Repository: mudler/LocalAI Branch: master Commit: 2b128753029a Files: 1287 Total size: 13.0 MB Directory structure: gitextract_p22ybhoq/ ├── .agents/ │ ├── adding-backends.md │ ├── api-endpoints-and-auth.md │ ├── building-and-testing.md │ ├── coding-style.md │ ├── llama-cpp-backend.md │ └── testing-mcp-apps.md ├── .air.toml ├── .devcontainer/ │ ├── devcontainer.json │ ├── docker-compose-devcontainer.yml │ ├── grafana/ │ │ └── datasource.yml │ └── prometheus/ │ └── prometheus.yml ├── .devcontainer-scripts/ │ ├── postcreate.sh │ ├── poststart.sh │ └── utils.sh ├── .dockerignore ├── .editorconfig ├── .gitattributes ├── .github/ │ ├── FUNDING.yml │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ ├── config.yml │ │ └── feature_request.md │ ├── PULL_REQUEST_TEMPLATE.md │ ├── bump_deps.sh │ ├── bump_docs.sh │ ├── check_and_update.py │ ├── checksum_checker.sh │ ├── ci/ │ │ └── modelslist.go │ ├── dependabot.yml │ ├── gallery-agent/ │ │ ├── agent.go │ │ ├── gallery.go │ │ ├── main.go │ │ ├── testing.go │ │ └── tools.go │ ├── labeler.yml │ ├── release.yml │ ├── stale.yml │ └── workflows/ │ ├── backend.yml │ ├── backend_build.yml │ ├── backend_build_darwin.yml │ ├── backend_pr.yml │ ├── build-test.yaml │ ├── bump_deps.yaml │ ├── bump_docs.yaml │ ├── checksum_checker.yaml │ ├── deploy-explorer.yaml │ ├── disabled/ │ │ ├── comment-pr.yaml │ │ ├── dependabot_auto.yml │ │ ├── labeler.yml │ │ ├── localaibot_automerge.yml │ │ ├── notify-models.yaml │ │ ├── prlint.yaml │ │ └── test-gpu.yml │ ├── gallery-agent.yaml │ ├── generate_grpc_cache.yaml │ ├── generate_intel_image.yaml │ ├── image-pr.yml │ ├── image.yml │ ├── image_build.yml │ ├── notify-releases.yaml │ ├── release.yaml │ ├── secscan.yaml │ ├── stalebot.yml │ ├── test-extra.yml │ ├── test.yml │ ├── tests-e2e.yml │ ├── tests-ui-e2e.yml │ ├── update_swagger.yaml │ └── yaml-check.yml ├── .gitignore ├── .gitmodules ├── .goreleaser.yaml ├── .vscode/ │ ├── extensions.json │ └── launch.json ├── .yamllint ├── AGENTS.md ├── CONTRIBUTING.md ├── Dockerfile ├── Entitlements.plist ├── LICENSE ├── Makefile ├── README.md ├── SECURITY.md ├── backend/ │ ├── Dockerfile.golang │ ├── Dockerfile.llama-cpp │ ├── Dockerfile.python │ ├── README.md │ ├── backend.proto │ ├── cpp/ │ │ ├── grpc/ │ │ │ ├── .gitignore │ │ │ └── Makefile │ │ └── llama-cpp/ │ │ ├── CMakeLists.txt │ │ ├── Makefile │ │ ├── grpc-server.cpp │ │ ├── package.sh │ │ ├── prepare.sh │ │ └── run.sh │ ├── go/ │ │ ├── acestep-cpp/ │ │ │ ├── CMakeLists.txt │ │ │ ├── Makefile │ │ │ ├── acestepcpp_test.go │ │ │ ├── cpp/ │ │ │ │ ├── goacestepcpp.cpp │ │ │ │ └── goacestepcpp.h │ │ │ ├── goacestepcpp.go │ │ │ ├── main.go │ │ │ ├── package.sh │ │ │ ├── run.sh │ │ │ └── test.sh │ │ ├── llm/ │ │ │ └── llama/ │ │ │ ├── llama.go │ │ │ └── main.go │ │ ├── local-store/ │ │ │ ├── Makefile │ │ │ ├── debug.go │ │ │ ├── main.go │ │ │ ├── package.sh │ │ │ ├── production.go │ │ │ ├── run.sh │ │ │ └── store.go │ │ ├── opus/ │ │ │ ├── Makefile │ │ │ ├── codec.go │ │ │ ├── csrc/ │ │ │ │ └── opus_shim.c │ │ │ ├── main.go │ │ │ ├── opus.go │ │ │ ├── opus_test.go │ │ │ ├── package.sh │ │ │ └── run.sh │ │ ├── piper/ │ │ │ ├── Makefile │ │ │ ├── main.go │ │ │ ├── package.sh │ │ │ ├── piper.go │ │ │ └── run.sh │ │ └── silero-vad/ │ │ ├── Makefile │ │ ├── main.go │ │ ├── package.sh │ │ ├── run.sh │ │ └── vad.go │ ├── index.yaml │ └── python/ │ ├── README.md │ ├── ace-step/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── chatterbox/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-install.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── common/ │ │ ├── libbackend.sh │ │ └── template/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── protogen.sh │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ └── test.sh │ ├── coqui/ │ │ ├── Makefile │ │ ├── README.md │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── diffusers/ │ │ ├── Makefile │ │ ├── README.md │ │ ├── backend.py │ │ ├── diffusers_dynamic_loader.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── faster-qwen3-tts/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-install.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── faster-whisper/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── protogen.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ └── test.sh │ ├── fish-speech/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── package.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── kitten-tts/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── kokoro/ │ │ ├── Makefile │ │ ├── README.md │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── mlx/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── mlx_cache.py │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ ├── test.sh │ │ └── test_mlx_cache.py │ ├── mlx-audio/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── mlx-distributed/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── coordinator.py │ │ ├── install.sh │ │ ├── mlx_cache.py │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── sharding.py │ │ ├── test.py │ │ └── test.sh │ ├── mlx-vlm/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── moonshine/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── protogen.sh │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── nemo/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── protogen.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── neutts/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-after.txt │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── outetts/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── pocket-tts/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── protogen.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── qwen-asr/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12-after.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel-after.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── qwen-tts/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12-after.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel-after.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── rerankers/ │ │ ├── Makefile │ │ ├── README.md │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── rfdetr/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ └── test.sh │ ├── transformers/ │ │ ├── Makefile │ │ ├── README.md │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── vibevoice/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── vllm/ │ │ ├── Makefile │ │ ├── README.md │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-after.txt │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12-after.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-install.txt │ │ ├── requirements-intel.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── vllm-omni/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── requirements-after.txt │ │ ├── requirements-cublas12-after.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ ├── voxcpm/ │ │ ├── Makefile │ │ ├── backend.py │ │ ├── install.sh │ │ ├── protogen.sh │ │ ├── requirements-cpu.txt │ │ ├── requirements-cublas12.txt │ │ ├── requirements-cublas13.txt │ │ ├── requirements-hipblas.txt │ │ ├── requirements-intel.txt │ │ ├── requirements-l4t12.txt │ │ ├── requirements-l4t13.txt │ │ ├── requirements-mps.txt │ │ ├── requirements.txt │ │ ├── run.sh │ │ ├── test.py │ │ └── test.sh │ └── whisperx/ │ ├── Makefile │ ├── backend.py │ ├── install.sh │ ├── protogen.sh │ ├── requirements-cpu.txt │ ├── requirements-cublas12.txt │ ├── requirements-cublas13.txt │ ├── requirements-hipblas.txt │ ├── requirements-mps.txt │ ├── requirements.txt │ ├── run.sh │ ├── test.py │ └── test.sh ├── cmd/ │ ├── launcher/ │ │ ├── icon.go │ │ ├── internal/ │ │ │ ├── launcher.go │ │ │ ├── launcher_suite_test.go │ │ │ ├── launcher_test.go │ │ │ ├── release_manager.go │ │ │ ├── release_manager_test.go │ │ │ ├── systray_manager.go │ │ │ └── ui.go │ │ └── main.go │ └── local-ai/ │ └── main.go ├── configuration/ │ └── .keep ├── core/ │ ├── application/ │ │ ├── agent_jobs.go │ │ ├── application.go │ │ ├── config_file_watcher.go │ │ ├── p2p.go │ │ ├── startup.go │ │ └── watchdog.go │ ├── backend/ │ │ ├── backend_suite_test.go │ │ ├── detection.go │ │ ├── embeddings.go │ │ ├── image.go │ │ ├── llm.go │ │ ├── llm_test.go │ │ ├── options.go │ │ ├── rerank.go │ │ ├── soundgeneration.go │ │ ├── stores.go │ │ ├── token_metrics.go │ │ ├── tokenize.go │ │ ├── transcript.go │ │ ├── tts.go │ │ ├── vad.go │ │ └── video.go │ ├── cli/ │ │ ├── agent.go │ │ ├── agent_test.go │ │ ├── backends.go │ │ ├── cli.go │ │ ├── completion.go │ │ ├── completion_test.go │ │ ├── context/ │ │ │ └── context.go │ │ ├── deprecations.go │ │ ├── explorer.go │ │ ├── federated.go │ │ ├── models.go │ │ ├── run.go │ │ ├── soundgeneration.go │ │ ├── transcript.go │ │ ├── tts.go │ │ ├── util.go │ │ └── worker/ │ │ ├── worker.go │ │ ├── worker_llamacpp.go │ │ ├── worker_mlx_common.go │ │ ├── worker_mlx_distributed.go │ │ ├── worker_p2p.go │ │ └── worker_p2p_mlx.go │ ├── clients/ │ │ └── store.go │ ├── config/ │ │ ├── application_config.go │ │ ├── application_config_test.go │ │ ├── config_suite_test.go │ │ ├── gallery.go │ │ ├── gguf.go │ │ ├── guesser.go │ │ ├── model_config.go │ │ ├── model_config_filter.go │ │ ├── model_config_loader.go │ │ ├── model_config_test.go │ │ ├── model_test.go │ │ └── runtime_settings.go │ ├── dependencies_manager/ │ │ └── manager.go │ ├── explorer/ │ │ ├── database.go │ │ ├── database_test.go │ │ ├── discovery.go │ │ └── explorer_suite_test.go │ ├── gallery/ │ │ ├── backend_resolve.go │ │ ├── backend_types.go │ │ ├── backends.go │ │ ├── backends_test.go │ │ ├── gallery.go │ │ ├── gallery_suite_test.go │ │ ├── gallery_test.go │ │ ├── importers/ │ │ │ ├── diffuser.go │ │ │ ├── diffuser_test.go │ │ │ ├── importers.go │ │ │ ├── importers_suite_test.go │ │ │ ├── importers_test.go │ │ │ ├── llama-cpp.go │ │ │ ├── llama-cpp_test.go │ │ │ ├── mlx.go │ │ │ ├── mlx_test.go │ │ │ ├── transformers.go │ │ │ ├── transformers_test.go │ │ │ ├── vllm.go │ │ │ └── vllm_test.go │ │ ├── metadata_type.go │ │ ├── models.go │ │ ├── models_test.go │ │ ├── models_types.go │ │ └── request_test.go │ ├── http/ │ │ ├── app.go │ │ ├── app_test.go │ │ ├── auth/ │ │ │ ├── apikeys.go │ │ │ ├── apikeys_test.go │ │ │ ├── auth_suite_test.go │ │ │ ├── db.go │ │ │ ├── db_nosqlite.go │ │ │ ├── db_sqlite.go │ │ │ ├── db_test.go │ │ │ ├── features.go │ │ │ ├── helpers_test.go │ │ │ ├── middleware.go │ │ │ ├── middleware_test.go │ │ │ ├── models.go │ │ │ ├── oauth.go │ │ │ ├── password.go │ │ │ ├── permissions.go │ │ │ ├── roles.go │ │ │ ├── roles_test.go │ │ │ ├── session.go │ │ │ ├── session_test.go │ │ │ ├── usage.go │ │ │ └── usage_test.go │ │ ├── endpoints/ │ │ │ ├── anthropic/ │ │ │ │ └── messages.go │ │ │ ├── elevenlabs/ │ │ │ │ ├── soundgeneration.go │ │ │ │ └── tts.go │ │ │ ├── explorer/ │ │ │ │ └── dashboard.go │ │ │ ├── jina/ │ │ │ │ └── rerank.go │ │ │ ├── localai/ │ │ │ │ ├── agent_collections.go │ │ │ │ ├── agent_jobs.go │ │ │ │ ├── agent_responses.go │ │ │ │ ├── agent_skills.go │ │ │ │ ├── agents.go │ │ │ │ ├── backend.go │ │ │ │ ├── backend_monitor.go │ │ │ │ ├── cors_proxy.go │ │ │ │ ├── detection.go │ │ │ │ ├── edit_model.go │ │ │ │ ├── edit_model_test.go │ │ │ │ ├── gallery.go │ │ │ │ ├── get_token_metrics.go │ │ │ │ ├── import_model.go │ │ │ │ ├── localai_suite_test.go │ │ │ │ ├── mcp.go │ │ │ │ ├── mcp_prompts.go │ │ │ │ ├── mcp_resources.go │ │ │ │ ├── mcp_tools.go │ │ │ │ ├── metrics.go │ │ │ │ ├── p2p.go │ │ │ │ ├── settings.go │ │ │ │ ├── stores.go │ │ │ │ ├── system.go │ │ │ │ ├── tokenize.go │ │ │ │ ├── tts.go │ │ │ │ ├── types.go │ │ │ │ ├── vad.go │ │ │ │ ├── video.go │ │ │ │ └── welcome.go │ │ │ ├── mcp/ │ │ │ │ └── tools.go │ │ │ ├── openai/ │ │ │ │ ├── chat.go │ │ │ │ ├── chat_test.go │ │ │ │ ├── completion.go │ │ │ │ ├── constants.go │ │ │ │ ├── edit.go │ │ │ │ ├── embeddings.go │ │ │ │ ├── image.go │ │ │ │ ├── image_test.go │ │ │ │ ├── inference.go │ │ │ │ ├── inference_test.go │ │ │ │ ├── inpainting.go │ │ │ │ ├── inpainting_test.go │ │ │ │ ├── list.go │ │ │ │ ├── openai_suite_test.go │ │ │ │ ├── realtime.go │ │ │ │ ├── realtime_model.go │ │ │ │ ├── realtime_transport.go │ │ │ │ ├── realtime_transport_webrtc.go │ │ │ │ ├── realtime_transport_ws.go │ │ │ │ ├── realtime_webrtc.go │ │ │ │ ├── transcription.go │ │ │ │ └── types/ │ │ │ │ ├── client_events.go │ │ │ │ ├── int_or_inf.go │ │ │ │ ├── message_item.go │ │ │ │ ├── server_events.go │ │ │ │ └── types.go │ │ │ └── openresponses/ │ │ │ ├── responses.go │ │ │ ├── store.go │ │ │ ├── store_suite_test.go │ │ │ ├── store_test.go │ │ │ └── websocket.go │ │ ├── explorer.go │ │ ├── http_suite_test.go │ │ ├── middleware/ │ │ │ ├── auth.go │ │ │ ├── auth_test.go │ │ │ ├── baseurl.go │ │ │ ├── baseurl_test.go │ │ │ ├── middleware_suite_test.go │ │ │ ├── request.go │ │ │ ├── strippathprefix.go │ │ │ ├── strippathprefix_test.go │ │ │ ├── trace.go │ │ │ └── usage.go │ │ ├── openresponses_test.go │ │ ├── react-ui/ │ │ │ ├── e2e/ │ │ │ │ ├── backend-logs.spec.js │ │ │ │ ├── manage-logs-link.spec.js │ │ │ │ ├── models-gallery.spec.js │ │ │ │ ├── navigation.spec.js │ │ │ │ ├── settings-backend-logging.spec.js │ │ │ │ ├── traces-errors.spec.js │ │ │ │ └── traces.spec.js │ │ │ ├── eslint.config.js │ │ │ ├── index.html │ │ │ ├── package.json │ │ │ ├── playwright.config.js │ │ │ ├── src/ │ │ │ │ ├── App.css │ │ │ │ ├── App.jsx │ │ │ │ ├── components/ │ │ │ │ │ ├── CanvasPanel.jsx │ │ │ │ │ ├── ClientMCPDropdown.jsx │ │ │ │ │ ├── CodeEditor.jsx │ │ │ │ │ ├── ConfirmDialog.jsx │ │ │ │ │ ├── LoadingSpinner.jsx │ │ │ │ │ ├── MCPAppFrame.jsx │ │ │ │ │ ├── Modal.jsx │ │ │ │ │ ├── ModelSelector.jsx │ │ │ │ │ ├── OperationsBar.jsx │ │ │ │ │ ├── RequireAdmin.jsx │ │ │ │ │ ├── RequireAuth.jsx │ │ │ │ │ ├── RequireFeature.jsx │ │ │ │ │ ├── ResourceCards.jsx │ │ │ │ │ ├── ResourceMonitor.jsx │ │ │ │ │ ├── SearchableModelSelect.jsx │ │ │ │ │ ├── SearchableSelect.jsx │ │ │ │ │ ├── SettingRow.jsx │ │ │ │ │ ├── Sidebar.jsx │ │ │ │ │ ├── ThemeToggle.jsx │ │ │ │ │ ├── Toast.jsx │ │ │ │ │ ├── Toggle.jsx │ │ │ │ │ ├── UnifiedMCPDropdown.jsx │ │ │ │ │ └── UserGroupSection.jsx │ │ │ │ ├── context/ │ │ │ │ │ └── AuthContext.jsx │ │ │ │ ├── contexts/ │ │ │ │ │ └── ThemeContext.jsx │ │ │ │ ├── hooks/ │ │ │ │ │ ├── useAgentChat.js │ │ │ │ │ ├── useChat.js │ │ │ │ │ ├── useMCPClient.js │ │ │ │ │ ├── useModels.js │ │ │ │ │ ├── useOperations.js │ │ │ │ │ ├── useResources.js │ │ │ │ │ └── useUserMap.js │ │ │ │ ├── index.css │ │ │ │ ├── main.jsx │ │ │ │ ├── pages/ │ │ │ │ │ ├── Account.jsx │ │ │ │ │ ├── AgentChat.jsx │ │ │ │ │ ├── AgentCreate.jsx │ │ │ │ │ ├── AgentJobDetails.jsx │ │ │ │ │ ├── AgentJobs.jsx │ │ │ │ │ ├── AgentStatus.jsx │ │ │ │ │ ├── AgentTaskDetails.jsx │ │ │ │ │ ├── Agents.jsx │ │ │ │ │ ├── BackendLogs.jsx │ │ │ │ │ ├── Backends.jsx │ │ │ │ │ ├── Chat.jsx │ │ │ │ │ ├── CollectionDetails.jsx │ │ │ │ │ ├── Collections.jsx │ │ │ │ │ ├── Explorer.jsx │ │ │ │ │ ├── Home.jsx │ │ │ │ │ ├── ImageGen.jsx │ │ │ │ │ ├── ImportModel.jsx │ │ │ │ │ ├── Login.jsx │ │ │ │ │ ├── Manage.jsx │ │ │ │ │ ├── ModelEditor.jsx │ │ │ │ │ ├── Models.jsx │ │ │ │ │ ├── NotFound.jsx │ │ │ │ │ ├── P2P.jsx │ │ │ │ │ ├── Settings.jsx │ │ │ │ │ ├── SkillEdit.jsx │ │ │ │ │ ├── Skills.jsx │ │ │ │ │ ├── Sound.jsx │ │ │ │ │ ├── TTS.jsx │ │ │ │ │ ├── Talk.jsx │ │ │ │ │ ├── Traces.jsx │ │ │ │ │ ├── Usage.jsx │ │ │ │ │ ├── Users.jsx │ │ │ │ │ ├── VideoGen.jsx │ │ │ │ │ └── auth.css │ │ │ │ ├── router.jsx │ │ │ │ ├── theme.css │ │ │ │ └── utils/ │ │ │ │ ├── api.js │ │ │ │ ├── artifacts.js │ │ │ │ ├── basePath.js │ │ │ │ ├── config.js │ │ │ │ ├── format.js │ │ │ │ ├── markdown.js │ │ │ │ └── mcpClientStorage.js │ │ │ └── vite.config.js │ │ ├── render.go │ │ ├── routes/ │ │ │ ├── agents.go │ │ │ ├── anthropic.go │ │ │ ├── auth.go │ │ │ ├── auth_test.go │ │ │ ├── elevenlabs.go │ │ │ ├── explorer.go │ │ │ ├── health.go │ │ │ ├── jina.go │ │ │ ├── localai.go │ │ │ ├── openai.go │ │ │ ├── openresponses.go │ │ │ ├── ui.go │ │ │ ├── ui_api.go │ │ │ ├── ui_api_backends_test.go │ │ │ ├── ui_backend_gallery.go │ │ │ └── ui_gallery.go │ │ ├── static/ │ │ │ ├── animations.css │ │ │ ├── assets/ │ │ │ │ ├── alpine.js │ │ │ │ ├── font1.css │ │ │ │ ├── font2.css │ │ │ │ ├── fontawesome/ │ │ │ │ │ └── css/ │ │ │ │ │ ├── all.css │ │ │ │ │ ├── brands.css │ │ │ │ │ ├── fontawesome.css │ │ │ │ │ ├── regular.css │ │ │ │ │ ├── solid.css │ │ │ │ │ ├── svg-with-js.css │ │ │ │ │ ├── v4-font-face.css │ │ │ │ │ ├── v4-shims.css │ │ │ │ │ └── v5-font-face.css │ │ │ │ ├── fontawesome.css │ │ │ │ ├── highlightjs.css │ │ │ │ ├── highlightjs.js │ │ │ │ ├── htmx.js │ │ │ │ ├── marked.js │ │ │ │ ├── purify.js │ │ │ │ ├── tailwindcss.js │ │ │ │ ├── tw-elements.css │ │ │ │ └── tw-elements.js │ │ │ ├── chat.js │ │ │ ├── components.css │ │ │ ├── general.css │ │ │ ├── image.js │ │ │ ├── p2panimation.js │ │ │ ├── sound.js │ │ │ ├── talk.js │ │ │ ├── theme.css │ │ │ ├── tts.js │ │ │ ├── typography.css │ │ │ └── video.js │ │ └── views/ │ │ ├── 404.html │ │ ├── agent-job-details.html │ │ ├── agent-jobs.html │ │ ├── agent-task-details.html │ │ ├── backends.html │ │ ├── chat.html │ │ ├── error.html │ │ ├── explorer.html │ │ ├── image.html │ │ ├── index.html │ │ ├── login.html │ │ ├── manage.html │ │ ├── model-editor.html │ │ ├── models.html │ │ ├── p2p.html │ │ ├── partials/ │ │ │ ├── footer.html │ │ │ ├── head.html │ │ │ ├── inprogress.html │ │ │ ├── navbar.html │ │ │ └── navbar_explorer.html │ │ ├── settings.html │ │ ├── sound.html │ │ ├── talk.html │ │ ├── traces.html │ │ ├── tts.html │ │ └── video.html │ ├── p2p/ │ │ ├── federated.go │ │ ├── federated_server.go │ │ ├── node.go │ │ ├── p2p.go │ │ └── p2p_common.go │ ├── schema/ │ │ ├── agent_jobs.go │ │ ├── anthropic.go │ │ ├── anthropic_test.go │ │ ├── backend.go │ │ ├── elevenlabs.go │ │ ├── gallery-model.schema.json │ │ ├── jina.go │ │ ├── localai.go │ │ ├── message.go │ │ ├── message_test.go │ │ ├── openai.go │ │ ├── openresponses.go │ │ ├── prediction.go │ │ ├── request.go │ │ ├── schema_suite_test.go │ │ ├── tokenize.go │ │ └── transcription.go │ ├── services/ │ │ ├── agent_jobs.go │ │ ├── agent_jobs_test.go │ │ ├── agent_pool.go │ │ ├── agent_pool_sse.go │ │ ├── backend_monitor.go │ │ ├── backends.go │ │ ├── backends_test.go │ │ ├── gallery.go │ │ ├── list_models.go │ │ ├── metrics.go │ │ ├── models.go │ │ ├── operation.go │ │ ├── services_suite_test.go │ │ ├── user_services.go │ │ └── user_storage.go │ ├── startup/ │ │ ├── model_preload.go │ │ ├── model_preload_test.go │ │ └── startup_suite_test.go │ ├── templates/ │ │ ├── cache.go │ │ ├── evaluator.go │ │ ├── evaluator_test.go │ │ ├── multimodal.go │ │ ├── multimodal_test.go │ │ └── templates_suite_test.go │ └── trace/ │ ├── audio_snippet.go │ └── backend_trace.go ├── custom-ca-certs/ │ └── .keep ├── docker-compose.yaml ├── docs/ │ ├── Dockerfile │ ├── README.md │ ├── assets/ │ │ └── jsconfig.json │ ├── content/ │ │ ├── _index.md │ │ ├── advanced/ │ │ │ ├── _index.en.md │ │ │ ├── _index.md │ │ │ ├── advanced-usage.md │ │ │ ├── fine-tuning.md │ │ │ ├── model-configuration.md │ │ │ ├── reverse-proxy-tls.md │ │ │ └── vram-management.md │ │ ├── faq.md │ │ ├── features/ │ │ │ ├── GPU-acceleration.md │ │ │ ├── _index.en.md │ │ │ ├── agents.md │ │ │ ├── audio-to-text.md │ │ │ ├── authentication.md │ │ │ ├── backend-monitor.md │ │ │ ├── backends.md │ │ │ ├── constrained_grammars.md │ │ │ ├── distributed_inferencing.md │ │ │ ├── embeddings.md │ │ │ ├── gpt-vision.md │ │ │ ├── image-generation.md │ │ │ ├── mcp.md │ │ │ ├── mlx-distributed.md │ │ │ ├── model-gallery.md │ │ │ ├── object-detection.md │ │ │ ├── openai-functions.md │ │ │ ├── openai-realtime.md │ │ │ ├── p2p.md │ │ │ ├── reranker.md │ │ │ ├── runtime-settings.md │ │ │ ├── sound-generation.md │ │ │ ├── stores.md │ │ │ ├── text-generation.md │ │ │ ├── text-to-audio.md │ │ │ ├── video-generation.md │ │ │ └── voice-activity-detection.md │ │ ├── getting-started/ │ │ │ ├── _index.en.md │ │ │ ├── build.md │ │ │ ├── container-images.md │ │ │ ├── customize-model.md │ │ │ ├── kubernetes.md │ │ │ ├── models.md │ │ │ ├── quickstart.md │ │ │ ├── troubleshooting.md │ │ │ └── try-it-out.md │ │ ├── installation/ │ │ │ ├── _index.en.md │ │ │ ├── build.md │ │ │ ├── containers.md │ │ │ ├── docker.md │ │ │ ├── kubernetes.md │ │ │ ├── linux.md │ │ │ └── macos.md │ │ ├── integrations.md │ │ ├── overview.md │ │ ├── reference/ │ │ │ ├── _index.en.md │ │ │ ├── _index.md │ │ │ ├── api-errors.md │ │ │ ├── architecture.md │ │ │ ├── binaries.md │ │ │ ├── cli-reference.md │ │ │ ├── compatibility-table.md │ │ │ ├── nvidia-l4t.md │ │ │ ├── shell-completion.md │ │ │ └── system-info.md │ │ └── whats-new.md │ ├── data/ │ │ ├── landing.yaml │ │ └── version.json │ ├── docker-compose.yaml │ ├── go.mod │ ├── go.sum │ ├── hugo.toml │ ├── layouts/ │ │ ├── 404.html │ │ ├── partials/ │ │ │ ├── docs/ │ │ │ │ ├── gitinfo.html │ │ │ │ ├── sidebar.html │ │ │ │ └── top-header.html │ │ │ ├── head.html │ │ │ ├── header.html │ │ │ ├── logo.html │ │ │ └── menu-footer.html │ │ └── shortcodes/ │ │ ├── github.html │ │ ├── pr.html │ │ └── version.html │ ├── netlify.toml │ ├── package.json │ └── static/ │ └── site.webmanifest ├── entrypoint.sh ├── examples/ │ └── README.md ├── gallery/ │ ├── alpaca.yaml │ ├── arch-function.yaml │ ├── cerbero.yaml │ ├── chatml-hercules.yaml │ ├── chatml.yaml │ ├── codellama.yaml │ ├── command-r.yaml │ ├── deephermes.yaml │ ├── deepseek-r1.yaml │ ├── deepseek.yaml │ ├── dreamshaper.yaml │ ├── falcon3.yaml │ ├── flux-ggml.yaml │ ├── flux.yaml │ ├── gemma.yaml │ ├── granite.yaml │ ├── granite3-2.yaml │ ├── granite4.yaml │ ├── harmony.yaml │ ├── hermes-2-pro-mistral.yaml │ ├── hermes-vllm.yaml │ ├── index.yaml │ ├── jamba.yaml │ ├── lfm.yaml │ ├── llama3-instruct.yaml │ ├── llama3.1-instruct-grammar.yaml │ ├── llama3.1-instruct.yaml │ ├── llama3.1-reflective.yaml │ ├── llama3.2-fcall.yaml │ ├── llama3.2-quantized.yaml │ ├── llava.yaml │ ├── mathstral.yaml │ ├── mistral-0.3.yaml │ ├── moondream.yaml │ ├── mudler.yaml │ ├── nanbeige4.1.yaml │ ├── noromaid.yaml │ ├── openvino.yaml │ ├── parler-tts.yaml │ ├── phi-2-chat.yaml │ ├── phi-2-orange.yaml │ ├── phi-3-chat.yaml │ ├── phi-3-vision.yaml │ ├── phi-4-chat-fcall.yaml │ ├── phi-4-chat.yaml │ ├── piper.yaml │ ├── pocket-tts.yaml │ ├── qwen-fcall.yaml │ ├── qwen-image.yaml │ ├── qwen3-deepresearch.yaml │ ├── qwen3-openbuddy.yaml │ ├── qwen3.yaml │ ├── rerankers.yaml │ ├── rwkv.yaml │ ├── sd-ggml.yaml │ ├── sentencetransformers.yaml │ ├── smolvlm.yaml │ ├── stablediffusion3.yaml │ ├── tuluv2.yaml │ ├── vibevoice.yaml │ ├── vicuna-chat.yaml │ ├── virtual.yaml │ ├── vllm.yaml │ ├── whisper-base.yaml │ ├── wizardlm2.yaml │ └── z-image-ggml.yaml ├── go.mod ├── go.sum ├── internal/ │ └── version.go ├── pkg/ │ ├── audio/ │ │ ├── audio.go │ │ ├── audio_suite_test.go │ │ ├── audio_test.go │ │ └── identify.go │ ├── concurrency/ │ │ ├── concurrency_suite_test.go │ │ ├── jobresult.go │ │ └── jobresult_test.go │ ├── downloader/ │ │ ├── downloader_suite_test.go │ │ ├── huggingface.go │ │ ├── progress.go │ │ ├── uri.go │ │ └── uri_test.go │ ├── format/ │ │ └── transcription.go │ ├── functions/ │ │ ├── chat_deltas.go │ │ ├── function_structure.go │ │ ├── functions.go │ │ ├── functions_suite_test.go │ │ ├── functions_test.go │ │ ├── grammars/ │ │ │ ├── bnf_rules.go │ │ │ ├── grammars_suite_test.go │ │ │ ├── json_schema.go │ │ │ ├── json_schema_test.go │ │ │ ├── llama31_schema.go │ │ │ ├── llama31_schema_test.go │ │ │ ├── options.go │ │ │ ├── rules.go │ │ │ └── types.go │ │ ├── iterative_parser.go │ │ ├── json_mode.go │ │ ├── json_stack_parser.go │ │ ├── parse.go │ │ ├── parse_test.go │ │ ├── peg/ │ │ │ ├── arena.go │ │ │ ├── builder.go │ │ │ ├── chat.go │ │ │ ├── chat_test.go │ │ │ ├── parser.go │ │ │ ├── parser_test.go │ │ │ ├── peg_suite_test.go │ │ │ ├── trie.go │ │ │ ├── types.go │ │ │ └── utils_test.go │ │ ├── peg_integration.go │ │ └── peg_integration_test.go │ ├── grpc/ │ │ ├── backend.go │ │ ├── base/ │ │ │ ├── base.go │ │ │ └── singlethread.go │ │ ├── client.go │ │ ├── embed.go │ │ ├── interface.go │ │ └── server.go │ ├── huggingface-api/ │ │ ├── client.go │ │ ├── client_test.go │ │ └── hfapi_suite_test.go │ ├── langchain/ │ │ └── langchain.go │ ├── model/ │ │ ├── backend_log_store.go │ │ ├── filters.go │ │ ├── initializers.go │ │ ├── loader.go │ │ ├── loader_options.go │ │ ├── loader_test.go │ │ ├── model.go │ │ ├── model_suite_test.go │ │ ├── process.go │ │ ├── watchdog.go │ │ ├── watchdog_options.go │ │ ├── watchdog_options_test.go │ │ └── watchdog_test.go │ ├── oci/ │ │ ├── blob.go │ │ ├── blob_test.go │ │ ├── image.go │ │ ├── image_test.go │ │ ├── oci_suite_test.go │ │ ├── ollama.go │ │ ├── ollama_test.go │ │ └── tarball.go │ ├── reasoning/ │ │ ├── config.go │ │ ├── extractor.go │ │ ├── extractor_test.go │ │ ├── reasoning.go │ │ ├── reasoning_suite_test.go │ │ └── reasoning_test.go │ ├── signals/ │ │ └── handler.go │ ├── sound/ │ │ ├── float32.go │ │ ├── int16.go │ │ ├── int16_test.go │ │ ├── sound_suite_test.go │ │ └── testutil_test.go │ ├── store/ │ │ └── client.go │ ├── system/ │ │ ├── capabilities.go │ │ ├── capabilities_test.go │ │ ├── state.go │ │ └── system_suite_test.go │ ├── utils/ │ │ ├── base64.go │ │ ├── base64_test.go │ │ ├── ffmpeg.go │ │ ├── hash.go │ │ ├── json.go │ │ ├── logging.go │ │ ├── path.go │ │ ├── strings.go │ │ ├── untar.go │ │ ├── urlfetch.go │ │ ├── urlfetch_test.go │ │ └── utils_suite_test.go │ ├── vram/ │ │ ├── cache.go │ │ ├── estimate.go │ │ ├── estimate_test.go │ │ ├── gguf_reader.go │ │ ├── hf_estimate.go │ │ ├── hf_estimate_test.go │ │ ├── types.go │ │ └── vram_suite_test.go │ ├── xio/ │ │ └── copy.go │ ├── xsync/ │ │ ├── map.go │ │ ├── map_test.go │ │ └── sync_suite_test.go │ └── xsysinfo/ │ ├── cpu.go │ ├── gpu.go │ └── memory.go ├── prompt-templates/ │ ├── alpaca.tmpl │ ├── getting_started.tmpl │ ├── ggml-gpt4all-j.tmpl │ ├── koala.tmpl │ ├── llama2-chat-message.tmpl │ ├── vicuna.tmpl │ └── wizardlm.tmpl ├── renovate.json ├── scripts/ │ ├── build/ │ │ ├── golang-darwin.sh │ │ ├── llama-cpp-darwin.sh │ │ ├── package-gpu-libs.sh │ │ └── python-darwin.sh │ ├── changed-backends.js │ ├── latest_hf.py │ ├── model_gallery_info.py │ └── prepare-libs.sh ├── swagger/ │ ├── docs.go │ ├── swagger.json │ └── swagger.yaml ├── tests/ │ ├── e2e/ │ │ ├── e2e_anthropic_test.go │ │ ├── e2e_mcp_test.go │ │ ├── e2e_suite_test.go │ │ ├── e2e_websocket_responses_test.go │ │ ├── mock_backend_test.go │ │ ├── realtime_webrtc_test.go │ │ └── realtime_ws_test.go │ ├── e2e-aio/ │ │ ├── e2e_suite_test.go │ │ ├── e2e_test.go │ │ ├── models/ │ │ │ ├── embeddings.yaml │ │ │ ├── image-gen.yaml │ │ │ ├── rerank.yaml │ │ │ ├── speech-to-text.yaml │ │ │ ├── text-to-speech.yaml │ │ │ ├── text-to-text.yaml │ │ │ ├── vad.yaml │ │ │ └── vision.yaml │ │ └── sample_data_test.go │ ├── e2e-ui/ │ │ ├── .gitignore │ │ ├── Dockerfile │ │ └── main.go │ ├── fixtures/ │ │ ├── backend-image/ │ │ │ ├── Dockerfile │ │ │ ├── run.sh │ │ │ └── src/ │ │ │ └── .keep │ │ └── gallery_simple.yaml │ ├── integration/ │ │ ├── integration_suite_test.go │ │ └── stores_test.go │ └── models_fixtures/ │ ├── completion.tmpl │ ├── config.yaml │ ├── embeddings.yaml │ ├── ggml-gpt4all-j.tmpl │ ├── gpt4.yaml │ ├── gpt4_2.yaml │ ├── grpc.yaml │ ├── rwkv.yaml │ └── whisper.yaml └── webui_static.yaml ================================================ FILE CONTENTS ================================================ ================================================ FILE: .agents/adding-backends.md ================================================ # Adding a New Backend When adding a new backend to LocalAI, you need to update several files to ensure the backend is properly built, tested, and registered. Here's a step-by-step guide based on the pattern used for adding backends like `moonshine`: ## 1. Create Backend Directory Structure Create the backend directory under the appropriate location: - **Python backends**: `backend/python//` - **Go backends**: `backend/go//` - **C++ backends**: `backend/cpp//` For Python backends, you'll typically need: - `backend.py` - Main gRPC server implementation - `Makefile` - Build configuration - `install.sh` - Installation script for dependencies - `protogen.sh` - Protocol buffer generation script - `requirements.txt` - Python dependencies - `run.sh` - Runtime script - `test.py` / `test.sh` - Test files ## 2. Add Build Configurations to `.github/workflows/backend.yml` Add build matrix entries for each platform/GPU type you want to support. Look at similar backends (e.g., `chatterbox`, `faster-whisper`) for reference. **Placement in file:** - CPU builds: Add after other CPU builds (e.g., after `cpu-chatterbox`) - CUDA 12 builds: Add after other CUDA 12 builds (e.g., after `gpu-nvidia-cuda-12-chatterbox`) - CUDA 13 builds: Add after other CUDA 13 builds (e.g., after `gpu-nvidia-cuda-13-chatterbox`) **Additional build types you may need:** - ROCm/HIP: Use `build-type: 'hipblas'` with `base-image: "rocm/dev-ubuntu-24.04:6.4.4"` - Intel/SYCL: Use `build-type: 'intel'` or `build-type: 'sycl_f16'`/`sycl_f32` with `base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04"` - L4T (ARM): Use `build-type: 'l4t'` with `platforms: 'linux/arm64'` and `runs-on: 'ubuntu-24.04-arm'` ## 3. Add Backend Metadata to `backend/index.yaml` **Step 3a: Add Meta Definition** Add a YAML anchor definition in the `## metas` section (around line 2-300). Look for similar backends to use as a template such as `diffusers` or `chatterbox` **Step 3b: Add Image Entries** Add image entries at the end of the file, following the pattern of similar backends such as `diffusers` or `chatterbox`. Include both `latest` (production) and `master` (development) tags. ## 4. Update the Makefile The Makefile needs to be updated in several places to support building and testing the new backend: **Step 4a: Add to `.NOTPARALLEL`** Add `backends/` to the `.NOTPARALLEL` line (around line 2) to prevent parallel execution conflicts: ```makefile .NOTPARALLEL: ... backends/ ``` **Step 4b: Add to `prepare-test-extra`** Add the backend to the `prepare-test-extra` target (around line 312) to prepare it for testing: ```makefile prepare-test-extra: protogen-python ... $(MAKE) -C backend/python/ ``` **Step 4c: Add to `test-extra`** Add the backend to the `test-extra` target (around line 319) to run its tests: ```makefile test-extra: prepare-test-extra ... $(MAKE) -C backend/python/ test ``` **Step 4d: Add Backend Definition** Add a backend definition variable in the backend definitions section (around line 428-457). The format depends on the backend type: **For Python backends with root context** (like `faster-whisper`, `coqui`): ```makefile BACKEND_ = |python|.|false|true ``` **For Python backends with `./backend` context** (like `chatterbox`, `moonshine`): ```makefile BACKEND_ = |python|./backend|false|true ``` **For Go backends**: ```makefile BACKEND_ = |golang|.|false|true ``` **Step 4e: Generate Docker Build Target** Add an eval call to generate the docker-build target (around line 480-501): ```makefile $(eval $(call generate-docker-build-target,$(BACKEND_))) ``` **Step 4f: Add to `docker-build-backends`** Add `docker-build-` to the `docker-build-backends` target (around line 507): ```makefile docker-build-backends: ... docker-build- ``` **Determining the Context:** - If the backend is in `backend/python//` and uses `./backend` as context in the workflow file, use `./backend` context - If the backend is in `backend/python//` but uses `.` as context in the workflow file, use `.` context - Check similar backends to determine the correct context ## 5. Verification Checklist After adding a new backend, verify: - [ ] Backend directory structure is complete with all necessary files - [ ] Build configurations added to `.github/workflows/backend.yml` for all desired platforms - [ ] Meta definition added to `backend/index.yaml` in the `## metas` section - [ ] Image entries added to `backend/index.yaml` for all build variants (latest + development) - [ ] Tag suffixes match between workflow file and index.yaml - [ ] Makefile updated with all 6 required changes (`.NOTPARALLEL`, `prepare-test-extra`, `test-extra`, backend definition, docker-build target eval, `docker-build-backends`) - [ ] No YAML syntax errors (check with linter) - [ ] No Makefile syntax errors (check with linter) - [ ] Follows the same pattern as similar backends (e.g., if it's a transcription backend, follow `faster-whisper` pattern) ## 6. Example: Adding a Python Backend For reference, when `moonshine` was added: - **Files created**: `backend/python/moonshine/{backend.py, Makefile, install.sh, protogen.sh, requirements.txt, run.sh, test.py, test.sh}` - **Workflow entries**: 3 build configurations (CPU, CUDA 12, CUDA 13) - **Index entries**: 1 meta definition + 6 image entries (cpu, cuda12, cuda13 x latest/development) - **Makefile updates**: - Added to `.NOTPARALLEL` line - Added to `prepare-test-extra` and `test-extra` targets - Added `BACKEND_MOONSHINE = moonshine|python|./backend|false|true` - Added eval for docker-build target generation - Added `docker-build-moonshine` to `docker-build-backends` ================================================ FILE: .agents/api-endpoints-and-auth.md ================================================ # API Endpoints and Authentication This guide covers how to add new API endpoints and properly integrate them with the auth/permissions system. ## Architecture overview Authentication and authorization flow through three layers: 1. **Global auth middleware** (`core/http/auth/middleware.go` → `auth.Middleware`) — applied to every request in `core/http/app.go`. Handles session cookies, Bearer tokens, API keys, and legacy API keys. Populates `auth_user` and `auth_role` in the Echo context. 2. **Feature middleware** (`auth.RequireFeature`) — per-feature access control applied to route groups or individual routes. Checks if the authenticated user has the specific feature enabled. 3. **Admin middleware** (`auth.RequireAdmin`) — restricts endpoints to admin users only. When auth is disabled (no auth DB, no legacy API keys), all middleware becomes pass-through (`auth.NoopMiddleware`). ## Adding a new API endpoint ### Step 1: Create the handler Write the endpoint handler in the appropriate package under `core/http/endpoints/`. Follow existing patterns: ```go // core/http/endpoints/localai/my_feature.go func MyFeatureEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { // Use auth.GetUser(c) to get the authenticated user (may be nil if auth is disabled) user := auth.GetUser(c) // Your logic here return c.JSON(http.StatusOK, result) } } ``` ### Step 2: Register routes Add routes in the appropriate file under `core/http/routes/`. The file you use depends on the endpoint category: | File | Category | |------|----------| | `routes/openai.go` | OpenAI-compatible API endpoints (`/v1/...`) | | `routes/localai.go` | LocalAI-specific endpoints (`/api/...`, `/models/...`, `/backends/...`) | | `routes/agents.go` | Agent pool endpoints (`/api/agents/...`) | | `routes/auth.go` | Auth endpoints (`/api/auth/...`) | | `routes/ui_api.go` | UI backend API endpoints | ### Step 3: Apply the right middleware Choose the appropriate protection level: #### No auth required (public) Exempt paths bypass auth entirely. Add to `isExemptPath()` in `middleware.go` or use the `/api/auth/` prefix (always exempt). Use sparingly — most endpoints should require auth. #### Standard auth (any authenticated user) The global middleware already handles this. API paths (`/api/`, `/v1/`, etc.) automatically require authentication when auth is enabled. You don't need to add any extra middleware. ```go router.GET("/v1/my-endpoint", myHandler) // auth enforced by global middleware ``` #### Admin only Pass `adminMiddleware` to the route. This is set up in `app.go` and passed to `Register*Routes` functions: ```go // In the Register function signature, accept the middleware: func RegisterMyRoutes(router *echo.Echo, app *application.Application, adminMiddleware echo.MiddlewareFunc) { router.POST("/models/apply", myHandler, adminMiddleware) } ``` #### Feature-gated For endpoints that should be toggleable per-user, use feature middleware. There are two approaches: **Approach A: Route-level middleware** (preferred for groups of related endpoints) ```go // In app.go, create the feature middleware: myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature) // Pass it to the route registration function: routes.RegisterMyRoutes(e, app, myFeatureMw) // In the routes file, apply to a group: g := e.Group("/api/my-feature", myFeatureMw) g.GET("", listHandler) g.POST("", createHandler) ``` **Approach B: RouteFeatureRegistry** (preferred for individual OpenAI-compatible endpoints) Add an entry to `RouteFeatureRegistry` in `core/http/auth/features.go`. The `RequireRouteFeature` global middleware will automatically enforce it: ```go var RouteFeatureRegistry = []RouteFeature{ // ... existing entries ... {"POST", "/v1/my-endpoint", FeatureMyFeature}, } ``` ## Adding a new feature When you need a new toggleable feature (not just a new endpoint under an existing feature): ### 1. Define the feature constant Add to `core/http/auth/permissions.go`: ```go const ( // Add to the appropriate group: // Agent features (default OFF for new users) FeatureMyFeature = "my_feature" // OR API features (default ON for new users) FeatureMyFeature = "my_feature" ) ``` Then add it to the appropriate slice: ```go // Default OFF — user must be explicitly granted access: var AgentFeatures = []string{..., FeatureMyFeature} // Default ON — user has access unless explicitly revoked: var APIFeatures = []string{..., FeatureMyFeature} ``` ### 2. Add feature metadata In `core/http/auth/features.go`, add to the appropriate `FeatureMetas` function so the admin UI can display it: ```go func AgentFeatureMetas() []FeatureMeta { return []FeatureMeta{ // ... existing ... {FeatureMyFeature, "My Feature", false}, // false = default OFF } } ``` ### 3. Wire up the middleware In `core/http/app.go`: ```go myFeatureMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMyFeature) ``` Then pass it to the route registration function. ### 4. Register route-feature mappings (if applicable) If your feature gates standard API endpoints (like `/v1/...`), add entries to `RouteFeatureRegistry` in `features.go` instead of using per-route middleware. ## Accessing the authenticated user in handlers ```go import "github.com/mudler/LocalAI/core/http/auth" func MyHandler(c echo.Context) error { // Get the user (nil when auth is disabled or unauthenticated) user := auth.GetUser(c) if user == nil { // Handle unauthenticated — or let middleware handle it } // Check role if user.Role == auth.RoleAdmin { // admin-specific logic } // Check feature access programmatically (when you need conditional behavior, not full blocking) if auth.HasFeatureAccess(db, user, auth.FeatureMyFeature) { // feature-specific logic } // Check model access if !auth.IsModelAllowed(db, user, modelName) { return c.JSON(http.StatusForbidden, ...) } } ``` ## Middleware composition patterns Middleware can be composed at different levels. Here are the patterns used in the codebase: ### Group-level middleware (agents pattern) ```go // All routes in the group share the middleware g := e.Group("/api/agents", poolReadyMw, agentsMw) g.GET("", listHandler) g.POST("", createHandler) ``` ### Per-route middleware (localai pattern) ```go // Individual routes get middleware as extra arguments router.POST("/models/apply", applyHandler, adminMiddleware) router.GET("/metrics", metricsHandler, adminMiddleware) ``` ### Middleware slice (openai pattern) ```go // Build a middleware chain for a handler chatMiddleware := []echo.MiddlewareFunc{ usageMiddleware, traceMiddleware, modelFilterMiddleware, } app.POST("/v1/chat/completions", chatHandler, chatMiddleware...) ``` ## Error response format Always use `schema.ErrorResponse` for auth/permission errors to stay consistent with the OpenAI-compatible API: ```go return c.JSON(http.StatusForbidden, schema.ErrorResponse{ Error: &schema.APIError{ Message: "feature not enabled for your account", Code: http.StatusForbidden, Type: "authorization_error", }, }) ``` Use these HTTP status codes: - `401 Unauthorized` — no valid credentials provided - `403 Forbidden` — authenticated but lacking permission - `429 Too Many Requests` — rate limited (auth endpoints) ## Usage tracking If your endpoint should be tracked for usage (token counts, request counts), add the `usageMiddleware` to its middleware chain. See `core/http/middleware/usage.go` and how it's applied in `routes/openai.go`. ## Path protection rules The global auth middleware classifies paths as API paths or non-API paths: - **API paths** (always require auth when auth is enabled): `/api/`, `/v1/`, `/models/`, `/backends/`, `/backend/`, `/tts`, `/vad`, `/video`, `/stores/`, `/system`, `/ws/`, `/metrics` - **Exempt paths** (never require auth): `/api/auth/` prefix, anything in `appConfig.PathWithoutAuth` - **Non-API paths** (UI, static assets): pass through without auth — the React UI handles login redirects client-side If you add endpoints under a new top-level path prefix, add it to `isAPIPath()` in `middleware.go` to ensure it requires authentication. ## Checklist When adding a new endpoint: - [ ] Handler in `core/http/endpoints/` - [ ] Route registered in appropriate `core/http/routes/` file - [ ] Auth level chosen: public / standard / admin / feature-gated - [ ] If feature-gated: constant in `permissions.go`, metadata in `features.go`, middleware in `app.go` - [ ] If new path prefix: added to `isAPIPath()` in `middleware.go` - [ ] If OpenAI-compatible: entry in `RouteFeatureRegistry` - [ ] If token-counting: `usageMiddleware` added to middleware chain - [ ] Error responses use `schema.ErrorResponse` format - [ ] Tests cover both authenticated and unauthenticated access ================================================ FILE: .agents/building-and-testing.md ================================================ # Build and Testing Building and testing the project depends on the components involved and the platform where development is taking place. Due to the amount of context required it's usually best not to try building or testing the project unless the user requests it. If you must build the project then inspect the Makefile in the project root and the Makefiles of any backends that are effected by changes you are making. In addition the workflows in .github/workflows can be used as a reference when it is unclear how to build or test a component. The primary Makefile contains targets for building inside or outside Docker, if the user has not previously specified a preference then ask which they would like to use. ## Building a specified backend Let's say the user wants to build a particular backend for a given platform. For example let's say they want to build coqui for ROCM/hipblas - The Makefile has targets like `docker-build-coqui` created with `generate-docker-build-target` at the time of writing. Recently added backends may require a new target. - At a minimum we need to set the BUILD_TYPE, BASE_IMAGE build-args - Use .github/workflows/backend.yml as a reference it lists the needed args in the `include` job strategy matrix - l4t and cublas also requires the CUDA major and minor version - You can pretty print a command like `DOCKER_MAKEFLAGS=-j$(nproc --ignore=1) BUILD_TYPE=hipblas BASE_IMAGE=rocm/dev-ubuntu-24.04:6.4.4 make docker-build-coqui` - Unless the user specifies that they want you to run the command, then just print it because not all agent frontends handle long running jobs well and the output may overflow your context - The user may say they want to build AMD or ROCM instead of hipblas, or Intel instead of SYCL or NVIDIA insted of l4t or cublas. Ask for confirmation if there is ambiguity. - Sometimes the user may need extra parameters to be added to `docker build` (e.g. `--platform` for cross-platform builds or `--progress` to view the full logs), in which case you can generate the `docker build` command directly. ================================================ FILE: .agents/coding-style.md ================================================ # Coding Style The project has the following .editorconfig: ``` root = true [*] indent_style = space indent_size = 2 end_of_line = lf charset = utf-8 trim_trailing_whitespace = true insert_final_newline = true [*.go] indent_style = tab [Makefile] indent_style = tab [*.proto] indent_size = 2 [*.py] indent_size = 4 [*.js] indent_size = 2 [*.yaml] indent_size = 2 [*.md] trim_trailing_whitespace = false ``` - Use comments sparingly to explain why code does something, not what it does. Comments are there to add context that would be difficult to deduce from reading the code. - Prefer modern Go e.g. use `any` not `interface{}` ## Logging Use `github.com/mudler/xlog` for logging which has the same API as slog. ## Documentation The project documentation is located in `docs/content`. When adding new features or changing existing functionality, it is crucial to update the documentation to reflect these changes. This helps users understand how to use the new capabilities and ensures the documentation stays relevant. - **Feature Documentation**: If you add a new feature (like a new backend or API endpoint), create a new markdown file in `docs/content/features/` explaining what it is, how to configure it, and how to use it. - **Configuration**: If you modify configuration options, update the relevant sections in `docs/content/`. - **Examples**: providing concrete examples (like YAML configuration blocks) is highly encouraged to help users get started quickly. ================================================ FILE: .agents/llama-cpp-backend.md ================================================ # llama.cpp Backend The llama.cpp backend (`backend/cpp/llama-cpp/grpc-server.cpp`) is a gRPC adaptation of the upstream HTTP server (`llama.cpp/tools/server/server.cpp`). It uses the same underlying server infrastructure from `llama.cpp/tools/server/server-context.cpp`. ## Building and Testing - Test llama.cpp backend compilation: `make backends/llama-cpp` - The backend is built as part of the main build process - Check `backend/cpp/llama-cpp/Makefile` for build configuration ## Architecture - **grpc-server.cpp**: gRPC server implementation, adapts HTTP server patterns to gRPC - Uses shared server infrastructure: `server-context.cpp`, `server-task.cpp`, `server-queue.cpp`, `server-common.cpp` - The gRPC server mirrors the HTTP server's functionality but uses gRPC instead of HTTP ## Common Issues When Updating llama.cpp When fixing compilation errors after upstream changes: 1. Check how `server.cpp` (HTTP server) handles the same change 2. Look for new public APIs or getter methods 3. Store copies of needed data instead of accessing private members 4. Update function calls to match new signatures 5. Test with `make backends/llama-cpp` ## Key Differences from HTTP Server - gRPC uses `BackendServiceImpl` class with gRPC service methods - HTTP server uses `server_routes` with HTTP handlers - Both use the same `server_context` and task queue infrastructure - gRPC methods: `LoadModel`, `Predict`, `PredictStream`, `Embedding`, `Rerank`, `TokenizeString`, `GetMetrics`, `Health` ## Tool Call Parsing Maintenance When working on JSON/XML tool call parsing functionality, always check llama.cpp for reference implementation and updates: ### Checking for XML Parsing Changes 1. **Review XML Format Definitions**: Check `llama.cpp/common/chat-parser-xml-toolcall.h` for `xml_tool_call_format` struct changes 2. **Review Parsing Logic**: Check `llama.cpp/common/chat-parser-xml-toolcall.cpp` for parsing algorithm updates 3. **Review Format Presets**: Check `llama.cpp/common/chat-parser.cpp` for new XML format presets (search for `xml_tool_call_format form`) 4. **Review Model Lists**: Check `llama.cpp/common/chat.h` for `COMMON_CHAT_FORMAT_*` enum values that use XML parsing: - `COMMON_CHAT_FORMAT_GLM_4_5` - `COMMON_CHAT_FORMAT_MINIMAX_M2` - `COMMON_CHAT_FORMAT_KIMI_K2` - `COMMON_CHAT_FORMAT_QWEN3_CODER_XML` - `COMMON_CHAT_FORMAT_APRIEL_1_5` - `COMMON_CHAT_FORMAT_XIAOMI_MIMO` - Any new formats added ### Model Configuration Options Always check `llama.cpp` for new model configuration options that should be supported in LocalAI: 1. **Check Server Context**: Review `llama.cpp/tools/server/server-context.cpp` for new parameters 2. **Check Chat Params**: Review `llama.cpp/common/chat.h` for `common_chat_params` struct changes 3. **Check Server Options**: Review `llama.cpp/tools/server/server.cpp` for command-line argument changes 4. **Examples of options to check**: - `ctx_shift` - Context shifting support - `parallel_tool_calls` - Parallel tool calling - `reasoning_format` - Reasoning format options - Any new flags or parameters ### Implementation Guidelines 1. **Feature Parity**: Always aim for feature parity with llama.cpp's implementation 2. **Test Coverage**: Add tests for new features matching llama.cpp's behavior 3. **Documentation**: Update relevant documentation when adding new formats or options 4. **Backward Compatibility**: Ensure changes don't break existing functionality ### Files to Monitor - `llama.cpp/common/chat-parser-xml-toolcall.h` - Format definitions - `llama.cpp/common/chat-parser-xml-toolcall.cpp` - Parsing logic - `llama.cpp/common/chat-parser.cpp` - Format presets and model-specific handlers - `llama.cpp/common/chat.h` - Format enums and parameter structures - `llama.cpp/tools/server/server-context.cpp` - Server configuration options ================================================ FILE: .agents/testing-mcp-apps.md ================================================ # Testing MCP Apps (Interactive Tool UIs) MCP Apps is an extension to MCP where tools declare interactive HTML UIs via `_meta.ui.resourceUri`. When the LLM calls such a tool, the UI renders the app in a sandboxed iframe inline in the chat. The app communicates bidirectionally with the host via `postMessage` (JSON-RPC) and can call server tools, send messages, and update model context. Spec: https://modelcontextprotocol.io/extensions/apps/overview ## Quick Start: Run a Test MCP App Server The `@modelcontextprotocol/server-basic-react` npm package is a ready-to-use test server that exposes a `get-time` tool with an interactive React clock UI. It requires Node >= 20, so run it in Docker: ```bash docker run -d --name mcp-app-test -p 3001:3001 node:22-slim \ sh -c 'npx -y @modelcontextprotocol/server-basic-react' ``` Wait ~10 seconds for it to start, then verify: ```bash # Check it's running docker logs mcp-app-test # Expected: "MCP server listening on http://localhost:3001/mcp" # Verify MCP protocol works curl -s -X POST http://localhost:3001/mcp \ -H 'Content-Type: application/json' \ -H 'Accept: application/json, text/event-stream' \ -d '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-03-26","capabilities":{},"clientInfo":{"name":"test","version":"1.0.0"}}}' # List tools — should show get-time with _meta.ui.resourceUri curl -s -X POST http://localhost:3001/mcp \ -H 'Content-Type: application/json' \ -H 'Accept: application/json, text/event-stream' \ -d '{"jsonrpc":"2.0","id":2,"method":"tools/list","params":{}}' ``` The `tools/list` response should contain: ```json { "name": "get-time", "_meta": { "ui": { "resourceUri": "ui://get-time/mcp-app.html" } } } ``` ## Testing in LocalAI's UI 1. Make sure LocalAI is running (e.g. `http://localhost:8080`) 2. Build the React UI: `cd core/http/react-ui && npm install && npm run build` 3. Open the Chat page in your browser 4. Click **"Client MCP"** in the chat header 5. Add a new client MCP server: - **URL**: `http://localhost:3001/mcp` - **Use CORS proxy**: enabled (default) — required because the browser can't hit `localhost:3001` directly due to CORS; LocalAI's proxy at `/api/cors-proxy` handles it 6. The server should connect and discover the `get-time` tool 7. Select a model and send: **"What time is it?"** 8. The LLM should call the `get-time` tool 9. The tool result should render the interactive React clock app in an iframe as a standalone chat message (not inside the collapsed activity group) ## What to Verify - [ ] Tool appears in the connected tools list (not filtered — `get-time` is callable by the LLM) - [ ] The iframe renders as a standalone chat message with a puzzle-piece icon - [ ] The app loads and is interactive (clock UI, buttons work) - [ ] No "Reconnect to MCP server" overlay (connection is live) - [ ] Console logs show bidirectional communication: - `tools/call` messages from app to host (app calling server tools) - `ui/message` notifications (app sending messages) - [ ] After the app renders, the LLM continues and produces a text response with the time - [ ] Non-UI tools continue to work normally (text-only results) - [ ] Page reload shows the HTML statically with a reconnect overlay until you reconnect ## Console Log Patterns Healthy bidirectional communication looks like: ``` Parsed message { jsonrpc: "2.0", id: N, result: {...} } // Bridge init get-time result: { content: [...] } // Tool result received Calling get-time tool... // App calls tool Sending message { method: "tools/call", ... } // App -> host -> server Parsed message { jsonrpc: "2.0", id: N, result: {...} } // Server response Sending message text to Host: ... // App sends message Sending message { method: "ui/message", ... } // Message notification Message accepted // Host acknowledged ``` Benign warnings to ignore: - `Source map error: ... about:srcdoc` — browser devtools can't find source maps for srcdoc iframes - `Ignoring message from unknown source` — duplicate postMessage from iframe navigation - `notifications/cancelled` — app cleaning up previous requests ## Architecture Notes - **No server-side changes needed** — the MCP App protocol runs entirely in the browser - `PostMessageTransport` wraps `window.postMessage` between host and `srcdoc` iframe - `AppBridge` (from `@modelcontextprotocol/ext-apps`) auto-forwards `tools/call`, `resources/read`, `resources/list` from the app to the MCP server via the host's `Client` - The iframe uses `sandbox="allow-scripts allow-forms"` (no `allow-same-origin`) — opaque origin, no access to host cookies/DOM/localStorage - App-only tools (`_meta.ui.visibility: "app-only"`) are filtered from the LLM's tool list but remain callable by the app iframe ## Key Files - `core/http/react-ui/src/components/MCPAppFrame.jsx` — iframe + AppBridge component - `core/http/react-ui/src/hooks/useMCPClient.js` — MCP client hook with app UI helpers (`hasAppUI`, `getAppResource`, `getClientForTool`, `getToolDefinition`) - `core/http/react-ui/src/hooks/useChat.js` — agentic loop, attaches `appUI` to tool_result messages - `core/http/react-ui/src/pages/Chat.jsx` — renders MCPAppFrame as standalone chat messages ## Other Test Servers The `@modelcontextprotocol/ext-apps` repo has many example servers: - `@modelcontextprotocol/server-basic-react` — simple clock (React) - More examples at https://github.com/modelcontextprotocol/ext-apps/tree/main/examples All examples support both stdio and HTTP transport. Run without `--stdio` for HTTP mode on port 3001. ## Cleanup ```bash docker rm -f mcp-app-test ``` ================================================ FILE: .air.toml ================================================ # .air.toml [build] cmd = "make build" bin = "./local-ai" args_bin = [ "--debug" ] include_ext = ["go", "html", "yaml", "toml", "json", "txt", "md"] exclude_dir = ["pkg/grpc/proto"] delay = 1000 ================================================ FILE: .devcontainer/devcontainer.json ================================================ { "$schema": "https://raw.githubusercontent.com/devcontainers/spec/main/schemas/devContainer.schema.json", "name": "LocalAI", "workspaceFolder": "/workspace", "dockerComposeFile": [ "./docker-compose-devcontainer.yml" ], "service": "api", "shutdownAction": "stopCompose", "customizations": { "vscode": { "extensions": [ "golang.go", "ms-vscode.makefile-tools", "ms-azuretools.vscode-docker", "ms-python.python", "ms-python.debugpy", "wayou.vscode-todo-highlight", "waderyan.gitblame" ] } }, "forwardPorts": [8080, 3000], "postCreateCommand": "bash /.devcontainer-scripts/postcreate.sh", "postStartCommand": "bash /.devcontainer-scripts/poststart.sh" } ================================================ FILE: .devcontainer/docker-compose-devcontainer.yml ================================================ services: api: build: context: .. dockerfile: Dockerfile target: devcontainer env_file: - ../.env ports: - 8080:8080 volumes: - localai_workspace:/workspace - models:/host-models - backends:/host-backends - ./customization:/devcontainer-customization command: /bin/sh -c "while sleep 1000; do :; done" cap_add: - SYS_PTRACE security_opt: - seccomp:unconfined prometheus: image: prom/prometheus container_name: prometheus command: - '--config.file=/etc/prometheus/prometheus.yml' ports: - 9090:9090 restart: unless-stopped volumes: - ./prometheus:/etc/prometheus - prom_data:/prometheus grafana: image: grafana/grafana container_name: grafana ports: - 3000:3000 restart: unless-stopped environment: - GF_SECURITY_ADMIN_USER=admin - GF_SECURITY_ADMIN_PASSWORD=grafana volumes: - ./grafana:/etc/grafana/provisioning/datasources volumes: prom_data: localai_workspace: models: backends: ================================================ FILE: .devcontainer/grafana/datasource.yml ================================================ apiVersion: 1 datasources: - name: Prometheus type: prometheus url: http://prometheus:9090 isDefault: true access: proxy editable: true ================================================ FILE: .devcontainer/prometheus/prometheus.yml ================================================ global: scrape_interval: 15s scrape_timeout: 10s evaluation_interval: 15s alerting: alertmanagers: - static_configs: - targets: [] scheme: http timeout: 10s api_version: v1 scrape_configs: - job_name: prometheus honor_timestamps: true scrape_interval: 15s scrape_timeout: 10s metrics_path: /metrics scheme: http static_configs: - targets: - localhost:9090 ================================================ FILE: .devcontainer-scripts/postcreate.sh ================================================ #!/bin/bash cd /workspace # Get the files into the volume without a bind mount if [ ! -d ".git" ]; then git clone https://github.com/mudler/LocalAI.git . else git fetch fi echo "Standard Post-Create script completed." if [ -f "/devcontainer-customization/postcreate.sh" ]; then echo "Launching customization postcreate.sh" bash "/devcontainer-customization/postcreate.sh" fi ================================================ FILE: .devcontainer-scripts/poststart.sh ================================================ #!/bin/bash cd /workspace # Ensures generated source files are present upon load make prepare echo "Standard Post-Start script completed." if [ -f "/devcontainer-customization/poststart.sh" ]; then echo "Launching customization poststart.sh" bash "/devcontainer-customization/poststart.sh" fi ================================================ FILE: .devcontainer-scripts/utils.sh ================================================ #!/bin/bash # This file contains some really simple functions that are useful when building up customization scripts. # Checks if the git config has a user registered - and sets it up if not. # # Param 1: name # Param 2: email # config_user() { echo "Configuring git for $1 <$2>" local gcn=$(git config --global user.name) if [ -z "${gcn}" ]; then echo "Setting up git user / remote" git config --global user.name "$1" git config --global user.email "$2" fi } # Checks if the git remote is configured - and sets it up if not. Fetches either way. # # Param 1: remote name # Param 2: remote url # config_remote() { echo "Adding git remote and fetching $2 as $1" local gr=$(git remote -v | grep $1) if [ -z "${gr}" ]; then git remote add $1 $2 fi git fetch $1 } # Setup special .ssh files # Prints out lines of text to make things pretty # Param 1: bash array, filenames relative to the customization directory that should be copied to ~/.ssh setup_ssh() { echo "starting ~/.ssh directory setup..." mkdir -p "${HOME}.ssh" chmod 0700 "${HOME}/.ssh" echo "-----" local files=("$@") for file in "${files[@]}" ; do local cfile="/devcontainer-customization/${file}" local hfile="${HOME}/.ssh/${file}" if [ ! -f "${hfile}" ]; then echo "copying \"${file}\"" cp "${cfile}" "${hfile}" chmod 600 "${hfile}" fi done echo "~/.ssh directory setup complete!" } ================================================ FILE: .dockerignore ================================================ .idea .github .vscode .devcontainer models backends examples/chatbot-ui/models backend/go/image/stablediffusion-ggml/build/ backend/go/*/build backend/go/*/.cache backend/go/*/sources backend/go/*/package examples/rwkv/models examples/**/models Dockerfile* __pycache__ # SonarQube .scannerwork # backend virtual environments **/venv backend/python/**/source ================================================ FILE: .editorconfig ================================================ root = true [*] indent_style = space indent_size = 2 end_of_line = lf charset = utf-8 trim_trailing_whitespace = true insert_final_newline = true [*.go] indent_style = tab [Makefile] indent_style = tab [*.proto] indent_size = 2 [*.py] indent_size = 4 [*.js] indent_size = 2 [*.yaml] indent_size = 2 [*.md] trim_trailing_whitespace = false ================================================ FILE: .gitattributes ================================================ *.sh text eol=lf backend/cpp/llama/*.hpp linguist-vendored ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: [mudler] custom: - https://www.buymeacoffee.com/mudler ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: bug, unconfirmed, up-for-grabs --- **LocalAI version:** **Environment, CPU architecture, OS, and Version:** **Describe the bug** **To Reproduce** **Expected behavior** **Logs** **Additional context** ================================================ FILE: .github/ISSUE_TEMPLATE/config.yml ================================================ blank_issues_enabled: false contact_links: - name: Community Support url: https://github.com/go-skynet/LocalAI/discussions about: Please ask and answer questions here. - name: Discord url: https://discord.gg/uJAeKSAGDy about: Join our community on Discord! ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: enhancement, up-for-grabs --- **Is your feature request related to a problem? Please describe.** **Describe the solution you'd like** **Describe alternatives you've considered** **Additional context** ================================================ FILE: .github/PULL_REQUEST_TEMPLATE.md ================================================ **Description** This PR fixes # **Notes for Reviewers** **[Signed commits](../CONTRIBUTING.md#signing-off-on-commits-developer-certificate-of-origin)** - [ ] Yes, I signed my commits. ================================================ FILE: .github/bump_deps.sh ================================================ #!/bin/bash set -xe REPO=$1 BRANCH=$2 VAR=$3 FILE=$4 if [ -z "$FILE" ]; then FILE="Makefile" fi LAST_COMMIT=$(curl -s -H "Accept: application/vnd.github.VERSION.sha" "https://api.github.com/repos/$REPO/commits/$BRANCH") # Read $VAR from Makefile (only first match) set +e CURRENT_COMMIT="$(grep -m1 "^$VAR?=" $FILE | cut -d'=' -f2)" set -e sed -i $FILE -e "s/$VAR?=.*/$VAR?=$LAST_COMMIT/" if [ -z "$CURRENT_COMMIT" ]; then echo "Could not find $VAR in Makefile." exit 0 fi echo "Changes: https://github.com/$REPO/compare/${CURRENT_COMMIT}..${LAST_COMMIT}" >> "${VAR}_message.txt" echo "${LAST_COMMIT}" >> "${VAR}_commit.txt" ================================================ FILE: .github/bump_docs.sh ================================================ #!/bin/bash set -xe REPO=$1 LATEST_TAG=$(curl -s "https://api.github.com/repos/$REPO/releases/latest" | jq -r '.tag_name') cat <<< $(jq ".version = \"$LATEST_TAG\"" docs/data/version.json) > docs/data/version.json ================================================ FILE: .github/check_and_update.py ================================================ import hashlib from huggingface_hub import hf_hub_download, get_paths_info import requests import sys import os uri = sys.argv[1] file_name = uri.split('/')[-1] # Function to parse the URI and determine download method def parse_uri(uri): if uri.startswith('huggingface://'): repo_id = uri.split('://')[1] return 'huggingface', repo_id.rsplit('/', 1)[0] elif 'huggingface.co' in uri: parts = uri.split('/resolve/') if len(parts) > 1: repo_path = parts[0].split('https://huggingface.co/')[-1] return 'huggingface', repo_path return 'direct', uri def calculate_sha256(file_path): sha256_hash = hashlib.sha256() with open(file_path, 'rb') as f: for byte_block in iter(lambda: f.read(4096), b''): sha256_hash.update(byte_block) return sha256_hash.hexdigest() def manual_safety_check_hf(repo_id): scanResponse = requests.get('https://huggingface.co/api/models/' + repo_id + "/scan") scan = scanResponse.json() # Check if 'hasUnsafeFile' exists in the response if 'hasUnsafeFile' in scan: if scan['hasUnsafeFile']: return scan else: return None else: return None download_type, repo_id_or_url = parse_uri(uri) new_checksum = None file_path = None # Decide download method based on URI type if download_type == 'huggingface': # Check if the repo is flagged as dangerous by HF hazard = manual_safety_check_hf(repo_id_or_url) if hazard != None: print(f'Error: HuggingFace has detected security problems for {repo_id_or_url}: {str(hazard)}', filename=file_name) sys.exit(5) # Use HF API to pull sha for file in get_paths_info(repo_id_or_url, [file_name], repo_type='model'): try: new_checksum = file.lfs.sha256 break except Exception as e: print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr) sys.exit(2) if new_checksum is None: try: file_path = hf_hub_download(repo_id=repo_id_or_url, filename=file_name) except Exception as e: print(f'Error from Hugging Face Hub: {str(e)}', file=sys.stderr) sys.exit(2) else: response = requests.get(repo_id_or_url) if response.status_code == 200: with open(file_name, 'wb') as f: f.write(response.content) file_path = file_name elif response.status_code == 404: print(f'File not found: {response.status_code}', file=sys.stderr) sys.exit(2) else: print(f'Error downloading file: {response.status_code}', file=sys.stderr) sys.exit(1) if new_checksum is None: new_checksum = calculate_sha256(file_path) print(new_checksum) os.remove(file_path) else: print(new_checksum) ================================================ FILE: .github/checksum_checker.sh ================================================ #!/bin/bash # This scripts needs yq and huggingface_hub to be installed # to install hugingface_hub run pip install huggingface_hub # Path to the input YAML file input_yaml=$1 # Function to download file and check checksum using Python function check_and_update_checksum() { model_name="$1" file_name="$2" uri="$3" old_checksum="$4" idx="$5" # Download the file and calculate new checksum using Python new_checksum=$(python3 ./.github/check_and_update.py $uri) result=$? if [[ $result -eq 5 ]]; then echo "Contaminated entry detected, deleting entry for $model_name..." yq eval -i "del([$idx])" "$input_yaml" return fi if [[ "$new_checksum" == "" ]]; then echo "Error calculating checksum for $file_name. Skipping..." return fi echo "Checksum for $file_name: $new_checksum" # Compare and update the YAML file if checksums do not match if [[ $result -eq 2 ]]; then echo "File not found, deleting entry for $file_name..." # yq eval -i "del(.[$idx].files[] | select(.filename == \"$file_name\"))" "$input_yaml" elif [[ "$old_checksum" != "$new_checksum" ]]; then echo "Checksum mismatch for $file_name. Updating..." yq eval -i "del(.[$idx].files[] | select(.filename == \"$file_name\").sha256)" "$input_yaml" yq eval -i "(.[$idx].files[] | select(.filename == \"$file_name\")).sha256 = \"$new_checksum\"" "$input_yaml" elif [[ $result -ne 0 ]]; then echo "Error downloading file $file_name. Skipping..." else echo "Checksum match for $file_name. No update needed." fi } # Read the YAML and process each file len=$(yq eval '. | length' "$input_yaml") for ((i=0; i<$len; i++)) do name=$(yq eval ".[$i].name" "$input_yaml") files_len=$(yq eval ".[$i].files | length" "$input_yaml") for ((j=0; j<$files_len; j++)) do filename=$(yq eval ".[$i].files[$j].filename" "$input_yaml") uri=$(yq eval ".[$i].files[$j].uri" "$input_yaml") checksum=$(yq eval ".[$i].files[$j].sha256" "$input_yaml") echo "Checking model $name, file $filename. URI = $uri, Checksum = $checksum" check_and_update_checksum "$name" "$filename" "$uri" "$checksum" "$i" done done ================================================ FILE: .github/ci/modelslist.go ================================================ package main import ( "fmt" "html/template" "io/ioutil" "os" "github.com/microcosm-cc/bluemonday" "gopkg.in/yaml.v3" ) var modelPageTemplate string = ` LocalAI models

LocalAI model gallery list


🖼️ Available {{.AvailableModels}} models

Refer to the Model gallery for more information on how to use the models with LocalAI.
You can install models with the CLI command local-ai models install . or by using the WebUI.

{{ range $_, $model := .Models }}
{{ $icon := "https://upload.wikimedia.org/wikipedia/commons/6/65/No-Image-Placeholder.svg" }} {{ if $model.Icon }} {{ $icon = $model.Icon }} {{ end }}
{{$model.Name}}
{{$model.Name}}

{{ $model.Description }}

{{ end }}
` type GalleryModel struct { Name string `json:"name" yaml:"name"` URLs []string `json:"urls" yaml:"urls"` Icon string `json:"icon" yaml:"icon"` Description string `json:"description" yaml:"description"` } func main() { // read the YAML file which contains the models f, err := ioutil.ReadFile(os.Args[1]) if err != nil { fmt.Println("Error reading file:", err) return } models := []*GalleryModel{} err = yaml.Unmarshal(f, &models) if err != nil { // write to stderr os.Stderr.WriteString("Error unmarshaling YAML: " + err.Error() + "\n") return } // Ensure that all arbitrary text content is sanitized before display for i, m := range models { models[i].Name = bluemonday.StrictPolicy().Sanitize(m.Name) models[i].Description = bluemonday.StrictPolicy().Sanitize(m.Description) } // render the template data := struct { Models []*GalleryModel AvailableModels int }{ Models: models, AvailableModels: len(models), } tmpl := template.Must(template.New("modelPage").Parse(modelPageTemplate)) err = tmpl.Execute(os.Stdout, data) if err != nil { fmt.Println("Error executing template:", err) return } } ================================================ FILE: .github/dependabot.yml ================================================ # https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file version: 2 updates: - package-ecosystem: "gitsubmodule" directory: "/" schedule: interval: "weekly" - package-ecosystem: "gomod" directory: "/" schedule: interval: "weekly" ignore: - dependency-name: "github.com/mudler/LocalAI/pkg/grpc/proto" - package-ecosystem: "github-actions" # Workflow files stored in the default location of `.github/workflows`. (You don't need to specify `/.github/workflows` for `directory`. You can use `directory: "/"`.) directory: "/" schedule: # Check for updates to GitHub Actions every weekday interval: "weekly" - package-ecosystem: "pip" # Workflow files stored in the default location of `.github/workflows`. (You don't need to specify `/.github/workflows` for `directory`. You can use `directory: "/"`.) directory: "/" schedule: # Check for updates to GitHub Actions every weekday interval: "weekly" - package-ecosystem: "docker" # Workflow files stored in the default location of `.github/workflows`. (You don't need to specify `/.github/workflows` for `directory`. You can use `directory: "/"`.) directory: "/" schedule: # Check for updates to GitHub Actions every weekday interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/bark" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/common/template" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/coqui" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/diffusers" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/exllama" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/exllama2" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/mamba" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/openvoice" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/rerankers" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/sentencetransformers" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/transformers" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/backend/python/vllm" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/examples/chainlit" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/examples/functions" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/examples/langchain/langchainpy-localai-example" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/examples/langchain-chroma" schedule: interval: "weekly" - package-ecosystem: "pip" directory: "/examples/streamlit-bot" schedule: interval: "weekly" - package-ecosystem: "docker" directory: "/examples/k8sgpt" schedule: interval: "weekly" - package-ecosystem: "docker" directory: "/examples/kubernetes" schedule: interval: "weekly" - package-ecosystem: "docker" directory: "/examples/langchain" schedule: interval: "weekly" - package-ecosystem: "gomod" directory: "/examples/semantic-todo" schedule: interval: "weekly" - package-ecosystem: "docker" directory: "/examples/telegram-bot" schedule: interval: "weekly" ================================================ FILE: .github/gallery-agent/agent.go ================================================ package main import ( "context" "encoding/json" "fmt" "io" "net/http" "os" "regexp" "slices" "strings" "github.com/ghodss/yaml" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" "github.com/mudler/cogito" "github.com/mudler/cogito/clients" "github.com/mudler/cogito/structures" "github.com/sashabaranov/go-openai/jsonschema" ) var ( openAIModel = os.Getenv("OPENAI_MODEL") openAIKey = os.Getenv("OPENAI_KEY") openAIBaseURL = os.Getenv("OPENAI_BASE_URL") galleryIndexPath = os.Getenv("GALLERY_INDEX_PATH") //defaultclient llm = clients.NewOpenAILLM(openAIModel, openAIKey, openAIBaseURL) ) // cleanTextContent removes trailing spaces, tabs, and normalizes line endings // to prevent YAML linting issues like trailing spaces and multiple empty lines func cleanTextContent(text string) string { lines := strings.Split(text, "\n") var cleanedLines []string var prevEmpty bool for _, line := range lines { // Remove all trailing whitespace (spaces, tabs, etc.) trimmed := strings.TrimRight(line, " \t\r") // Avoid multiple consecutive empty lines if trimmed == "" { if !prevEmpty { cleanedLines = append(cleanedLines, "") } prevEmpty = true } else { cleanedLines = append(cleanedLines, trimmed) prevEmpty = false } } // Remove trailing empty lines from the result result := strings.Join(cleanedLines, "\n") return stripThinkingTags(strings.TrimRight(result, "\n")) } type galleryModel struct { Name string `yaml:"name"` Urls []string `yaml:"urls"` } // isModelExisting checks if a specific model ID exists in the gallery using text search func isModelExisting(modelID string) (bool, error) { indexPath := getGalleryIndexPath() content, err := os.ReadFile(indexPath) if err != nil { return false, fmt.Errorf("failed to read %s: %w", indexPath, err) } var galleryModels []galleryModel err = yaml.Unmarshal(content, &galleryModels) if err != nil { return false, fmt.Errorf("failed to unmarshal %s: %w", indexPath, err) } for _, galleryModel := range galleryModels { if slices.Contains(galleryModel.Urls, modelID) { return true, nil } } return false, nil } // filterExistingModels removes models that already exist in the gallery func filterExistingModels(models []ProcessedModel) ([]ProcessedModel, error) { var filteredModels []ProcessedModel for _, model := range models { exists, err := isModelExisting(model.ModelID) if err != nil { fmt.Printf("Error checking if model %s exists: %v, skipping\n", model.ModelID, err) continue } if !exists { filteredModels = append(filteredModels, model) } else { fmt.Printf("Skipping existing model: %s\n", model.ModelID) } } fmt.Printf("Filtered out %d existing models, %d new models remaining\n", len(models)-len(filteredModels), len(filteredModels)) return filteredModels, nil } // getGalleryIndexPath returns the gallery index file path, with a default fallback func getGalleryIndexPath() string { if galleryIndexPath != "" { return galleryIndexPath } return "gallery/index.yaml" } func stripThinkingTags(content string) string { // Remove content between and (including multi-line) content = regexp.MustCompile(`(?s).*?`).ReplaceAllString(content, "") // Remove content between and (including multi-line) content = regexp.MustCompile(`(?s).*?`).ReplaceAllString(content, "") // Clean up any extra whitespace content = strings.TrimSpace(content) return content } func getRealReadme(ctx context.Context, repository string) (string, error) { // Create a conversation fragment fragment := cogito.NewEmptyFragment(). AddMessage("user", `Your task is to get a clear description of a large language model from huggingface by using the provided tool. I will share with you a repository that might be quantized, and as such probably not by the original model author. We need to get the real description of the model, and not the one that might be quantized. You will have to call the tool to get the readme more than once by figuring out from the quantized readme which is the base model readme. This is the repository: `+repository) // Execute with tools result, err := cogito.ExecuteTools(llm, fragment, cogito.WithIterations(3), cogito.WithMaxAttempts(3), cogito.WithTools(&HFReadmeTool{client: hfapi.NewClient()})) if err != nil { return "", err } result = result.AddMessage("user", "Describe the model in a clear and concise way that can be shared in a model gallery.") // Get a response _, err = llm.Ask(ctx, result) if err != nil { return "", err } content := result.LastMessage().Content return cleanTextContent(content), nil } func selectMostInterestingModels(ctx context.Context, searchResult *SearchResult) ([]ProcessedModel, error) { if len(searchResult.Models) == 1 { return searchResult.Models, nil } // Create a conversation fragment fragment := cogito.NewEmptyFragment(). AddMessage("user", `Your task is to analyze a list of AI models and select the most interesting ones for a model gallery. You will be given detailed information about multiple models including their metadata, file information, and README content. Consider the following criteria when selecting models: 1. Model popularity (download count) 2. Model recency (last modified date) 3. Model completeness (has preferred model file, README, etc.) 4. Model uniqueness (not duplicates or very similar models) 5. Model quality (based on README content and description) 6. Model utility (practical applications) You should select models that would be most valuable for users browsing a model gallery. Prioritize models that are: - Well-documented with clear READMEs - Recently updated - Popular (high download count) - Have the preferred quantization format available - Offer unique capabilities or are from reputable authors Return your analysis and selection reasoning.`) // Add the search results as context modelsInfo := fmt.Sprintf("Found %d models matching '%s' with quantization preference '%s':\n\n", searchResult.TotalModelsFound, searchResult.SearchTerm, searchResult.Quantization) for i, model := range searchResult.Models { modelsInfo += fmt.Sprintf("Model %d:\n", i+1) modelsInfo += fmt.Sprintf(" ID: %s\n", model.ModelID) modelsInfo += fmt.Sprintf(" Author: %s\n", model.Author) modelsInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads) modelsInfo += fmt.Sprintf(" Last Modified: %s\n", model.LastModified) modelsInfo += fmt.Sprintf(" Files: %d files\n", len(model.Files)) if model.PreferredModelFile != nil { modelsInfo += fmt.Sprintf(" Preferred Model File: %s (%d bytes)\n", model.PreferredModelFile.Path, model.PreferredModelFile.Size) } else { modelsInfo += " No preferred model file found\n" } if model.ReadmeContent != "" { modelsInfo += fmt.Sprintf(" README: %s\n", model.ReadmeContent) } if model.ProcessingError != "" { modelsInfo += fmt.Sprintf(" Processing Error: %s\n", model.ProcessingError) } modelsInfo += "\n" } fragment = fragment.AddMessage("user", modelsInfo) fragment = fragment.AddMessage("user", "Based on your analysis, select the top 5 most interesting models and provide a brief explanation for each selection. Also, create a filtered SearchResult with only the selected models. Return just a list of repositories IDs, you will later be asked to output it as a JSON array with the json tool.") // Get a response newFragment, err := llm.Ask(ctx, fragment) if err != nil { return nil, err } fmt.Println(newFragment.LastMessage().Content) repositories := struct { Repositories []string `json:"repositories"` }{} s := structures.Structure{ Schema: jsonschema.Definition{ Type: jsonschema.Object, AdditionalProperties: false, Properties: map[string]jsonschema.Definition{ "repositories": { Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}, Description: "The trending repositories IDs", }, }, Required: []string{"repositories"}, }, Object: &repositories, } err = newFragment.ExtractStructure(ctx, llm, s) if err != nil { return nil, err } filteredModels := []ProcessedModel{} for _, m := range searchResult.Models { if slices.Contains(repositories.Repositories, m.ModelID) { filteredModels = append(filteredModels, m) } } return filteredModels, nil } // ModelMetadata represents extracted metadata from a model type ModelMetadata struct { Tags []string `json:"tags"` License string `json:"license"` } // extractModelMetadata extracts tags and license from model README and documentation func extractModelMetadata(ctx context.Context, model ProcessedModel) ([]string, string, error) { // Create a conversation fragment fragment := cogito.NewEmptyFragment(). AddMessage("user", `Your task is to extract metadata from an AI model's README and documentation. You will be provided with: 1. Model information (ID, author, description) 2. README content You need to extract: 1. **Tags**: An array of relevant tags that describe the model. Use common tags from the gallery such as: - llm, gguf, gpu, cpu, multimodal, image-to-text, text-to-text, text-to-speech, tts - thinking, reasoning, chat, instruction-tuned, code, vision - Model family names (e.g., llama, qwen, mistral, gemma) if applicable - Any other relevant descriptive tags Select 3-8 most relevant tags. 2. **License**: The license identifier (e.g., "apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", "cc-by-4.0"). If no license is found, return an empty string. Return the extracted metadata in a structured format.`) // Add model information modelInfo := "Model Information:\n" modelInfo += fmt.Sprintf(" ID: %s\n", model.ModelID) modelInfo += fmt.Sprintf(" Author: %s\n", model.Author) modelInfo += fmt.Sprintf(" Downloads: %d\n", model.Downloads) if model.ReadmeContent != "" { modelInfo += fmt.Sprintf(" README Content:\n%s\n", model.ReadmeContent) } else if model.ReadmeContentPreview != "" { modelInfo += fmt.Sprintf(" README Preview: %s\n", model.ReadmeContentPreview) } fragment = fragment.AddMessage("user", modelInfo) fragment = fragment.AddMessage("user", "Extract the tags and license from the model information. Return the metadata as a JSON object with 'tags' (array of strings) and 'license' (string).") // Get a response newFragment, err := llm.Ask(ctx, fragment) if err != nil { return nil, "", err } // Extract structured metadata metadata := ModelMetadata{} s := structures.Structure{ Schema: jsonschema.Definition{ Type: jsonschema.Object, AdditionalProperties: false, Properties: map[string]jsonschema.Definition{ "tags": { Type: jsonschema.Array, Items: &jsonschema.Definition{Type: jsonschema.String}, Description: "Array of relevant tags describing the model", }, "license": { Type: jsonschema.String, Description: "License identifier (e.g., apache-2.0, mit, llama2). Empty string if not found.", }, }, Required: []string{"tags", "license"}, }, Object: &metadata, } err = newFragment.ExtractStructure(ctx, llm, s) if err != nil { return nil, "", err } return metadata.Tags, metadata.License, nil } // extractIconFromReadme scans the README content for image URLs and returns the first suitable icon URL found func extractIconFromReadme(readmeContent string) string { if readmeContent == "" { return "" } // Regular expressions to match image URLs in various formats (case-insensitive) // Match markdown image syntax: ![alt](url) - case insensitive extensions markdownImageRegex := regexp.MustCompile(`(?i)!\[[^\]]*\]\(([^)]+\.(png|jpg|jpeg|svg|webp|gif))\)`) // Match HTML img tags: htmlImageRegex := regexp.MustCompile(`(?i)]+src=["']([^"']+\.(png|jpg|jpeg|svg|webp|gif))["']`) // Match plain URLs ending with image extensions plainImageRegex := regexp.MustCompile(`(?i)https?://[^\s<>"']+\.(png|jpg|jpeg|svg|webp|gif)`) // Try markdown format first matches := markdownImageRegex.FindStringSubmatch(readmeContent) if len(matches) > 1 && matches[1] != "" { url := strings.TrimSpace(matches[1]) // Prefer HuggingFace CDN URLs or absolute URLs if strings.HasPrefix(strings.ToLower(url), "http") { return url } } // Try HTML img tags matches = htmlImageRegex.FindStringSubmatch(readmeContent) if len(matches) > 1 && matches[1] != "" { url := strings.TrimSpace(matches[1]) if strings.HasPrefix(strings.ToLower(url), "http") { return url } } // Try plain URLs matches = plainImageRegex.FindStringSubmatch(readmeContent) if len(matches) > 0 { url := strings.TrimSpace(matches[0]) if strings.HasPrefix(strings.ToLower(url), "http") { return url } } return "" } // getHuggingFaceAvatarURL attempts to get the HuggingFace avatar URL for a user func getHuggingFaceAvatarURL(author string) string { if author == "" { return "" } // Try to fetch user info from HuggingFace API // HuggingFace API endpoint: https://huggingface.co/api/users/{username} baseURL := "https://huggingface.co" userURL := fmt.Sprintf("%s/api/users/%s", baseURL, author) req, err := http.NewRequest("GET", userURL, nil) if err != nil { return "" } client := &http.Client{} resp, err := client.Do(req) if err != nil { return "" } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return "" } // Parse the response to get avatar URL var userInfo map[string]interface{} body, err := io.ReadAll(resp.Body) if err != nil { return "" } if err := json.Unmarshal(body, &userInfo); err != nil { return "" } // Try to extract avatar URL from response if avatar, ok := userInfo["avatarUrl"].(string); ok && avatar != "" { return avatar } if avatar, ok := userInfo["avatar"].(string); ok && avatar != "" { return avatar } return "" } // extractModelIcon extracts icon URL from README or falls back to HuggingFace avatar func extractModelIcon(model ProcessedModel) string { // First, try to extract icon from README if icon := extractIconFromReadme(model.ReadmeContent); icon != "" { return icon } // Fallback: Try to get HuggingFace user avatar if model.Author != "" { if avatar := getHuggingFaceAvatarURL(model.Author); avatar != "" { return avatar } } return "" } ================================================ FILE: .github/gallery-agent/gallery.go ================================================ package main import ( "context" "encoding/json" "fmt" "os" "strings" "github.com/ghodss/yaml" "github.com/mudler/LocalAI/core/gallery/importers" ) func formatTextContent(text string) string { return formatTextContentWithIndent(text, 4, 6) } // formatTextContentWithIndent formats text content with specified base and list item indentation func formatTextContentWithIndent(text string, baseIndent int, listItemIndent int) string { var formattedLines []string lines := strings.Split(text, "\n") for _, line := range lines { trimmed := strings.TrimRight(line, " \t\r") if trimmed == "" { // Keep empty lines as empty (no indentation) formattedLines = append(formattedLines, "") } else { // Preserve relative indentation from yaml.Marshal output // Count existing leading spaces to preserve relative structure leadingSpaces := len(trimmed) - len(strings.TrimLeft(trimmed, " \t")) trimmedStripped := strings.TrimLeft(trimmed, " \t") var totalIndent int if strings.HasPrefix(trimmedStripped, "-") { // List items: use listItemIndent (ignore existing leading spaces) totalIndent = listItemIndent } else { // Regular lines: use baseIndent + preserve relative indentation // This handles both top-level keys (leadingSpaces=0) and nested properties (leadingSpaces>0) totalIndent = baseIndent + leadingSpaces } indentStr := strings.Repeat(" ", totalIndent) formattedLines = append(formattedLines, indentStr+trimmedStripped) } } formattedText := strings.Join(formattedLines, "\n") // Remove any trailing spaces from the formatted description formattedText = strings.TrimRight(formattedText, " \t") return formattedText } // generateYAMLEntry generates a YAML entry for a model using the specified anchor func generateYAMLEntry(model ProcessedModel, quantization string) string { modelConfig, err := importers.DiscoverModelConfig("https://huggingface.co/"+model.ModelID, json.RawMessage(`{ "quantization": "`+quantization+`"}`)) if err != nil { panic(err) } // Extract model name from ModelID parts := strings.Split(model.ModelID, "/") modelName := model.ModelID if len(parts) > 0 { modelName = strings.ToLower(parts[len(parts)-1]) } // Remove common suffixes modelName = strings.ReplaceAll(modelName, "-gguf", "") modelName = strings.ReplaceAll(modelName, "-q4_k_m", "") modelName = strings.ReplaceAll(modelName, "-q4_k_s", "") modelName = strings.ReplaceAll(modelName, "-q3_k_m", "") modelName = strings.ReplaceAll(modelName, "-q2_k", "") description := model.ReadmeContent if description == "" { description = fmt.Sprintf("AI model: %s", modelName) } // Clean up description to prevent YAML linting issues description = cleanTextContent(description) formattedDescription := formatTextContent(description) configFile := formatTextContent(modelConfig.ConfigFile) filesYAML, _ := yaml.Marshal(modelConfig.Files) // Files section: list items need 4 spaces (not 6), since files: is at 2 spaces files := formatTextContentWithIndent(string(filesYAML), 4, 4) // Build metadata sections var metadataSections []string // Add license if present if model.License != "" { metadataSections = append(metadataSections, fmt.Sprintf(` license: "%s"`, model.License)) } // Add tags if present if len(model.Tags) > 0 { tagsYAML, _ := yaml.Marshal(model.Tags) tagsFormatted := formatTextContentWithIndent(string(tagsYAML), 4, 4) tagsFormatted = strings.TrimRight(tagsFormatted, "\n") metadataSections = append(metadataSections, fmt.Sprintf(" tags:\n%s", tagsFormatted)) } // Add icon if present if model.Icon != "" { metadataSections = append(metadataSections, fmt.Sprintf(` icon: %s`, model.Icon)) } // Build the metadata block metadataBlock := "" if len(metadataSections) > 0 { metadataBlock = strings.Join(metadataSections, "\n") + "\n" } yamlTemplate := "" yamlTemplate = `- name: "%s" url: "github:mudler/LocalAI/gallery/virtual.yaml@master" urls: - https://huggingface.co/%s description: | %s%s overrides: %s files: %s` // Trim trailing newlines from formatted sections to prevent extra blank lines formattedDescription = strings.TrimRight(formattedDescription, "\n") configFile = strings.TrimRight(configFile, "\n") files = strings.TrimRight(files, "\n") // Add newline before metadata block if present if metadataBlock != "" { metadataBlock = "\n" + strings.TrimRight(metadataBlock, "\n") } return fmt.Sprintf(yamlTemplate, modelName, model.ModelID, formattedDescription, metadataBlock, configFile, files, ) } // generateYAMLForModels generates YAML entries for selected models and appends to index.yaml func generateYAMLForModels(ctx context.Context, models []ProcessedModel, quantization string) error { // Generate YAML entries for each model var yamlEntries []string for _, model := range models { fmt.Printf("Generating YAML entry for model: %s\n", model.ModelID) // Generate YAML entry yamlEntry := generateYAMLEntry(model, quantization) yamlEntries = append(yamlEntries, yamlEntry) } // Prepend to index.yaml (write at the top) if len(yamlEntries) > 0 { indexPath := getGalleryIndexPath() fmt.Printf("Prepending YAML entries to %s...\n", indexPath) // Read current content content, err := os.ReadFile(indexPath) if err != nil { return fmt.Errorf("failed to read %s: %w", indexPath, err) } existingContent := string(content) yamlBlock := strings.Join(yamlEntries, "\n") // Check if file starts with "---" var newContent string if strings.HasPrefix(existingContent, "---\n") { // File starts with "---", prepend new entries after it restOfContent := strings.TrimPrefix(existingContent, "---\n") // Ensure proper spacing: "---\n" + new entries + "\n" + rest of content newContent = "---\n" + yamlBlock + "\n" + restOfContent } else if strings.HasPrefix(existingContent, "---") { // File starts with "---" but no newline after restOfContent := strings.TrimPrefix(existingContent, "---") newContent = "---\n" + yamlBlock + "\n" + strings.TrimPrefix(restOfContent, "\n") } else { // No "---" at start, prepend new entries at the very beginning // Trim leading whitespace from existing content existingContent = strings.TrimLeft(existingContent, " \t\n\r") newContent = yamlBlock + "\n" + existingContent } // Write back to file err = os.WriteFile(indexPath, []byte(newContent), 0644) if err != nil { return fmt.Errorf("failed to write %s: %w", indexPath, err) } fmt.Printf("Successfully prepended %d models to %s\n", len(yamlEntries), indexPath) } return nil } ================================================ FILE: .github/gallery-agent/main.go ================================================ package main import ( "context" "encoding/json" "fmt" "os" "strconv" "strings" "time" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" ) // ProcessedModelFile represents a processed model file with additional metadata type ProcessedModelFile struct { Path string `json:"path"` Size int64 `json:"size"` SHA256 string `json:"sha256"` IsReadme bool `json:"is_readme"` FileType string `json:"file_type"` // "model", "readme", "other" } // ProcessedModel represents a processed model with all gathered metadata type ProcessedModel struct { ModelID string `json:"model_id"` Author string `json:"author"` Downloads int `json:"downloads"` LastModified string `json:"last_modified"` Files []ProcessedModelFile `json:"files"` PreferredModelFile *ProcessedModelFile `json:"preferred_model_file,omitempty"` ReadmeFile *ProcessedModelFile `json:"readme_file,omitempty"` ReadmeContent string `json:"readme_content,omitempty"` ReadmeContentPreview string `json:"readme_content_preview,omitempty"` QuantizationPreferences []string `json:"quantization_preferences"` ProcessingError string `json:"processing_error,omitempty"` Tags []string `json:"tags,omitempty"` License string `json:"license,omitempty"` Icon string `json:"icon,omitempty"` } // SearchResult represents the complete result of searching and processing models type SearchResult struct { SearchTerm string `json:"search_term"` Limit int `json:"limit"` Quantization string `json:"quantization"` TotalModelsFound int `json:"total_models_found"` Models []ProcessedModel `json:"models"` FormattedOutput string `json:"formatted_output"` } // AddedModelSummary represents a summary of models added to the gallery type AddedModelSummary struct { SearchTerm string `json:"search_term"` TotalFound int `json:"total_found"` ModelsAdded int `json:"models_added"` AddedModelIDs []string `json:"added_model_ids"` AddedModelURLs []string `json:"added_model_urls"` Quantization string `json:"quantization"` ProcessingTime string `json:"processing_time"` } func main() { startTime := time.Now() // Check for synthetic mode syntheticMode := os.Getenv("SYNTHETIC_MODE") if syntheticMode == "true" || syntheticMode == "1" { fmt.Println("Running in SYNTHETIC MODE - generating random test data") err := runSyntheticMode() if err != nil { fmt.Fprintf(os.Stderr, "Error in synthetic mode: %v\n", err) os.Exit(1) } return } // Get configuration from environment variables searchTerm := os.Getenv("SEARCH_TERM") if searchTerm == "" { searchTerm = "GGUF" } limitStr := os.Getenv("LIMIT") if limitStr == "" { limitStr = "5" } limit, err := strconv.Atoi(limitStr) if err != nil { fmt.Fprintf(os.Stderr, "Error parsing LIMIT: %v\n", err) os.Exit(1) } quantization := os.Getenv("QUANTIZATION") maxModels := os.Getenv("MAX_MODELS") if maxModels == "" { maxModels = "1" } maxModelsInt, err := strconv.Atoi(maxModels) if err != nil { fmt.Fprintf(os.Stderr, "Error parsing MAX_MODELS: %v\n", err) os.Exit(1) } // Print configuration fmt.Printf("Gallery Agent Configuration:\n") fmt.Printf(" Search Term: %s\n", searchTerm) fmt.Printf(" Limit: %d\n", limit) fmt.Printf(" Quantization: %s\n", quantization) fmt.Printf(" Max Models to Add: %d\n", maxModelsInt) fmt.Printf(" Gallery Index Path: %s\n", os.Getenv("GALLERY_INDEX_PATH")) fmt.Println() result, err := searchAndProcessModels(searchTerm, limit, quantization) if err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } fmt.Println(result.FormattedOutput) var models []ProcessedModel if len(result.Models) > 1 { fmt.Println("More than one model found (", len(result.Models), "), using AI agent to select the most interesting models") for _, model := range result.Models { fmt.Println("Model: ", model.ModelID) } // Use AI agent to select the most interesting models fmt.Println("Using AI agent to select the most interesting models...") models, err = selectMostInterestingModels(context.Background(), result) if err != nil { fmt.Fprintf(os.Stderr, "Error in model selection: %v\n", err) // Continue with original result if selection fails models = result.Models } } else if len(result.Models) == 1 { models = result.Models fmt.Println("Only one model found, using it directly") } fmt.Print(models) // Filter out models that already exist in the gallery fmt.Println("Filtering out existing models...") models, err = filterExistingModels(models) if err != nil { fmt.Fprintf(os.Stderr, "Error filtering existing models: %v\n", err) os.Exit(1) } // Limit to maxModelsInt after filtering if len(models) > maxModelsInt { models = models[:maxModelsInt] } // Track added models for summary var addedModelIDs []string var addedModelURLs []string // Generate YAML entries and append to gallery/index.yaml if len(models) > 0 { for _, model := range models { addedModelIDs = append(addedModelIDs, model.ModelID) // Generate Hugging Face URL for the model modelURL := fmt.Sprintf("https://huggingface.co/%s", model.ModelID) addedModelURLs = append(addedModelURLs, modelURL) } fmt.Println("Generating YAML entries for selected models...") err = generateYAMLForModels(context.Background(), models, quantization) if err != nil { fmt.Fprintf(os.Stderr, "Error generating YAML entries: %v\n", err) os.Exit(1) } } else { fmt.Println("No new models to add to the gallery.") } // Create and write summary processingTime := time.Since(startTime).String() summary := AddedModelSummary{ SearchTerm: searchTerm, TotalFound: result.TotalModelsFound, ModelsAdded: len(addedModelIDs), AddedModelIDs: addedModelIDs, AddedModelURLs: addedModelURLs, Quantization: quantization, ProcessingTime: processingTime, } // Write summary to file summaryData, err := json.MarshalIndent(summary, "", " ") if err != nil { fmt.Fprintf(os.Stderr, "Error marshaling summary: %v\n", err) } else { err = os.WriteFile("gallery-agent-summary.json", summaryData, 0644) if err != nil { fmt.Fprintf(os.Stderr, "Error writing summary file: %v\n", err) } else { fmt.Printf("Summary written to gallery-agent-summary.json\n") } } } func searchAndProcessModels(searchTerm string, limit int, quantization string) (*SearchResult, error) { client := hfapi.NewClient() var outputBuilder strings.Builder fmt.Println("Searching for models...") // Initialize the result struct result := &SearchResult{ SearchTerm: searchTerm, Limit: limit, Quantization: quantization, Models: []ProcessedModel{}, } models, err := client.GetLatest(searchTerm, limit) if err != nil { return nil, fmt.Errorf("failed to fetch models: %w", err) } fmt.Println("Models found:", len(models)) result.TotalModelsFound = len(models) if len(models) == 0 { outputBuilder.WriteString("No models found.\n") result.FormattedOutput = outputBuilder.String() return result, nil } outputBuilder.WriteString(fmt.Sprintf("Found %d models matching '%s':\n\n", len(models), searchTerm)) // Process each model for i, model := range models { outputBuilder.WriteString(fmt.Sprintf("%d. Processing Model: %s\n", i+1, model.ModelID)) outputBuilder.WriteString(fmt.Sprintf(" Author: %s\n", model.Author)) outputBuilder.WriteString(fmt.Sprintf(" Downloads: %d\n", model.Downloads)) outputBuilder.WriteString(fmt.Sprintf(" Last Modified: %s\n", model.LastModified)) // Initialize processed model struct processedModel := ProcessedModel{ ModelID: model.ModelID, Author: model.Author, Downloads: model.Downloads, LastModified: model.LastModified, QuantizationPreferences: []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"}, } // Get detailed model information details, err := client.GetModelDetails(model.ModelID) if err != nil { errorMsg := fmt.Sprintf(" Error getting model details: %v\n", err) outputBuilder.WriteString(errorMsg) processedModel.ProcessingError = err.Error() result.Models = append(result.Models, processedModel) continue } // Define quantization preferences (in order of preference) quantizationPreferences := []string{quantization, "Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"} // Find preferred model file preferredModelFile := hfapi.FindPreferredModelFile(details.Files, quantizationPreferences) // Process files processedFiles := make([]ProcessedModelFile, len(details.Files)) for j, file := range details.Files { fileType := "other" if file.IsReadme { fileType = "readme" } else if preferredModelFile != nil && file.Path == preferredModelFile.Path { fileType = "model" } processedFiles[j] = ProcessedModelFile{ Path: file.Path, Size: file.Size, SHA256: file.SHA256, IsReadme: file.IsReadme, FileType: fileType, } } processedModel.Files = processedFiles // Set preferred model file if preferredModelFile != nil { for _, file := range processedFiles { if file.Path == preferredModelFile.Path { processedModel.PreferredModelFile = &file break } } } // Print file information outputBuilder.WriteString(fmt.Sprintf(" Files found: %d\n", len(details.Files))) if preferredModelFile != nil { outputBuilder.WriteString(fmt.Sprintf(" Preferred Model File: %s (SHA256: %s)\n", preferredModelFile.Path, preferredModelFile.SHA256)) } else { outputBuilder.WriteString(fmt.Sprintf(" No model file found with quantization preferences: %v\n", quantizationPreferences)) } if details.ReadmeFile != nil { outputBuilder.WriteString(fmt.Sprintf(" README File: %s\n", details.ReadmeFile.Path)) // Find and set readme file for _, file := range processedFiles { if file.IsReadme { processedModel.ReadmeFile = &file break } } fmt.Println("Getting real readme for", model.ModelID, "waiting...") // Use agent to get the real readme and prepare the model description readmeContent, err := getRealReadme(context.Background(), model.ModelID) if err == nil { processedModel.ReadmeContent = readmeContent processedModel.ReadmeContentPreview = truncateString(readmeContent, 200) outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n", processedModel.ReadmeContentPreview)) } else { fmt.Printf(" Warning: Failed to get real readme: %v\n", err) } fmt.Println("Real readme got", readmeContent) // Extract metadata (tags, license) from README using LLM fmt.Println("Extracting metadata for", model.ModelID, "waiting...") tags, license, err := extractModelMetadata(context.Background(), processedModel) if err == nil { processedModel.Tags = tags processedModel.License = license outputBuilder.WriteString(fmt.Sprintf(" Tags: %v\n", tags)) outputBuilder.WriteString(fmt.Sprintf(" License: %s\n", license)) } else { fmt.Printf(" Warning: Failed to extract metadata: %v\n", err) } // Extract icon from README or use HuggingFace avatar icon := extractModelIcon(processedModel) if icon != "" { processedModel.Icon = icon outputBuilder.WriteString(fmt.Sprintf(" Icon: %s\n", icon)) } // Get README content // readmeContent, err := client.GetReadmeContent(model.ModelID, details.ReadmeFile.Path) // if err == nil { // processedModel.ReadmeContent = readmeContent // processedModel.ReadmeContentPreview = truncateString(readmeContent, 200) // outputBuilder.WriteString(fmt.Sprintf(" README Content Preview: %s\n", // processedModel.ReadmeContentPreview)) // } } // Print all files with their checksums outputBuilder.WriteString(" All Files:\n") for _, file := range processedFiles { outputBuilder.WriteString(fmt.Sprintf(" - %s (%s, %d bytes", file.Path, file.FileType, file.Size)) if file.SHA256 != "" { outputBuilder.WriteString(fmt.Sprintf(", SHA256: %s", file.SHA256)) } outputBuilder.WriteString(")\n") } outputBuilder.WriteString("\n") result.Models = append(result.Models, processedModel) } result.FormattedOutput = outputBuilder.String() return result, nil } func truncateString(s string, maxLen int) string { if len(s) <= maxLen { return s } return s[:maxLen] + "..." } ================================================ FILE: .github/gallery-agent/testing.go ================================================ package main import ( "context" "fmt" "math/rand" "strings" "time" ) // runSyntheticMode generates synthetic test data and appends it to the gallery func runSyntheticMode() error { generator := NewSyntheticDataGenerator() // Generate a random number of synthetic models (1-3) numModels := generator.rand.Intn(3) + 1 fmt.Printf("Generating %d synthetic models for testing...\n", numModels) var models []ProcessedModel for i := 0; i < numModels; i++ { model := generator.GenerateProcessedModel() models = append(models, model) fmt.Printf("Generated synthetic model: %s\n", model.ModelID) } // Generate YAML entries and append to gallery/index.yaml fmt.Println("Generating YAML entries for synthetic models...") err := generateYAMLForModels(context.Background(), models, "Q4_K_M") if err != nil { return fmt.Errorf("error generating YAML entries: %w", err) } fmt.Printf("Successfully added %d synthetic models to the gallery for testing!\n", len(models)) return nil } // SyntheticDataGenerator provides methods to generate synthetic test data type SyntheticDataGenerator struct { rand *rand.Rand } // NewSyntheticDataGenerator creates a new synthetic data generator func NewSyntheticDataGenerator() *SyntheticDataGenerator { return &SyntheticDataGenerator{ rand: rand.New(rand.NewSource(time.Now().UnixNano())), } } // GenerateProcessedModelFile creates a synthetic ProcessedModelFile func (g *SyntheticDataGenerator) GenerateProcessedModelFile() ProcessedModelFile { fileTypes := []string{"model", "readme", "other"} fileType := fileTypes[g.rand.Intn(len(fileTypes))] var path string var isReadme bool switch fileType { case "model": path = fmt.Sprintf("model-%s.gguf", g.randomString(8)) isReadme = false case "readme": path = "README.md" isReadme = true default: path = fmt.Sprintf("file-%s.txt", g.randomString(6)) isReadme = false } return ProcessedModelFile{ Path: path, Size: int64(g.rand.Intn(1000000000) + 1000000), // 1MB to 1GB SHA256: g.randomSHA256(), IsReadme: isReadme, FileType: fileType, } } // GenerateProcessedModel creates a synthetic ProcessedModel func (g *SyntheticDataGenerator) GenerateProcessedModel() ProcessedModel { authors := []string{"microsoft", "meta", "google", "openai", "anthropic", "mistralai", "huggingface"} modelNames := []string{"llama", "gpt", "claude", "mistral", "gemma", "phi", "qwen", "codellama"} author := authors[g.rand.Intn(len(authors))] modelName := modelNames[g.rand.Intn(len(modelNames))] modelID := fmt.Sprintf("%s/%s-%s", author, modelName, g.randomString(6)) // Generate files numFiles := g.rand.Intn(5) + 2 // 2-6 files files := make([]ProcessedModelFile, numFiles) // Ensure at least one model file and one readme hasModelFile := false hasReadme := false for i := 0; i < numFiles; i++ { files[i] = g.GenerateProcessedModelFile() if files[i].FileType == "model" { hasModelFile = true } if files[i].FileType == "readme" { hasReadme = true } } // Add required files if missing if !hasModelFile { modelFile := g.GenerateProcessedModelFile() modelFile.FileType = "model" modelFile.Path = fmt.Sprintf("%s-Q4_K_M.gguf", modelName) files = append(files, modelFile) } if !hasReadme { readmeFile := g.GenerateProcessedModelFile() readmeFile.FileType = "readme" readmeFile.Path = "README.md" readmeFile.IsReadme = true files = append(files, readmeFile) } // Find preferred model file var preferredModelFile *ProcessedModelFile for i := range files { if files[i].FileType == "model" { preferredModelFile = &files[i] break } } // Find readme file var readmeFile *ProcessedModelFile for i := range files { if files[i].FileType == "readme" { readmeFile = &files[i] break } } readmeContent := g.generateReadmeContent(modelName, author) // Generate sample metadata licenses := []string{"apache-2.0", "mit", "llama2", "gpl-3.0", "bsd", ""} license := licenses[g.rand.Intn(len(licenses))] sampleTags := []string{"llm", "gguf", "gpu", "cpu", "text-to-text", "chat", "instruction-tuned"} numTags := g.rand.Intn(4) + 3 // 3-6 tags tags := make([]string, numTags) for i := 0; i < numTags; i++ { tags[i] = sampleTags[g.rand.Intn(len(sampleTags))] } // Remove duplicates tags = g.removeDuplicates(tags) // Optionally include icon (50% chance) icon := "" if g.rand.Intn(2) == 0 { icon = fmt.Sprintf("https://cdn-avatars.huggingface.co/v1/production/uploads/%s.png", g.randomString(24)) } return ProcessedModel{ ModelID: modelID, Author: author, Downloads: g.rand.Intn(1000000) + 1000, LastModified: g.randomDate(), Files: files, PreferredModelFile: preferredModelFile, ReadmeFile: readmeFile, ReadmeContent: readmeContent, ReadmeContentPreview: truncateString(readmeContent, 200), QuantizationPreferences: []string{"Q4_K_M", "Q4_K_S", "Q3_K_M", "Q2_K"}, ProcessingError: "", Tags: tags, License: license, Icon: icon, } } // Helper methods for synthetic data generation func (g *SyntheticDataGenerator) randomString(length int) string { const charset = "abcdefghijklmnopqrstuvwxyz0123456789" b := make([]byte, length) for i := range b { b[i] = charset[g.rand.Intn(len(charset))] } return string(b) } func (g *SyntheticDataGenerator) randomSHA256() string { const charset = "0123456789abcdef" b := make([]byte, 64) for i := range b { b[i] = charset[g.rand.Intn(len(charset))] } return string(b) } func (g *SyntheticDataGenerator) randomDate() string { now := time.Now() daysAgo := g.rand.Intn(365) // Random date within last year pastDate := now.AddDate(0, 0, -daysAgo) return pastDate.Format("2006-01-02T15:04:05.000Z") } func (g *SyntheticDataGenerator) removeDuplicates(slice []string) []string { keys := make(map[string]bool) result := []string{} for _, item := range slice { if !keys[item] { keys[item] = true result = append(result, item) } } return result } func (g *SyntheticDataGenerator) generateReadmeContent(modelName, author string) string { templates := []string{ fmt.Sprintf("# %s Model\n\nThis is a %s model developed by %s. It's designed for various natural language processing tasks including text generation, question answering, and conversation.\n\n## Features\n\n- High-quality text generation\n- Efficient inference\n- Multiple quantization options\n- Easy to use with LocalAI\n\n## Usage\n\nUse this model with LocalAI for various AI tasks.", strings.Title(modelName), modelName, author), fmt.Sprintf("# %s\n\nA powerful language model from %s. This model excels at understanding and generating human-like text across multiple domains.\n\n## Capabilities\n\n- Text completion\n- Code generation\n- Creative writing\n- Technical documentation\n\n## Model Details\n\n- Architecture: Transformer-based\n- Training: Large-scale supervised learning\n- Quantization: Available in multiple formats", strings.Title(modelName), author), fmt.Sprintf("# %s Language Model\n\nDeveloped by %s, this model represents state-of-the-art performance in natural language understanding and generation.\n\n## Key Features\n\n- Multilingual support\n- Context-aware responses\n- Efficient memory usage\n- Fast inference speed\n\n## Applications\n\n- Chatbots and virtual assistants\n- Content generation\n- Code completion\n- Educational tools", strings.Title(modelName), author), } return templates[g.rand.Intn(len(templates))] } ================================================ FILE: .github/gallery-agent/tools.go ================================================ package main import ( "fmt" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" openai "github.com/sashabaranov/go-openai" jsonschema "github.com/sashabaranov/go-openai/jsonschema" ) // Get repository README from HF type HFReadmeTool struct { client *hfapi.Client } func (s *HFReadmeTool) Execute(args map[string]any) (string, any, error) { q, ok := args["repository"].(string) if !ok { return "", nil, fmt.Errorf("no query") } readme, err := s.client.GetReadmeContent(q, "README.md") if err != nil { return "", nil, err } return readme, nil, nil } func (s *HFReadmeTool) Tool() openai.Tool { return openai.Tool{ Type: openai.ToolTypeFunction, Function: &openai.FunctionDefinition{ Name: "hf_readme", Description: "A tool to get the README content of a huggingface repository", Parameters: jsonschema.Definition{ Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ "repository": { Type: jsonschema.String, Description: "The huggingface repository to get the README content of", }, }, Required: []string{"repository"}, }, }, } } ================================================ FILE: .github/labeler.yml ================================================ enhancement: - head-branch: ['^feature', 'feature'] dependencies: - any: - changed-files: - any-glob-to-any-file: 'Makefile' - changed-files: - any-glob-to-any-file: '*.mod' - changed-files: - any-glob-to-any-file: '*.sum' kind/documentation: - any: - changed-files: - any-glob-to-any-file: 'docs/*' - changed-files: - any-glob-to-any-file: '*.md' area/ai-model: - any: - changed-files: - any-glob-to-any-file: 'gallery/*' examples: - any: - changed-files: - any-glob-to-any-file: 'examples/*' ci: - any: - changed-files: - any-glob-to-any-file: '.github/*' ================================================ FILE: .github/release.yml ================================================ # .github/release.yml changelog: exclude: labels: - ignore-for-release categories: - title: Breaking Changes 🛠 labels: - Semver-Major - breaking-change - title: "Bug fixes :bug:" labels: - bug - regression - title: "🖧 P2P area" labels: - area/p2p - title: Exciting New Features 🎉 labels: - Semver-Minor - enhancement - ux - roadmap - title: 🧠 Models labels: - area/ai-model - title: 📖 Documentation and examples labels: - kind/documentation - examples - title: 👒 Dependencies labels: - dependencies - title: Other Changes labels: - "*" ================================================ FILE: .github/stale.yml ================================================ # Number of days of inactivity before an issue becomes stale daysUntilStale: 45 # Number of days of inactivity before a stale issue is closed daysUntilClose: 10 # Issues with these labels will never be considered stale exemptLabels: - issue/willfix # Label to use when marking an issue as stale staleLabel: issue/stale # Comment to post when marking an issue as stale. Set to `false` to disable markComment: > This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions. # Comment to post when closing a stale issue. Set to `false` to disable closeComment: > This issue is being automatically closed due to inactivity. However, you may choose to reopen this issue. ================================================ FILE: .github/workflows/backend.yml ================================================ --- name: 'build backend container images' on: push: branches: - master tags: - '*' concurrency: group: ci-backends-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: backend-jobs: if: github.repository == 'mudler/LocalAI' uses: ./.github/workflows/backend_build.yml with: tag-latest: ${{ matrix.tag-latest }} tag-suffix: ${{ matrix.tag-suffix }} build-type: ${{ matrix.build-type }} cuda-major-version: ${{ matrix.cuda-major-version }} cuda-minor-version: ${{ matrix.cuda-minor-version }} platforms: ${{ matrix.platforms }} runs-on: ${{ matrix.runs-on }} base-image: ${{ matrix.base-image }} backend: ${{ matrix.backend }} dockerfile: ${{ matrix.dockerfile }} skip-drivers: ${{ matrix.skip-drivers }} context: ${{ matrix.context }} ubuntu-version: ${{ matrix.ubuntu-version }} secrets: dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} strategy: fail-fast: false #max-parallel: ${{ github.event_name != 'pull_request' && 6 || 4 }} matrix: include: - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-diffusers' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "diffusers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-diffusers' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "diffusers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-chatterbox' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "chatterbox" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-moonshine' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "moonshine" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-whisperx' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "whisperx" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-ace-step' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "ace-step" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-mlx' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "mlx" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-mlx-vlm' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "mlx-vlm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-mlx-audio' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "mlx-audio" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-mlx-distributed' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "mlx-distributed" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' # CUDA 12 builds - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-vibevoice' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "vibevoice" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-qwen-asr' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "qwen-asr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-nemo' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "nemo" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-qwen-tts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "qwen-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-fish-speech' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "fish-speech" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-faster-qwen3-tts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "faster-qwen3-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-voxcpm' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "voxcpm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-pocket-tts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "pocket-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-rerankers' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "rerankers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-llama-cpp' runs-on: 'bigger-runner' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "llama-cpp" dockerfile: "./backend/Dockerfile.llama-cpp" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-vllm' runs-on: 'arc-runner-set' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "vllm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-vllm-omni' runs-on: 'arc-runner-set' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "vllm-omni" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-transformers' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "transformers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-diffusers' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "diffusers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-ace-step' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "ace-step" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-kokoro' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "kokoro" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-faster-whisper' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "faster-whisper" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-whisperx' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "whisperx" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "9" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-coqui' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "coqui" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-outetts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "outetts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-chatterbox' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "chatterbox" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-moonshine' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "moonshine" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-mlx' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "mlx" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-mlx-vlm' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "mlx-vlm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-mlx-audio' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "mlx-audio" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-mlx-distributed' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "mlx-distributed" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-stablediffusion-ggml' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "stablediffusion-ggml" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-whisper' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "whisper" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-acestep-cpp' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "acestep-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-rfdetr' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "rfdetr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12-neutts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "neutts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' # cuda 13 - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-rerankers' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "rerankers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-vibevoice' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "vibevoice" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-qwen-asr' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "qwen-asr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-nemo' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "nemo" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-qwen-tts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "qwen-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-fish-speech' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "fish-speech" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-faster-qwen3-tts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "faster-qwen3-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-voxcpm' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "voxcpm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-pocket-tts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "pocket-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-llama-cpp' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "llama-cpp" dockerfile: "./backend/Dockerfile.llama-cpp" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'false' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-llama-cpp' base-image: "ubuntu:24.04" runs-on: 'ubuntu-24.04-arm' ubuntu-version: '2404' backend: "llama-cpp" dockerfile: "./backend/Dockerfile.llama-cpp" context: "./" - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-transformers' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "transformers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-diffusers' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "diffusers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-ace-step' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "ace-step" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-vibevoice' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "vibevoice" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-qwen-asr' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "qwen-asr" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-qwen-tts' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "qwen-tts" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-fish-speech' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "fish-speech" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "faster-qwen3-tts" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-pocket-tts' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "pocket-tts" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-chatterbox' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "chatterbox" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-diffusers' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "diffusers" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-mlx' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "mlx" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-mlx-vlm' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "mlx-vlm" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-mlx-audio' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "mlx-audio" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'l4t' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-mlx-distributed' runs-on: 'ubuntu-24.04-arm' base-image: "ubuntu:24.04" skip-drivers: 'false' ubuntu-version: '2404' backend: "mlx-distributed" dockerfile: "./backend/Dockerfile.python" context: "./" - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-kokoro' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "kokoro" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-faster-whisper' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "faster-whisper" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-whisperx' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "whisperx" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-chatterbox' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "chatterbox" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-moonshine' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "moonshine" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-mlx' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "mlx" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-mlx-vlm' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "mlx-vlm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-mlx-audio' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "mlx-audio" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-mlx-distributed' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "mlx-distributed" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-stablediffusion-ggml' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "stablediffusion-ggml" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'false' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml' base-image: "ubuntu:24.04" ubuntu-version: '2404' runs-on: 'ubuntu-24.04-arm' backend: "stablediffusion-ggml" dockerfile: "./backend/Dockerfile.golang" context: "./" - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-whisper' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "whisper" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'false' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-whisper' base-image: "ubuntu:24.04" ubuntu-version: '2404' runs-on: 'ubuntu-24.04-arm' backend: "whisper" dockerfile: "./backend/Dockerfile.golang" context: "./" - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-acestep-cpp' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "acestep-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'false' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-cuda-13-arm64-acestep-cpp' base-image: "ubuntu:24.04" ubuntu-version: '2404' runs-on: 'ubuntu-24.04-arm' backend: "acestep-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13-rfdetr' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "rfdetr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' # hipblas builds - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-rerankers' runs-on: 'ubuntu-latest' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "rerankers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-llama-cpp' runs-on: 'ubuntu-latest' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "llama-cpp" dockerfile: "./backend/Dockerfile.llama-cpp" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-vllm' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "vllm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-vllm-omni' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "vllm-omni" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-transformers' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "transformers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-diffusers' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "diffusers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-ace-step' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "ace-step" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' # ROCm additional backends - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-kokoro' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "kokoro" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-vibevoice' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "vibevoice" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-qwen-asr' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "qwen-asr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-nemo' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "nemo" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-qwen-tts' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "qwen-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-fish-speech' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "fish-speech" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-voxcpm' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "voxcpm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-pocket-tts' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "pocket-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-faster-whisper' runs-on: 'bigger-runner' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "faster-whisper" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-whisperx' runs-on: 'bigger-runner' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "whisperx" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-coqui' runs-on: 'bigger-runner' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "coqui" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' # sycl builds - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-rerankers' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "rerankers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'sycl_f32' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-sycl-f32-llama-cpp' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "llama-cpp" dockerfile: "./backend/Dockerfile.llama-cpp" context: "./" ubuntu-version: '2404' - build-type: 'sycl_f16' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-sycl-f16-llama-cpp' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "llama-cpp" dockerfile: "./backend/Dockerfile.llama-cpp" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-vllm' runs-on: 'arc-runner-set' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "vllm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-transformers' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "transformers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-diffusers' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "diffusers" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-ace-step' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "ace-step" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-vibevoice' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "vibevoice" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-qwen-asr' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "qwen-asr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-qwen-tts' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "qwen-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-fish-speech' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "fish-speech" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-faster-qwen3-tts' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "faster-qwen3-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-pocket-tts' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "pocket-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-kokoro' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "kokoro" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-mlx' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "mlx" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-mlx-vlm' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "mlx-vlm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-mlx-audio' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "mlx-audio" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-mlx-distributed' runs-on: 'ubuntu-24.04-arm' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" skip-drivers: 'true' backend: "mlx-distributed" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' # SYCL additional backends - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-kokoro' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "kokoro" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-faster-whisper' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "faster-whisper" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-vibevoice' runs-on: 'arc-runner-set' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "vibevoice" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-qwen-asr' runs-on: 'arc-runner-set' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "qwen-asr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-nemo' runs-on: 'arc-runner-set' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "nemo" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-qwen-tts' runs-on: 'arc-runner-set' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "qwen-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-fish-speech' runs-on: 'arc-runner-set' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "fish-speech" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-voxcpm' runs-on: 'arc-runner-set' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "voxcpm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-pocket-tts' runs-on: 'arc-runner-set' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "pocket-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-coqui' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "coqui" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' # piper - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-piper' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "piper" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-llama-cpp' runs-on: 'bigger-runner' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "llama-cpp" dockerfile: "./backend/Dockerfile.llama-cpp" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'false' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-arm64-llama-cpp' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" runs-on: 'ubuntu-24.04-arm' backend: "llama-cpp" dockerfile: "./backend/Dockerfile.llama-cpp" context: "./" ubuntu-version: '2204' - build-type: 'vulkan' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-gpu-vulkan-llama-cpp' runs-on: 'bigger-runner' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "llama-cpp" dockerfile: "./backend/Dockerfile.llama-cpp" context: "./" ubuntu-version: '2404' # Stablediffusion-ggml - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-stablediffusion-ggml' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "stablediffusion-ggml" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'sycl_f32' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-sycl-f32-stablediffusion-ggml' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "stablediffusion-ggml" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'sycl_f16' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-sycl-f16-stablediffusion-ggml' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "stablediffusion-ggml" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'vulkan' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-gpu-vulkan-stablediffusion-ggml' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "stablediffusion-ggml" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'false' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-arm64-stablediffusion-ggml' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" runs-on: 'ubuntu-24.04-arm' backend: "stablediffusion-ggml" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2204' # whisper - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-whisper' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "whisper" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'sycl_f32' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-sycl-f32-whisper' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "whisper" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'sycl_f16' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-sycl-f16-whisper' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "whisper" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'vulkan' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-gpu-vulkan-whisper' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "whisper" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'false' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-arm64-whisper' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" runs-on: 'ubuntu-24.04-arm' backend: "whisper" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2204' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-whisper' base-image: "rocm/dev-ubuntu-24.04:6.4.4" runs-on: 'ubuntu-latest' skip-drivers: 'false' backend: "whisper" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' # acestep-cpp - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-acestep-cpp' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "acestep-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'sycl_f32' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-sycl-f32-acestep-cpp' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "acestep-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'sycl_f16' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-sycl-f16-acestep-cpp' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "acestep-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'vulkan' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-gpu-vulkan-acestep-cpp' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "acestep-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'false' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-arm64-acestep-cpp' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" runs-on: 'ubuntu-24.04-arm' backend: "acestep-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2204' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-acestep-cpp' base-image: "rocm/dev-ubuntu-24.04:6.4.4" runs-on: 'ubuntu-latest' skip-drivers: 'false' backend: "acestep-cpp" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' # voxtral - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-voxtral' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "voxtral" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' #opus - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-opus' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "opus" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' #silero-vad - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-silero-vad' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "silero-vad" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' # local-store - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-local-store' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "local-store" dockerfile: "./backend/Dockerfile.golang" context: "./" ubuntu-version: '2404' # rfdetr - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-rfdetr' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "rfdetr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'intel' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-intel-rfdetr' runs-on: 'ubuntu-latest' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" skip-drivers: 'false' backend: "rfdetr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'true' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-arm64-rfdetr' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" runs-on: 'ubuntu-24.04-arm' backend: "rfdetr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' - build-type: 'l4t' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' skip-drivers: 'true' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-arm64-chatterbox' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" runs-on: 'ubuntu-24.04-arm' backend: "chatterbox" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2204' # runs out of space on the runner # - build-type: 'hipblas' # cuda-major-version: "" # cuda-minor-version: "" # platforms: 'linux/amd64' # tag-latest: 'auto' # tag-suffix: '-gpu-hipblas-rfdetr' # base-image: "rocm/dev-ubuntu-24.04:6.4.4" # runs-on: 'ubuntu-latest' # skip-drivers: 'false' # backend: "rfdetr" # dockerfile: "./backend/Dockerfile.python" # context: "./" # kitten-tts - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-kitten-tts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "kitten-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' # neutts - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-neutts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "neutts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: 'hipblas' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-rocm-hipblas-neutts' runs-on: 'arc-runner-set' base-image: "rocm/dev-ubuntu-24.04:6.4.4" skip-drivers: 'false' backend: "neutts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-vibevoice' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "vibevoice" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-qwen-asr' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "qwen-asr" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-nemo' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "nemo" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-qwen-tts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "qwen-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-fish-speech' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "fish-speech" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-voxcpm' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "voxcpm" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-cpu-pocket-tts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' backend: "pocket-tts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' - build-type: '' cuda-major-version: "" cuda-minor-version: "" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-cpu-outetts' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'true' backend: "outetts" dockerfile: "./backend/Dockerfile.python" context: "./" ubuntu-version: '2404' backend-jobs-darwin: uses: ./.github/workflows/backend_build_darwin.yml strategy: matrix: include: - backend: "diffusers" tag-suffix: "-metal-darwin-arm64-diffusers" build-type: "mps" - backend: "ace-step" tag-suffix: "-metal-darwin-arm64-ace-step" build-type: "mps" - backend: "mlx" tag-suffix: "-metal-darwin-arm64-mlx" build-type: "mps" - backend: "chatterbox" tag-suffix: "-metal-darwin-arm64-chatterbox" build-type: "mps" - backend: "mlx-vlm" tag-suffix: "-metal-darwin-arm64-mlx-vlm" build-type: "mps" - backend: "mlx-audio" tag-suffix: "-metal-darwin-arm64-mlx-audio" build-type: "mps" - backend: "mlx-distributed" tag-suffix: "-metal-darwin-arm64-mlx-distributed" build-type: "mps" - backend: "stablediffusion-ggml" tag-suffix: "-metal-darwin-arm64-stablediffusion-ggml" build-type: "metal" lang: "go" - backend: "whisper" tag-suffix: "-metal-darwin-arm64-whisper" build-type: "metal" lang: "go" - backend: "acestep-cpp" tag-suffix: "-metal-darwin-arm64-acestep-cpp" build-type: "metal" lang: "go" - backend: "voxtral" tag-suffix: "-metal-darwin-arm64-voxtral" build-type: "metal" lang: "go" - backend: "vibevoice" tag-suffix: "-metal-darwin-arm64-vibevoice" build-type: "mps" - backend: "qwen-asr" tag-suffix: "-metal-darwin-arm64-qwen-asr" build-type: "mps" - backend: "nemo" tag-suffix: "-metal-darwin-arm64-nemo" build-type: "mps" - backend: "qwen-tts" tag-suffix: "-metal-darwin-arm64-qwen-tts" build-type: "mps" - backend: "fish-speech" tag-suffix: "-metal-darwin-arm64-fish-speech" build-type: "mps" - backend: "voxcpm" tag-suffix: "-metal-darwin-arm64-voxcpm" build-type: "mps" - backend: "pocket-tts" tag-suffix: "-metal-darwin-arm64-pocket-tts" build-type: "mps" - backend: "moonshine" tag-suffix: "-metal-darwin-arm64-moonshine" build-type: "mps" - backend: "whisperx" tag-suffix: "-metal-darwin-arm64-whisperx" build-type: "mps" - backend: "rerankers" tag-suffix: "-metal-darwin-arm64-rerankers" build-type: "mps" - backend: "transformers" tag-suffix: "-metal-darwin-arm64-transformers" build-type: "mps" - backend: "kokoro" tag-suffix: "-metal-darwin-arm64-kokoro" build-type: "mps" - backend: "faster-whisper" tag-suffix: "-metal-darwin-arm64-faster-whisper" build-type: "mps" - backend: "coqui" tag-suffix: "-metal-darwin-arm64-coqui" build-type: "mps" - backend: "rfdetr" tag-suffix: "-metal-darwin-arm64-rfdetr" build-type: "mps" - backend: "kitten-tts" tag-suffix: "-metal-darwin-arm64-kitten-tts" build-type: "mps" - backend: "piper" tag-suffix: "-metal-darwin-arm64-piper" build-type: "metal" lang: "go" - backend: "opus" tag-suffix: "-metal-darwin-arm64-opus" build-type: "metal" lang: "go" - backend: "silero-vad" tag-suffix: "-metal-darwin-arm64-silero-vad" build-type: "metal" lang: "go" - backend: "local-store" tag-suffix: "-metal-darwin-arm64-local-store" build-type: "metal" lang: "go" with: backend: ${{ matrix.backend }} build-type: ${{ matrix.build-type }} go-version: "1.24.x" tag-suffix: ${{ matrix.tag-suffix }} lang: ${{ matrix.lang || 'python' }} use-pip: ${{ matrix.backend == 'diffusers' }} runs-on: "macos-latest" secrets: dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} llama-cpp-darwin: runs-on: macos-latest strategy: matrix: go-version: ['1.25.x'] steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Setup Go ${{ matrix.go-version }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: false # You can test your matrix by printing the current Go version - name: Display Go version run: go version - name: Dependencies run: | brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm - name: Build llama-cpp-darwin run: | make protogen-go make backends/llama-cpp-darwin - name: Upload llama-cpp.tar uses: actions/upload-artifact@v7 with: name: llama-cpp-tar path: backend-images/llama-cpp.tar llama-cpp-darwin-publish: needs: llama-cpp-darwin if: github.event_name != 'pull_request' runs-on: ubuntu-latest steps: - name: Download llama-cpp.tar uses: actions/download-artifact@v8 with: name: llama-cpp-tar path: . - name: Install crane run: | curl -L https://github.com/google/go-containerregistry/releases/latest/download/go-containerregistry_Linux_x86_64.tar.gz | tar -xz sudo mv crane /usr/local/bin/ - name: Log in to DockerHub run: | echo "${{ secrets.DOCKERHUB_PASSWORD }}" | crane auth login docker.io -u "${{ secrets.DOCKERHUB_USERNAME }}" --password-stdin - name: Log in to quay.io run: | echo "${{ secrets.LOCALAI_REGISTRY_PASSWORD }}" | crane auth login quay.io -u "${{ secrets.LOCALAI_REGISTRY_USERNAME }}" --password-stdin - name: Docker meta id: meta uses: docker/metadata-action@v6 with: images: | localai/localai-backends tags: | type=ref,event=branch type=semver,pattern={{raw}} type=sha flavor: | latest=auto suffix=-metal-darwin-arm64-llama-cpp,onlatest=true - name: Docker meta id: quaymeta uses: docker/metadata-action@v6 with: images: | quay.io/go-skynet/local-ai-backends tags: | type=ref,event=branch type=semver,pattern={{raw}} type=sha flavor: | latest=auto suffix=-metal-darwin-arm64-llama-cpp,onlatest=true - name: Push Docker image (DockerHub) run: | for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr ',' '\n'); do crane push llama-cpp.tar $tag done - name: Push Docker image (Quay) run: | for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do crane push llama-cpp.tar $tag done ================================================ FILE: .github/workflows/backend_build.yml ================================================ --- name: 'build backend container images (reusable)' on: workflow_call: inputs: base-image: description: 'Base image' required: true type: string build-type: description: 'Build type' default: '' type: string cuda-major-version: description: 'CUDA major version' default: "12" type: string cuda-minor-version: description: 'CUDA minor version' default: "1" type: string platforms: description: 'Platforms' default: '' type: string tag-latest: description: 'Tag latest' default: '' type: string tag-suffix: description: 'Tag suffix' default: '' type: string runs-on: description: 'Runs on' required: true default: '' type: string backend: description: 'Backend to build' required: true type: string context: description: 'Build context' required: true type: string dockerfile: description: 'Build Dockerfile' required: true type: string skip-drivers: description: 'Skip drivers' default: 'false' type: string ubuntu-version: description: 'Ubuntu version' required: false default: '2204' type: string secrets: dockerUsername: required: false dockerPassword: required: false quayUsername: required: true quayPassword: required: true jobs: backend-build: runs-on: ${{ inputs.runs-on }} env: quay_username: ${{ secrets.quayUsername }} steps: - name: Free Disk Space (Ubuntu) if: inputs.runs-on == 'ubuntu-latest' uses: jlumbroso/free-disk-space@main with: # this might remove tools that are actually needed, # if set to "true" but frees about 6 GB tool-cache: true # all of these default to true, but feel free to set to # "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true docker-images: true swap-storage: true - name: Force Install GIT latest run: | sudo apt-get update \ && sudo apt-get install -y software-properties-common \ && sudo apt-get update \ && sudo add-apt-repository -y ppa:git-core/ppa \ && sudo apt-get update \ && sudo apt-get install -y git - name: Checkout uses: actions/checkout@v6 - name: Release space from worker if: inputs.runs-on == 'ubuntu-latest' run: | echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo df -h echo sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true sudo apt-get remove --auto-remove android-sdk-platform-tools snapd || true sudo apt-get purge --auto-remove android-sdk-platform-tools snapd || true sudo rm -rf /usr/local/lib/android sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true sudo rm -rf /usr/share/dotnet sudo apt-get remove -y '^mono-.*' || true sudo apt-get remove -y '^ghc-.*' || true sudo apt-get remove -y '.*jdk.*|.*jre.*' || true sudo apt-get remove -y 'php.*' || true sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true sudo apt-get remove -y '^google-.*' || true sudo apt-get remove -y azure-cli || true sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true sudo apt-get remove -y '^gfortran-.*' || true sudo apt-get remove -y microsoft-edge-stable || true sudo apt-get remove -y firefox || true sudo apt-get remove -y powershell || true sudo apt-get remove -y r-base-core || true sudo apt-get autoremove -y sudo apt-get clean echo echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo sudo rm -rfv build || true sudo rm -rf /usr/share/dotnet || true sudo rm -rf /opt/ghc || true sudo rm -rf "/usr/local/share/boost" || true sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true df -h - name: Docker meta id: meta if: github.event_name != 'pull_request' uses: docker/metadata-action@v6 with: images: | quay.io/go-skynet/local-ai-backends localai/localai-backends tags: | type=ref,event=branch type=semver,pattern={{raw}} type=sha flavor: | latest=${{ inputs.tag-latest }} suffix=${{ inputs.tag-suffix }},onlatest=true - name: Docker meta for PR id: meta_pull_request if: github.event_name == 'pull_request' uses: docker/metadata-action@v6 with: images: | quay.io/go-skynet/ci-tests tags: | type=ref,event=branch,suffix=${{ github.event.number }}-${{ inputs.backend }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }} type=semver,pattern={{raw}},suffix=${{ github.event.number }}-${{ inputs.backend }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }} type=sha,suffix=${{ github.event.number }}-${{ inputs.backend }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }} flavor: | latest=${{ inputs.tag-latest }} suffix=${{ inputs.tag-suffix }},onlatest=true ## End testing image - name: Set up QEMU uses: docker/setup-qemu-action@master with: platforms: all - name: Set up Docker Buildx id: buildx uses: docker/setup-buildx-action@master - name: Login to DockerHub if: github.event_name != 'pull_request' uses: docker/login-action@v4 with: username: ${{ secrets.dockerUsername }} password: ${{ secrets.dockerPassword }} - name: Login to Quay.io if: ${{ env.quay_username != '' }} uses: docker/login-action@v4 with: registry: quay.io username: ${{ secrets.quayUsername }} password: ${{ secrets.quayPassword }} - name: Build and push uses: docker/build-push-action@v7 if: github.event_name != 'pull_request' with: builder: ${{ steps.buildx.outputs.name }} build-args: | BUILD_TYPE=${{ inputs.build-type }} SKIP_DRIVERS=${{ inputs.skip-drivers }} CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }} CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }} BASE_IMAGE=${{ inputs.base-image }} BACKEND=${{ inputs.backend }} UBUNTU_VERSION=${{ inputs.ubuntu-version }} context: ${{ inputs.context }} file: ${{ inputs.dockerfile }} cache-from: type=gha platforms: ${{ inputs.platforms }} push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} - name: Build and push (PR) uses: docker/build-push-action@v7 if: github.event_name == 'pull_request' with: builder: ${{ steps.buildx.outputs.name }} build-args: | BUILD_TYPE=${{ inputs.build-type }} SKIP_DRIVERS=${{ inputs.skip-drivers }} CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }} CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }} BASE_IMAGE=${{ inputs.base-image }} BACKEND=${{ inputs.backend }} UBUNTU_VERSION=${{ inputs.ubuntu-version }} context: ${{ inputs.context }} file: ${{ inputs.dockerfile }} cache-from: type=gha platforms: ${{ inputs.platforms }} push: ${{ env.quay_username != '' }} tags: ${{ steps.meta_pull_request.outputs.tags }} labels: ${{ steps.meta_pull_request.outputs.labels }} - name: job summary run: | echo "Built image: ${{ steps.meta.outputs.labels }}" >> $GITHUB_STEP_SUMMARY ================================================ FILE: .github/workflows/backend_build_darwin.yml ================================================ --- name: 'build darwin python backend container images (reusable)' on: workflow_call: inputs: backend: description: 'Backend to build' required: true type: string build-type: description: 'Build type (e.g., mps)' default: '' type: string use-pip: description: 'Use pip to install dependencies' default: false type: boolean lang: description: 'Programming language (e.g. go)' default: 'python' type: string go-version: description: 'Go version to use' default: '1.24.x' type: string tag-suffix: description: 'Tag suffix for the built image' required: true type: string runs-on: description: 'Runner to use' default: 'macOS-14' type: string secrets: dockerUsername: required: false dockerPassword: required: false quayUsername: required: true quayPassword: required: true jobs: darwin-backend-build: runs-on: ${{ inputs.runs-on }} strategy: matrix: go-version: ['${{ inputs.go-version }}'] steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Setup Go ${{ matrix.go-version }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: false # You can test your matrix by printing the current Go version - name: Display Go version run: go version - name: Dependencies run: | brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm - name: Build ${{ inputs.backend }}-darwin run: | make protogen-go BACKEND=${{ inputs.backend }} BUILD_TYPE=${{ inputs.build-type }} USE_PIP=${{ inputs.use-pip }} make build-darwin-${{ inputs.lang }}-backend - name: Upload ${{ inputs.backend }}.tar uses: actions/upload-artifact@v7 with: name: ${{ inputs.backend }}-tar path: backend-images/${{ inputs.backend }}.tar darwin-backend-publish: needs: darwin-backend-build if: github.event_name != 'pull_request' runs-on: ubuntu-latest steps: - name: Download ${{ inputs.backend }}.tar uses: actions/download-artifact@v8 with: name: ${{ inputs.backend }}-tar path: . - name: Install crane run: | curl -L https://github.com/google/go-containerregistry/releases/latest/download/go-containerregistry_Linux_x86_64.tar.gz | tar -xz sudo mv crane /usr/local/bin/ - name: Log in to DockerHub run: | echo "${{ secrets.dockerPassword }}" | crane auth login docker.io -u "${{ secrets.dockerUsername }}" --password-stdin - name: Log in to quay.io run: | echo "${{ secrets.quayPassword }}" | crane auth login quay.io -u "${{ secrets.quayUsername }}" --password-stdin - name: Docker meta id: meta uses: docker/metadata-action@v6 with: images: | localai/localai-backends tags: | type=ref,event=branch type=semver,pattern={{raw}} type=sha flavor: | latest=auto suffix=${{ inputs.tag-suffix }},onlatest=true - name: Docker meta id: quaymeta uses: docker/metadata-action@v6 with: images: | quay.io/go-skynet/local-ai-backends tags: | type=ref,event=branch type=semver,pattern={{raw}} type=sha flavor: | latest=auto suffix=${{ inputs.tag-suffix }},onlatest=true - name: Push Docker image (DockerHub) run: | for tag in $(echo "${{ steps.meta.outputs.tags }}" | tr ',' '\n'); do crane push ${{ inputs.backend }}.tar $tag done - name: Push Docker image (Quay) run: | for tag in $(echo "${{ steps.quaymeta.outputs.tags }}" | tr ',' '\n'); do crane push ${{ inputs.backend }}.tar $tag done ================================================ FILE: .github/workflows/backend_pr.yml ================================================ name: 'build backend container images (PR-filtered)' on: pull_request: concurrency: group: ci-backends-pr-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: generate-matrix: runs-on: ubuntu-latest outputs: matrix: ${{ steps.set-matrix.outputs.matrix }} matrix-darwin: ${{ steps.set-matrix.outputs.matrix-darwin }} has-backends: ${{ steps.set-matrix.outputs.has-backends }} has-backends-darwin: ${{ steps.set-matrix.outputs.has-backends-darwin }} steps: - name: Checkout repository uses: actions/checkout@v6 - name: Setup Bun uses: oven-sh/setup-bun@v2 - name: Install dependencies run: | bun add js-yaml bun add @octokit/core # filters the matrix in backend.yml - name: Filter matrix for changed backends id: set-matrix env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_EVENT_PATH: ${{ github.event_path }} run: bun run scripts/changed-backends.js backend-jobs: needs: generate-matrix uses: ./.github/workflows/backend_build.yml if: needs.generate-matrix.outputs.has-backends == 'true' with: tag-latest: ${{ matrix.tag-latest }} tag-suffix: ${{ matrix.tag-suffix }} build-type: ${{ matrix.build-type }} cuda-major-version: ${{ matrix.cuda-major-version }} cuda-minor-version: ${{ matrix.cuda-minor-version }} platforms: ${{ matrix.platforms }} runs-on: ${{ matrix.runs-on }} base-image: ${{ matrix.base-image }} backend: ${{ matrix.backend }} dockerfile: ${{ matrix.dockerfile }} skip-drivers: ${{ matrix.skip-drivers }} context: ${{ matrix.context }} ubuntu-version: ${{ matrix.ubuntu-version }} secrets: quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} strategy: fail-fast: true matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix) }} backend-jobs-darwin: needs: generate-matrix uses: ./.github/workflows/backend_build_darwin.yml if: needs.generate-matrix.outputs.has-backends-darwin == 'true' with: backend: ${{ matrix.backend }} build-type: ${{ matrix.build-type }} go-version: "1.24.x" tag-suffix: ${{ matrix.tag-suffix }} lang: ${{ matrix.lang || 'python' }} use-pip: ${{ matrix.backend == 'diffusers' }} runs-on: "macos-latest" secrets: quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} strategy: fail-fast: true matrix: ${{ fromJson(needs.generate-matrix.outputs.matrix-darwin) }} ================================================ FILE: .github/workflows/build-test.yaml ================================================ name: Build test on: push: branches: - master pull_request: jobs: build-test: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v5 with: go-version: 1.25 - name: Run GoReleaser run: | make dev-dist launcher-build-darwin: runs-on: macos-latest steps: - name: Checkout uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v5 with: go-version: 1.25 - name: Build launcher for macOS ARM64 run: | make build-launcher-darwin ls -liah dist - name: Upload macOS launcher artifacts uses: actions/upload-artifact@v7 with: name: launcher-macos path: dist/ retention-days: 30 launcher-build-linux: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v5 with: go-version: 1.25 - name: Build launcher for Linux run: | sudo apt-get update sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev make build-launcher-linux - name: Upload Linux launcher artifacts uses: actions/upload-artifact@v7 with: name: launcher-linux path: local-ai-launcher-linux.tar.xz retention-days: 30 ================================================ FILE: .github/workflows/bump_deps.yaml ================================================ name: Bump Backend dependencies on: schedule: - cron: 0 20 * * * workflow_dispatch: jobs: bump-backends: if: github.repository == 'mudler/LocalAI' strategy: fail-fast: false matrix: include: - repository: "ggml-org/llama.cpp" variable: "LLAMA_VERSION" branch: "master" file: "backend/cpp/llama-cpp/Makefile" - repository: "ggml-org/whisper.cpp" variable: "WHISPER_CPP_VERSION" branch: "master" file: "backend/go/whisper/Makefile" - repository: "leejet/stable-diffusion.cpp" variable: "STABLEDIFFUSION_GGML_VERSION" branch: "master" file: "backend/go/stablediffusion-ggml/Makefile" - repository: "mudler/go-piper" variable: "PIPER_VERSION" branch: "master" file: "backend/go/piper/Makefile" - repository: "antirez/voxtral.c" variable: "VOXTRAL_VERSION" branch: "main" file: "backend/go/voxtral/Makefile" - repository: "ace-step/acestep.cpp" variable: "ACESTEP_CPP_VERSION" branch: "master" file: "backend/go/acestep-cpp/Makefile" runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Bump dependencies 🔧 id: bump run: | bash .github/bump_deps.sh ${{ matrix.repository }} ${{ matrix.branch }} ${{ matrix.variable }} ${{ matrix.file }} { echo 'message<> "$GITHUB_OUTPUT" { echo 'commit<> "$GITHUB_OUTPUT" rm -rfv ${{ matrix.variable }}_message.txt rm -rfv ${{ matrix.variable }}_commit.txt - name: Create Pull Request uses: peter-evans/create-pull-request@v8 with: token: ${{ secrets.UPDATE_BOT_TOKEN }} push-to-fork: ci-forks/LocalAI commit-message: ':arrow_up: Update ${{ matrix.repository }}' title: 'chore: :arrow_up: Update ${{ matrix.repository }} to `${{ steps.bump.outputs.commit }}`' branch: "update/${{ matrix.variable }}" body: ${{ steps.bump.outputs.message }} signoff: true ================================================ FILE: .github/workflows/bump_docs.yaml ================================================ name: Bump Documentation on: schedule: - cron: 0 20 * * * workflow_dispatch: jobs: bump-docs: if: github.repository == 'mudler/LocalAI' strategy: fail-fast: false matrix: include: - repository: "mudler/LocalAI" runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Bump dependencies 🔧 run: | bash .github/bump_docs.sh ${{ matrix.repository }} - name: Create Pull Request uses: peter-evans/create-pull-request@v8 with: token: ${{ secrets.UPDATE_BOT_TOKEN }} push-to-fork: ci-forks/LocalAI commit-message: ':arrow_up: Update docs version ${{ matrix.repository }}' title: 'docs: :arrow_up: update docs version ${{ matrix.repository }}' branch: "update/docs" body: Bump of ${{ matrix.repository }} version inside docs signoff: true ================================================ FILE: .github/workflows/checksum_checker.yaml ================================================ name: Check if checksums are up-to-date on: schedule: - cron: 0 20 * * * workflow_dispatch: jobs: checksum_check: if: github.repository == 'mudler/LocalAI' runs-on: ubuntu-latest steps: - name: Force Install GIT latest run: | sudo apt-get update \ && sudo apt-get install -y software-properties-common \ && sudo apt-get update \ && sudo add-apt-repository -y ppa:git-core/ppa \ && sudo apt-get update \ && sudo apt-get install -y git - uses: actions/checkout@v6 - name: Install dependencies run: | sudo apt-get update sudo apt-get install -y pip wget pip install huggingface_hub - name: 'Setup yq' uses: dcarbone/install-yq-action@v1.3.1 with: version: 'v4.44.2' download-compressed: true force: true - name: Checksum checker 🔧 run: | export HF_HOME=/hf_cache sudo mkdir /hf_cache sudo chmod 777 /hf_cache bash .github/checksum_checker.sh gallery/index.yaml - name: Create Pull Request uses: peter-evans/create-pull-request@v8 with: token: ${{ secrets.UPDATE_BOT_TOKEN }} push-to-fork: ci-forks/LocalAI commit-message: ':arrow_up: Checksum updates in gallery/index.yaml' title: 'chore(model-gallery): :arrow_up: update checksum' branch: "update/checksum" body: Updating checksums in gallery/index.yaml signoff: true ================================================ FILE: .github/workflows/deploy-explorer.yaml ================================================ name: Explorer deployment on: push: branches: - master tags: - 'v*' concurrency: group: ci-deploy-${{ github.head_ref || github.ref }}-${{ github.repository }} jobs: build-linux: if: github.repository == 'mudler/LocalAI' runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - uses: actions/setup-go@v5 with: go-version: '1.21.x' cache: false - name: Dependencies run: | sudo apt-get update sudo apt-get install -y wget curl build-essential ffmpeg protobuf-compiler ccache upx-ucl gawk cmake libgmock-dev go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 make protogen-go - name: Build api run: | CGO_ENABLED=0 make build - name: rm uses: appleboy/ssh-action@v1.2.5 with: host: ${{ secrets.EXPLORER_SSH_HOST }} username: ${{ secrets.EXPLORER_SSH_USERNAME }} key: ${{ secrets.EXPLORER_SSH_KEY }} port: ${{ secrets.EXPLORER_SSH_PORT }} script: | sudo rm -rf local-ai/ || true - name: copy file via ssh uses: appleboy/scp-action@v1.0.0 with: host: ${{ secrets.EXPLORER_SSH_HOST }} username: ${{ secrets.EXPLORER_SSH_USERNAME }} key: ${{ secrets.EXPLORER_SSH_KEY }} port: ${{ secrets.EXPLORER_SSH_PORT }} source: "local-ai" overwrite: true rm: true target: ./local-ai - name: restarting uses: appleboy/ssh-action@v1.2.5 with: host: ${{ secrets.EXPLORER_SSH_HOST }} username: ${{ secrets.EXPLORER_SSH_USERNAME }} key: ${{ secrets.EXPLORER_SSH_KEY }} port: ${{ secrets.EXPLORER_SSH_PORT }} script: | sudo cp -rfv local-ai/local-ai /usr/bin/local-ai sudo systemctl restart local-ai ================================================ FILE: .github/workflows/disabled/comment-pr.yaml ================================================ name: Comment PRs on: pull_request_target: jobs: comment-pr: env: MODEL_NAME: hermes-2-theta-llama-3-8b runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v3 with: ref: "${{ github.event.pull_request.merge_commit_sha }}" fetch-depth: 0 # needed to checkout all branches for this Action to work - uses: mudler/localai-github-action@v1 with: model: 'hermes-2-theta-llama-3-8b' # Any from models.localai.io, or from huggingface.com with: "huggingface:///file" # Check the PR diff using the current branch and the base branch of the PR - uses: GrantBirki/git-diff-action@v2.7.0 id: git-diff-action with: json_diff_file_output: diff.json raw_diff_file_output: diff.txt file_output_only: "true" base_branch: ${{ github.event.pull_request.base.sha }} - name: Show diff env: DIFF: ${{ steps.git-diff-action.outputs.raw-diff-path }} run: | cat $DIFF - name: Summarize env: DIFF: ${{ steps.git-diff-action.outputs.raw-diff-path }} id: summarize run: | input="$(cat $DIFF)" # Define the LocalAI API endpoint API_URL="http://localhost:8080/chat/completions" # Create a JSON payload using jq to handle special characters json_payload=$(jq -n --arg input "$input" '{ model: "'$MODEL_NAME'", messages: [ { role: "system", content: "You are LocalAI-bot in Github that helps understanding PRs and assess complexity. Explain what has changed in this PR diff and why" }, { role: "user", content: $input } ] }') # Send the request to LocalAI response=$(curl -s -X POST $API_URL \ -H "Content-Type: application/json" \ -d "$json_payload") # Extract the summary from the response summary="$(echo $response | jq -r '.choices[0].message.content')" # Print the summary # -H "Authorization: Bearer $API_KEY" \ echo "Summary:" echo "$summary" echo "payload sent" echo "$json_payload" { echo 'message<> "$GITHUB_OUTPUT" docker logs --tail 10 local-ai - uses: mshick/add-pr-comment@v2 if: always() with: repo-token: ${{ secrets.UPDATE_BOT_TOKEN }} message: ${{ steps.summarize.outputs.message }} message-failure: | Uh oh! Could not analyze this PR, maybe it's too big? ================================================ FILE: .github/workflows/disabled/dependabot_auto.yml ================================================ name: Dependabot auto-merge on: - pull_request_target permissions: contents: write pull-requests: write packages: read jobs: dependabot: if: github.repository == 'mudler/LocalAI' && github.actor == 'dependabot[bot]' runs-on: ubuntu-latest steps: - name: Dependabot metadata id: metadata uses: dependabot/fetch-metadata@v2.5.0 with: github-token: "${{ secrets.GITHUB_TOKEN }}" skip-commit-verification: true - name: Checkout repository uses: actions/checkout@v6 - name: Approve a PR if not already approved run: | gh pr checkout "$PR_URL" if [ "$(gh pr status --json reviewDecision -q .currentBranch.reviewDecision)" != "APPROVED" ]; then gh pr review --approve "$PR_URL" else echo "PR already approved."; fi env: PR_URL: ${{github.event.pull_request.html_url}} GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} - name: Enable auto-merge for Dependabot PRs if: ${{ contains(github.event.pull_request.title, 'bump')}} run: gh pr merge --auto --squash "$PR_URL" env: PR_URL: ${{github.event.pull_request.html_url}} GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} ================================================ FILE: .github/workflows/disabled/labeler.yml ================================================ name: "Pull Request Labeler" on: - pull_request_target jobs: labeler: permissions: contents: read pull-requests: write runs-on: ubuntu-latest steps: - uses: actions/labeler@v6 ================================================ FILE: .github/workflows/disabled/localaibot_automerge.yml ================================================ name: LocalAI-bot auto-merge on: - pull_request_target permissions: contents: write pull-requests: write packages: read issues: write # for Homebrew/actions/post-comment actions: write # to dispatch publish workflow jobs: dependabot: if: github.repository == 'mudler/LocalAI' && github.actor == 'localai-bot' && contains(github.event.pull_request.title, 'chore:') runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v6 - name: Approve a PR if not already approved run: | gh pr checkout "$PR_URL" if [ "$(gh pr status --json reviewDecision -q .currentBranch.reviewDecision)" != "APPROVED" ]; then gh pr review --approve "$PR_URL" else echo "PR already approved."; fi env: PR_URL: ${{github.event.pull_request.html_url}} GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} - name: Enable auto-merge for LocalAIBot PRs run: gh pr merge --auto --squash "$PR_URL" env: PR_URL: ${{github.event.pull_request.html_url}} GITHUB_TOKEN: ${{secrets.GITHUB_TOKEN}} ================================================ FILE: .github/workflows/disabled/notify-models.yaml ================================================ name: Notifications for new models on: pull_request_target: types: - closed permissions: contents: read pull-requests: read jobs: notify-discord: if: github.repository == 'mudler/LocalAI' && (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) env: MODEL_NAME: gemma-3-12b-it-qat runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 with: fetch-depth: 0 # needed to checkout all branches for this Action to work ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes - uses: mudler/localai-github-action@v1 with: model: 'gemma-3-12b-it-qat' # Any from models.localai.io, or from huggingface.com with: "huggingface:///file" # Check the PR diff using the current branch and the base branch of the PR - uses: GrantBirki/git-diff-action@v2.8.1 id: git-diff-action with: json_diff_file_output: diff.json raw_diff_file_output: diff.txt file_output_only: "true" - name: Summarize env: DIFF: ${{ steps.git-diff-action.outputs.raw-diff-path }} id: summarize run: | input="$(cat $DIFF)" # Define the LocalAI API endpoint API_URL="http://localhost:8080/chat/completions" # Create a JSON payload using jq to handle special characters json_payload=$(jq -n --arg input "$input" '{ model: "'$MODEL_NAME'", messages: [ { role: "system", content: "You are LocalAI-bot. Write a discord message to notify everyone about the new model from the git diff. Make it informal. An example can include: the URL of the model, the name, and a brief description of the model if exists. Also add an hint on how to install it in LocalAI and that can be browsed over https://models.localai.io. For example: local-ai run model_name_here" }, { role: "user", content: $input } ] }') # Send the request to LocalAI response=$(curl -s -X POST $API_URL \ -H "Content-Type: application/json" \ -d "$json_payload") # Extract the summary from the response summary="$(echo $response | jq -r '.choices[0].message.content')" # Print the summary # -H "Authorization: Bearer $API_KEY" \ echo "Summary:" echo "$summary" echo "payload sent" echo "$json_payload" { echo 'message<> "$GITHUB_OUTPUT" docker logs --tail 10 local-ai - name: Discord notification env: DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK_URL }} DISCORD_USERNAME: "LocalAI-Bot" DISCORD_AVATAR: "https://avatars.githubusercontent.com/u/139863280?v=4" uses: Ilshidur/action-discord@master with: args: ${{ steps.summarize.outputs.message }} - name: Setup tmate session if fails if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.23 with: detached: true connect-timeout-seconds: 180 limit-access-to-actor: true notify-twitter: if: github.repository == 'mudler/LocalAI' && (github.event.pull_request.merged == true) && (contains(github.event.pull_request.labels.*.name, 'area/ai-model')) env: MODEL_NAME: gemma-3-12b-it-qat runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 with: fetch-depth: 0 # needed to checkout all branches for this Action to work ref: ${{ github.event.pull_request.head.sha }} # Checkout the PR head to get the actual changes - name: Start LocalAI run: | echo "Starting LocalAI..." docker run -e -ti -d --name local-ai -p 8080:8080 localai/localai:master run --debug $MODEL_NAME until [ "`docker inspect -f {{.State.Health.Status}} local-ai`" == "healthy" ]; do echo "Waiting for container to be ready"; docker logs --tail 10 local-ai; sleep 2; done # Check the PR diff using the current branch and the base branch of the PR - uses: GrantBirki/git-diff-action@v2.8.1 id: git-diff-action with: json_diff_file_output: diff.json raw_diff_file_output: diff.txt file_output_only: "true" - name: Summarize env: DIFF: ${{ steps.git-diff-action.outputs.raw-diff-path }} id: summarize run: | input="$(cat $DIFF)" # Define the LocalAI API endpoint API_URL="http://localhost:8080/chat/completions" # Create a JSON payload using jq to handle special characters json_payload=$(jq -n --arg input "$input" '{ model: "'$MODEL_NAME'", messages: [ { role: "system", content: "You are LocalAI-bot. Write a twitter message to notify everyone about the new model from the git diff. Make it informal and really short. An example can include: the name, and a brief description of the model if exists. Also add an hint on how to install it in LocalAI. For example: local-ai run model_name_here" }, { role: "user", content: $input } ] }') # Send the request to LocalAI response=$(curl -s -X POST $API_URL \ -H "Content-Type: application/json" \ -d "$json_payload") # Extract the summary from the response summary="$(echo $response | jq -r '.choices[0].message.content')" # Print the summary # -H "Authorization: Bearer $API_KEY" \ echo "Summary:" echo "$summary" echo "payload sent" echo "$json_payload" { echo 'message<> "$GITHUB_OUTPUT" docker logs --tail 10 local-ai - uses: Eomm/why-don-t-you-tweet@v2 with: tweet-message: ${{ steps.summarize.outputs.message }} env: # Get your tokens from https://developer.twitter.com/apps TWITTER_CONSUMER_API_KEY: ${{ secrets.TWITTER_APP_KEY }} TWITTER_CONSUMER_API_SECRET: ${{ secrets.TWITTER_APP_SECRET }} TWITTER_ACCESS_TOKEN: ${{ secrets.TWITTER_ACCESS_TOKEN }} TWITTER_ACCESS_TOKEN_SECRET: ${{ secrets.TWITTER_ACCESS_TOKEN_SECRET }} - name: Setup tmate session if fails if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.23 with: detached: true connect-timeout-seconds: 180 limit-access-to-actor: true ================================================ FILE: .github/workflows/disabled/prlint.yaml ================================================ name: Check PR style on: pull_request_target: types: - opened - reopened - edited - synchronize jobs: title-lint: runs-on: ubuntu-latest permissions: statuses: write steps: - uses: aslafy-z/conventional-pr-title-action@v3 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} # check-pr-description: # runs-on: ubuntu-latest # steps: # - uses: actions/checkout@v2 # - uses: jadrol/pr-description-checker-action@v1.0.0 # id: description-checker # with: # repo-token: ${{ secrets.GITHUB_TOKEN }} # exempt-labels: no qa ================================================ FILE: .github/workflows/disabled/test-gpu.yml ================================================ --- name: 'GPU tests' on: pull_request: push: branches: - master tags: - '*' concurrency: group: ci-gpu-tests-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: ubuntu-latest: runs-on: gpu strategy: matrix: go-version: ['1.21.x'] steps: - name: Clone uses: actions/checkout@v4 with: submodules: true - name: Setup Go ${{ matrix.go-version }} uses: actions/setup-go@v4 with: go-version: ${{ matrix.go-version }} # You can test your matrix by printing the current Go version - name: Display Go version run: go version - name: Dependencies run: | sudo apt-get update sudo DEBIAN_FRONTEND=noninteractive apt-get install -y make wget - name: Build run: | if [ ! -e /run/systemd/system ]; then sudo mkdir /run/systemd/system fi sudo mkdir -p /host/tests/${{ github.head_ref || github.ref }} sudo chmod -R 777 /host/tests/${{ github.head_ref || github.ref }} make \ TEST_DIR="/host/tests/${{ github.head_ref || github.ref }}" \ BUILD_TYPE=cublas \ prepare-e2e run-e2e-image test-e2e - name: Release space from worker ♻ if: always() run: | sudo rm -rf build || true sudo rm -rf bin || true sudo rm -rf dist || true sudo docker logs $(sudo docker ps -q --filter ancestor=localai-tests) > logs.txt sudo cat logs.txt || true sudo rm -rf logs.txt make clean || true make \ TEST_DIR="/host/tests/${{ github.head_ref || github.ref }}" \ teardown-e2e || true sudo rm -rf /host/tests/${{ github.head_ref || github.ref }} || true docker system prune -f -a --volumes || true ================================================ FILE: .github/workflows/gallery-agent.yaml ================================================ name: Gallery Agent on: schedule: - cron: '0 */3 * * *' # Run every 4 hours workflow_dispatch: inputs: search_term: description: 'Search term for models' required: false default: 'GGUF' type: string limit: description: 'Maximum number of models to process' required: false default: '15' type: string quantization: description: 'Preferred quantization format' required: false default: 'Q4_K_M' type: string max_models: description: 'Maximum number of models to add to the gallery' required: false default: '1' type: string jobs: gallery-agent: if: github.repository == 'mudler/LocalAI' runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@v6 with: token: ${{ secrets.GITHUB_TOKEN }} - name: Set up Go uses: actions/setup-go@v5 with: go-version: '1.21' - name: Proto Dependencies run: | # Install protoc curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \ rm protoc.zip go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af PATH="$PATH:$HOME/go/bin" make protogen-go - uses: mudler/localai-github-action@v1.1 with: model: 'https://huggingface.co/unsloth/Qwen3.5-2B-GGUF' - name: Run gallery agent env: #OPENAI_MODEL: ${{ secrets.OPENAI_MODEL }} OPENAI_MODE: Qwen3.5-2B-GGUF OPENAI_BASE_URL: "http://localhost:8080" OPENAI_KEY: ${{ secrets.OPENAI_KEY }} #OPENAI_BASE_URL: ${{ secrets.OPENAI_BASE_URL }} SEARCH_TERM: ${{ github.event.inputs.search_term || 'GGUF' }} LIMIT: ${{ github.event.inputs.limit || '15' }} QUANTIZATION: ${{ github.event.inputs.quantization || 'Q4_K_M' }} MAX_MODELS: ${{ github.event.inputs.max_models || '1' }} run: | export GALLERY_INDEX_PATH=$PWD/gallery/index.yaml go run ./.github/gallery-agent - name: Check for changes id: check_changes run: | if git diff --quiet gallery/index.yaml; then echo "changes=false" >> $GITHUB_OUTPUT echo "No changes detected in gallery/index.yaml" else echo "changes=true" >> $GITHUB_OUTPUT echo "Changes detected in gallery/index.yaml" git diff gallery/index.yaml fi - name: Read gallery agent summary id: read_summary if: steps.check_changes.outputs.changes == 'true' run: | if [ -f "./gallery-agent-summary.json" ]; then echo "summary_exists=true" >> $GITHUB_OUTPUT # Extract summary data using jq echo "search_term=$(jq -r '.search_term' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT echo "total_found=$(jq -r '.total_found' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT echo "models_added=$(jq -r '.models_added' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT echo "quantization=$(jq -r '.quantization' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT echo "processing_time=$(jq -r '.processing_time' ./gallery-agent-summary.json)" >> $GITHUB_OUTPUT # Create a formatted list of added models with URLs added_models=$(jq -r 'range(0; .added_model_ids | length) as $i | "- [\(.added_model_ids[$i])](\(.added_model_urls[$i]))"' ./gallery-agent-summary.json | tr '\n' '\n') echo "added_models<> $GITHUB_OUTPUT echo "$added_models" >> $GITHUB_OUTPUT echo "EOF" >> $GITHUB_OUTPUT rm -f ./gallery-agent-summary.json else echo "summary_exists=false" >> $GITHUB_OUTPUT fi - name: Create Pull Request if: steps.check_changes.outputs.changes == 'true' uses: peter-evans/create-pull-request@v8 with: token: ${{ secrets.UPDATE_BOT_TOKEN }} push-to-fork: ci-forks/LocalAI commit-message: 'chore(model gallery): :robot: add new models via gallery agent' title: 'chore(model gallery): :robot: add ${{ steps.read_summary.outputs.models_added || 0 }} new models via gallery agent' # Branch has to be unique so PRs are not overriding each other branch-suffix: timestamp body: | This PR was automatically created by the gallery agent workflow. **Summary:** - **Search Term:** ${{ steps.read_summary.outputs.search_term || github.event.inputs.search_term || 'GGUF' }} - **Models Found:** ${{ steps.read_summary.outputs.total_found || 'N/A' }} - **Models Added:** ${{ steps.read_summary.outputs.models_added || '0' }} - **Quantization:** ${{ steps.read_summary.outputs.quantization || github.event.inputs.quantization || 'Q4_K_M' }} - **Processing Time:** ${{ steps.read_summary.outputs.processing_time || 'N/A' }} **Added Models:** ${{ steps.read_summary.outputs.added_models || '- No models added' }} **Workflow Details:** - Triggered by: `${{ github.event_name }}` - Run ID: `${{ github.run_id }}` - Commit: `${{ github.sha }}` signoff: true delete-branch: true ================================================ FILE: .github/workflows/generate_grpc_cache.yaml ================================================ name: 'generate and publish GRPC docker caches' on: workflow_dispatch: schedule: # daily at midnight - cron: '0 0 * * *' concurrency: group: grpc-cache-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: generate_caches: if: github.repository == 'mudler/LocalAI' strategy: matrix: include: - grpc-base-image: ubuntu:24.04 runs-on: 'ubuntu-latest' platforms: 'linux/amd64,linux/arm64' runs-on: ${{matrix.runs-on}} steps: - name: Release space from worker if: matrix.runs-on == 'ubuntu-latest' run: | echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo df -h echo sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true sudo apt-get remove --auto-remove android-sdk-platform-tools || true sudo apt-get purge --auto-remove android-sdk-platform-tools || true sudo rm -rf /usr/local/lib/android sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true sudo rm -rf /usr/share/dotnet sudo apt-get remove -y '^mono-.*' || true sudo apt-get remove -y '^ghc-.*' || true sudo apt-get remove -y '.*jdk.*|.*jre.*' || true sudo apt-get remove -y 'php.*' || true sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true sudo apt-get remove -y '^google-.*' || true sudo apt-get remove -y azure-cli || true sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true sudo apt-get remove -y '^gfortran-.*' || true sudo apt-get remove -y microsoft-edge-stable || true sudo apt-get remove -y firefox || true sudo apt-get remove -y powershell || true sudo apt-get remove -y r-base-core || true sudo apt-get autoremove -y sudo apt-get clean echo echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo sudo rm -rfv build || true sudo rm -rf /usr/share/dotnet || true sudo rm -rf /opt/ghc || true sudo rm -rf "/usr/local/share/boost" || true sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true df -h - name: Set up QEMU uses: docker/setup-qemu-action@master with: platforms: all - name: Set up Docker Buildx id: buildx uses: docker/setup-buildx-action@master - name: Checkout uses: actions/checkout@v6 - name: Cache GRPC uses: docker/build-push-action@v7 with: builder: ${{ steps.buildx.outputs.name }} # The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache. # This means that even the MAKEFLAGS have to be an EXACT match. # If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch. build-args: | GRPC_BASE_IMAGE=${{ matrix.grpc-base-image }} GRPC_MAKEFLAGS=--jobs=4 --output-sync=target GRPC_VERSION=v1.65.0 context: . file: ./Dockerfile cache-to: type=gha,ignore-error=true cache-from: type=gha target: grpc platforms: ${{ matrix.platforms }} push: false ================================================ FILE: .github/workflows/generate_intel_image.yaml ================================================ name: 'generate and publish intel docker caches' on: workflow_dispatch: push: branches: - master concurrency: group: intel-cache-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: generate_caches: if: github.repository == 'mudler/LocalAI' strategy: matrix: include: - base-image: intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04 runs-on: 'arc-runner-set' platforms: 'linux/amd64' runs-on: ${{matrix.runs-on}} steps: - name: Set up QEMU uses: docker/setup-qemu-action@master with: platforms: all - name: Login to DockerHub if: github.event_name != 'pull_request' uses: docker/login-action@v4 with: username: ${{ secrets.DOCKERHUB_USERNAME }} password: ${{ secrets.DOCKERHUB_PASSWORD }} - name: Login to quay if: github.event_name != 'pull_request' uses: docker/login-action@v4 with: registry: quay.io username: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} password: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} - name: Set up Docker Buildx id: buildx uses: docker/setup-buildx-action@master - name: Checkout uses: actions/checkout@v6 - name: Cache Intel images uses: docker/build-push-action@v7 with: builder: ${{ steps.buildx.outputs.name }} build-args: | BASE_IMAGE=${{ matrix.base-image }} context: . file: ./Dockerfile tags: quay.io/go-skynet/intel-oneapi-base:24.04 push: true target: intel platforms: ${{ matrix.platforms }} ================================================ FILE: .github/workflows/image-pr.yml ================================================ --- name: 'build container images tests' on: pull_request: concurrency: group: ci-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: image-build: uses: ./.github/workflows/image_build.yml with: tag-latest: ${{ matrix.tag-latest }} tag-suffix: ${{ matrix.tag-suffix }} build-type: ${{ matrix.build-type }} cuda-major-version: ${{ matrix.cuda-major-version }} cuda-minor-version: ${{ matrix.cuda-minor-version }} platforms: ${{ matrix.platforms }} runs-on: ${{ matrix.runs-on }} base-image: ${{ matrix.base-image }} grpc-base-image: ${{ matrix.grpc-base-image }} makeflags: ${{ matrix.makeflags }} ubuntu-version: ${{ matrix.ubuntu-version }} secrets: dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} strategy: # Pushing with all jobs in parallel # eats the bandwidth of all the nodes max-parallel: ${{ github.event_name != 'pull_request' && 4 || 8 }} fail-fast: false matrix: include: - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'false' tag-suffix: '-gpu-nvidia-cuda-12' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" makeflags: "--jobs=3 --output-sync=target" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'false' tag-suffix: '-gpu-nvidia-cuda-13' runs-on: 'ubuntu-latest' base-image: "ubuntu:22.04" makeflags: "--jobs=3 --output-sync=target" ubuntu-version: '2404' - build-type: 'hipblas' platforms: 'linux/amd64' tag-latest: 'false' tag-suffix: '-hipblas' base-image: "rocm/dev-ubuntu-24.04:6.4.4" grpc-base-image: "ubuntu:24.04" runs-on: 'ubuntu-latest' makeflags: "--jobs=3 --output-sync=target" ubuntu-version: '2404' - build-type: 'sycl' platforms: 'linux/amd64' tag-latest: 'false' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" grpc-base-image: "ubuntu:24.04" tag-suffix: 'sycl' runs-on: 'ubuntu-latest' makeflags: "--jobs=3 --output-sync=target" ubuntu-version: '2404' - build-type: 'vulkan' platforms: 'linux/amd64,linux/arm64' tag-latest: 'false' tag-suffix: '-vulkan-core' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" makeflags: "--jobs=4 --output-sync=target" ubuntu-version: '2404' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'false' tag-suffix: '-nvidia-l4t-arm64-cuda-13' base-image: "ubuntu:24.04" runs-on: 'ubuntu-24.04-arm' makeflags: "--jobs=4 --output-sync=target" skip-drivers: 'false' ubuntu-version: '2404' ================================================ FILE: .github/workflows/image.yml ================================================ --- name: 'build container images' on: push: branches: - master tags: - '*' concurrency: group: ci-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: hipblas-jobs: if: github.repository == 'mudler/LocalAI' uses: ./.github/workflows/image_build.yml with: tag-latest: ${{ matrix.tag-latest }} tag-suffix: ${{ matrix.tag-suffix }} build-type: ${{ matrix.build-type }} cuda-major-version: ${{ matrix.cuda-major-version }} cuda-minor-version: ${{ matrix.cuda-minor-version }} platforms: ${{ matrix.platforms }} runs-on: ${{ matrix.runs-on }} base-image: ${{ matrix.base-image }} grpc-base-image: ${{ matrix.grpc-base-image }} makeflags: ${{ matrix.makeflags }} ubuntu-version: ${{ matrix.ubuntu-version }} ubuntu-codename: ${{ matrix.ubuntu-codename }} secrets: dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} strategy: matrix: include: - build-type: 'hipblas' platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-hipblas' base-image: "rocm/dev-ubuntu-24.04:6.4.4" grpc-base-image: "ubuntu:24.04" runs-on: 'ubuntu-latest' makeflags: "--jobs=3 --output-sync=target" ubuntu-version: '2404' ubuntu-codename: 'noble' core-image-build: if: github.repository == 'mudler/LocalAI' uses: ./.github/workflows/image_build.yml with: tag-latest: ${{ matrix.tag-latest }} tag-suffix: ${{ matrix.tag-suffix }} build-type: ${{ matrix.build-type }} cuda-major-version: ${{ matrix.cuda-major-version }} cuda-minor-version: ${{ matrix.cuda-minor-version }} platforms: ${{ matrix.platforms }} runs-on: ${{ matrix.runs-on }} base-image: ${{ matrix.base-image }} grpc-base-image: ${{ matrix.grpc-base-image }} makeflags: ${{ matrix.makeflags }} skip-drivers: ${{ matrix.skip-drivers }} ubuntu-version: ${{ matrix.ubuntu-version }} ubuntu-codename: ${{ matrix.ubuntu-codename }} secrets: dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} strategy: #max-parallel: ${{ github.event_name != 'pull_request' && 2 || 4 }} matrix: include: - build-type: '' platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '' base-image: "ubuntu:24.04" runs-on: 'ubuntu-latest' makeflags: "--jobs=4 --output-sync=target" skip-drivers: 'false' ubuntu-version: '2404' ubuntu-codename: 'noble' - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "8" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-12' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' makeflags: "--jobs=4 --output-sync=target" ubuntu-version: '2404' ubuntu-codename: 'noble' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/amd64' tag-latest: 'auto' tag-suffix: '-gpu-nvidia-cuda-13' runs-on: 'ubuntu-latest' base-image: "ubuntu:22.04" skip-drivers: 'false' makeflags: "--jobs=4 --output-sync=target" ubuntu-version: '2404' ubuntu-codename: 'noble' - build-type: 'vulkan' platforms: 'linux/amd64,linux/arm64' tag-latest: 'auto' tag-suffix: '-gpu-vulkan' runs-on: 'ubuntu-latest' base-image: "ubuntu:24.04" skip-drivers: 'false' makeflags: "--jobs=4 --output-sync=target" ubuntu-version: '2404' ubuntu-codename: 'noble' - build-type: 'intel' platforms: 'linux/amd64' tag-latest: 'auto' base-image: "intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04" grpc-base-image: "ubuntu:24.04" tag-suffix: '-gpu-intel' runs-on: 'ubuntu-latest' makeflags: "--jobs=3 --output-sync=target" ubuntu-version: '2404' ubuntu-codename: 'noble' gh-runner: if: github.repository == 'mudler/LocalAI' uses: ./.github/workflows/image_build.yml with: tag-latest: ${{ matrix.tag-latest }} tag-suffix: ${{ matrix.tag-suffix }} build-type: ${{ matrix.build-type }} cuda-major-version: ${{ matrix.cuda-major-version }} cuda-minor-version: ${{ matrix.cuda-minor-version }} platforms: ${{ matrix.platforms }} runs-on: ${{ matrix.runs-on }} base-image: ${{ matrix.base-image }} grpc-base-image: ${{ matrix.grpc-base-image }} makeflags: ${{ matrix.makeflags }} skip-drivers: ${{ matrix.skip-drivers }} ubuntu-version: ${{ matrix.ubuntu-version }} ubuntu-codename: ${{ matrix.ubuntu-codename }} secrets: dockerUsername: ${{ secrets.DOCKERHUB_USERNAME }} dockerPassword: ${{ secrets.DOCKERHUB_PASSWORD }} quayUsername: ${{ secrets.LOCALAI_REGISTRY_USERNAME }} quayPassword: ${{ secrets.LOCALAI_REGISTRY_PASSWORD }} strategy: matrix: include: - build-type: 'cublas' cuda-major-version: "12" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-arm64' base-image: "nvcr.io/nvidia/l4t-jetpack:r36.4.0" runs-on: 'ubuntu-24.04-arm' makeflags: "--jobs=4 --output-sync=target" skip-drivers: 'true' ubuntu-version: "2204" ubuntu-codename: 'jammy' - build-type: 'cublas' cuda-major-version: "13" cuda-minor-version: "0" platforms: 'linux/arm64' tag-latest: 'auto' tag-suffix: '-nvidia-l4t-arm64-cuda-13' base-image: "ubuntu:24.04" runs-on: 'ubuntu-24.04-arm' makeflags: "--jobs=4 --output-sync=target" skip-drivers: 'false' ubuntu-version: '2404' ubuntu-codename: 'noble' ================================================ FILE: .github/workflows/image_build.yml ================================================ --- name: 'build container images (reusable)' on: workflow_call: inputs: base-image: description: 'Base image' required: true type: string grpc-base-image: description: 'GRPC Base image, must be a compatible image with base-image' required: false default: '' type: string build-type: description: 'Build type' default: '' type: string cuda-major-version: description: 'CUDA major version' default: "12" type: string cuda-minor-version: description: 'CUDA minor version' default: "9" type: string platforms: description: 'Platforms' default: '' type: string tag-latest: description: 'Tag latest' default: '' type: string tag-suffix: description: 'Tag suffix' default: '' type: string skip-drivers: description: 'Skip drivers by default' default: 'false' type: string runs-on: description: 'Runs on' required: true default: '' type: string makeflags: description: 'Make Flags' required: false default: '--jobs=4 --output-sync=target' type: string ubuntu-version: description: 'Ubuntu version' required: false default: '2204' type: string ubuntu-codename: description: 'Ubuntu codename' required: false default: 'noble' type: string secrets: dockerUsername: required: true dockerPassword: required: true quayUsername: required: true quayPassword: required: true jobs: reusable_image-build: runs-on: ${{ inputs.runs-on }} steps: - name: Free Disk Space (Ubuntu) if: inputs.runs-on == 'ubuntu-latest' uses: jlumbroso/free-disk-space@main with: # this might remove tools that are actually needed, # if set to "true" but frees about 6 GB tool-cache: true # all of these default to true, but feel free to set to # "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true docker-images: true swap-storage: true - name: Force Install GIT latest run: | sudo apt-get update \ && sudo apt-get install -y software-properties-common \ && sudo apt-get update \ && sudo add-apt-repository -y ppa:git-core/ppa \ && sudo apt-get update \ && sudo apt-get install -y git - name: Checkout uses: actions/checkout@v6 - name: Release space from worker if: inputs.runs-on == 'ubuntu-latest' run: | echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo df -h echo sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true sudo apt-get remove --auto-remove android-sdk-platform-tools snapd || true sudo apt-get purge --auto-remove android-sdk-platform-tools snapd || true sudo rm -rf /usr/local/lib/android sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true sudo rm -rf /usr/share/dotnet sudo apt-get remove -y '^mono-.*' || true sudo apt-get remove -y '^ghc-.*' || true sudo apt-get remove -y '.*jdk.*|.*jre.*' || true sudo apt-get remove -y 'php.*' || true sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true sudo apt-get remove -y '^google-.*' || true sudo apt-get remove -y azure-cli || true sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true sudo apt-get remove -y '^gfortran-.*' || true sudo apt-get remove -y microsoft-edge-stable || true sudo apt-get remove -y firefox || true sudo apt-get remove -y powershell || true sudo apt-get remove -y r-base-core || true sudo apt-get autoremove -y sudo apt-get clean echo echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo sudo rm -rfv build || true sudo rm -rf /usr/share/dotnet || true sudo rm -rf /opt/ghc || true sudo rm -rf "/usr/local/share/boost" || true sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true df -h - name: Docker meta id: meta if: github.event_name != 'pull_request' uses: docker/metadata-action@v6 with: images: | quay.io/go-skynet/local-ai localai/localai tags: | type=ref,event=branch type=semver,pattern={{raw}} type=sha flavor: | latest=${{ inputs.tag-latest }} suffix=${{ inputs.tag-suffix }},onlatest=true - name: Docker meta for PR id: meta_pull_request if: github.event_name == 'pull_request' uses: docker/metadata-action@v6 with: images: | quay.io/go-skynet/ci-tests tags: | type=ref,event=branch,suffix=localai${{ github.event.number }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }} type=semver,pattern={{raw}},suffix=localai${{ github.event.number }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }} type=sha,suffix=localai${{ github.event.number }}-${{ inputs.build-type }}-${{ inputs.cuda-major-version }}-${{ inputs.cuda-minor-version }} flavor: | latest=${{ inputs.tag-latest }} suffix=${{ inputs.tag-suffix }} - name: Set up QEMU uses: docker/setup-qemu-action@master with: platforms: all - name: Set up Docker Buildx id: buildx uses: docker/setup-buildx-action@master - name: Login to DockerHub if: github.event_name != 'pull_request' uses: docker/login-action@v4 with: username: ${{ secrets.dockerUsername }} password: ${{ secrets.dockerPassword }} - name: Login to DockerHub if: github.event_name != 'pull_request' uses: docker/login-action@v4 with: registry: quay.io username: ${{ secrets.quayUsername }} password: ${{ secrets.quayPassword }} - name: Build and push uses: docker/build-push-action@v7 if: github.event_name != 'pull_request' with: builder: ${{ steps.buildx.outputs.name }} # The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache. # This means that even the MAKEFLAGS have to be an EXACT match. # If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch. # This is why some build args like GRPC_VERSION and MAKEFLAGS are hardcoded build-args: | BUILD_TYPE=${{ inputs.build-type }} CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }} CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }} BASE_IMAGE=${{ inputs.base-image }} GRPC_BASE_IMAGE=${{ inputs.grpc-base-image || inputs.base-image }} GRPC_MAKEFLAGS=--jobs=4 --output-sync=target GRPC_VERSION=v1.65.0 MAKEFLAGS=${{ inputs.makeflags }} SKIP_DRIVERS=${{ inputs.skip-drivers }} UBUNTU_VERSION=${{ inputs.ubuntu-version }} UBUNTU_CODENAME=${{ inputs.ubuntu-codename }} context: . file: ./Dockerfile cache-from: type=gha platforms: ${{ inputs.platforms }} push: ${{ github.event_name != 'pull_request' }} tags: ${{ steps.meta.outputs.tags }} labels: ${{ steps.meta.outputs.labels }} ### Start testing image - name: Build and push uses: docker/build-push-action@v7 if: github.event_name == 'pull_request' with: builder: ${{ steps.buildx.outputs.name }} # The build-args MUST be an EXACT match between the image cache and other workflow steps that want to use that cache. # This means that even the MAKEFLAGS have to be an EXACT match. # If the build-args are not an EXACT match, it will result in a cache miss, which will require GRPC to be built from scratch. # This is why some build args like GRPC_VERSION and MAKEFLAGS are hardcoded build-args: | BUILD_TYPE=${{ inputs.build-type }} CUDA_MAJOR_VERSION=${{ inputs.cuda-major-version }} CUDA_MINOR_VERSION=${{ inputs.cuda-minor-version }} BASE_IMAGE=${{ inputs.base-image }} GRPC_BASE_IMAGE=${{ inputs.grpc-base-image || inputs.base-image }} GRPC_MAKEFLAGS=--jobs=4 --output-sync=target GRPC_VERSION=v1.65.0 MAKEFLAGS=${{ inputs.makeflags }} SKIP_DRIVERS=${{ inputs.skip-drivers }} UBUNTU_VERSION=${{ inputs.ubuntu-version }} UBUNTU_CODENAME=${{ inputs.ubuntu-codename }} context: . file: ./Dockerfile cache-from: type=gha platforms: ${{ inputs.platforms }} #push: true tags: ${{ steps.meta_pull_request.outputs.tags }} labels: ${{ steps.meta_pull_request.outputs.labels }} ## End testing image - name: job summary run: | echo "Built image: ${{ steps.meta.outputs.labels }}" >> $GITHUB_STEP_SUMMARY ================================================ FILE: .github/workflows/notify-releases.yaml ================================================ name: Release notifications on: release: types: - published jobs: notify-discord: if: github.repository == 'mudler/LocalAI' runs-on: ubuntu-latest env: RELEASE_BODY: ${{ github.event.release.body }} RELEASE_TITLE: ${{ github.event.release.name }} RELEASE_TAG_NAME: ${{ github.event.release.tag_name }} MODEL_NAME: gemma-3-12b-it-qat steps: - uses: mudler/localai-github-action@v1 with: model: 'gemma-3-12b-it-qat' # Any from models.localai.io, or from huggingface.com with: "huggingface:///file" - name: Summarize id: summarize run: | input="$RELEASE_TITLE\b$RELEASE_BODY" # Define the LocalAI API endpoint API_URL="http://localhost:8080/chat/completions" # Create a JSON payload using jq to handle special characters json_payload=$(jq -n --arg input "$input" '{ model: "'$MODEL_NAME'", messages: [ { role: "system", content: "Write a discord message with a bullet point summary of the release notes." }, { role: "user", content: $input } ] }') # Send the request to LocalAI API response=$(curl -s -X POST $API_URL \ -H "Content-Type: application/json" \ -d "$json_payload") # Extract the summary from the response summary=$(echo $response | jq -r '.choices[0].message.content') # Print the summary # -H "Authorization: Bearer $API_KEY" \ { echo 'message<> "$GITHUB_OUTPUT" - name: Discord notification env: DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK_URL_RELEASE }} DISCORD_USERNAME: "LocalAI-Bot" DISCORD_AVATAR: "https://avatars.githubusercontent.com/u/139863280?v=4" uses: Ilshidur/action-discord@master with: args: ${{ steps.summarize.outputs.message }} ================================================ FILE: .github/workflows/release.yaml ================================================ name: goreleaser on: push: tags: - 'v*' jobs: goreleaser: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v5 with: go-version: 1.23 - name: Run GoReleaser uses: goreleaser/goreleaser-action@v7 with: version: v2.11.0 args: release --clean env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} launcher-build-darwin: runs-on: macos-latest steps: - name: Checkout uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v5 with: go-version: 1.23 - name: Build launcher for macOS ARM64 run: | make build-launcher-darwin - name: Upload DMG to Release uses: softprops/action-gh-release@v2 with: files: ./dist/LocalAI.dmg launcher-build-linux: runs-on: ubuntu-latest steps: - name: Checkout uses: actions/checkout@v6 with: fetch-depth: 0 - name: Set up Go uses: actions/setup-go@v5 with: go-version: 1.23 - name: Build launcher for Linux run: | sudo apt-get update sudo apt-get install golang gcc libgl1-mesa-dev xorg-dev libxkbcommon-dev make build-launcher-linux - name: Upload Linux launcher artifacts uses: softprops/action-gh-release@v2 with: files: ./local-ai-launcher-linux.tar.xz ================================================ FILE: .github/workflows/secscan.yaml ================================================ name: "Security Scan" # Run workflow each time code is pushed to your repository and on a schedule. # The scheduled workflow runs every at 00:00 on Sunday UTC time. on: push: schedule: - cron: '0 0 * * 0' jobs: tests: runs-on: ubuntu-latest env: GO111MODULE: on steps: - name: Checkout Source uses: actions/checkout@v6 if: ${{ github.actor != 'dependabot[bot]' }} - name: Run Gosec Security Scanner if: ${{ github.actor != 'dependabot[bot]' }} uses: securego/gosec@v2.22.9 with: # we let the report trigger content trigger a failure using the GitHub Security features. args: '-no-fail -fmt sarif -out results.sarif ./...' - name: Upload SARIF file if: ${{ github.actor != 'dependabot[bot]' }} uses: github/codeql-action/upload-sarif@v4 with: # Path to SARIF file relative to the root of the repository sarif_file: results.sarif ================================================ FILE: .github/workflows/stalebot.yml ================================================ name: 'Close stale issues and PRs' permissions: issues: write pull-requests: write on: schedule: - cron: '30 1 * * *' jobs: stale: if: github.repository == 'mudler/LocalAI' runs-on: ubuntu-latest steps: - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v9 with: stale-issue-message: 'This issue is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 5 days.' stale-pr-message: 'This PR is stale because it has been open 90 days with no activity. Remove stale label or comment or this will be closed in 10 days.' close-issue-message: 'This issue was closed because it has been stalled for 5 days with no activity.' close-pr-message: 'This PR was closed because it has been stalled for 10 days with no activity.' days-before-issue-stale: 90 days-before-pr-stale: 90 days-before-issue-close: 5 days-before-pr-close: 10 exempt-issue-labels: 'roadmap' exempt-pr-labels: 'roadmap' ================================================ FILE: .github/workflows/test-extra.yml ================================================ --- name: 'Tests extras backends' on: pull_request: push: branches: - master tags: - '*' concurrency: group: ci-tests-extra-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: # Requires CUDA # tests-chatterbox-tts: # runs-on: ubuntu-latest # steps: # - name: Clone # uses: actions/checkout@v6 # with: # submodules: true # - name: Dependencies # run: | # sudo apt-get update # sudo apt-get install build-essential ffmpeg # # Install UV # curl -LsSf https://astral.sh/uv/install.sh | sh # sudo apt-get install -y ca-certificates cmake curl patch python3-pip # sudo apt-get install -y libopencv-dev # pip install --user --no-cache-dir grpcio-tools==1.64.1 # - name: Test chatterbox-tts # run: | # make --jobs=5 --output-sync=target -C backend/python/chatterbox # make --jobs=5 --output-sync=target -C backend/python/chatterbox test tests-transformers: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install build-essential ffmpeg # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh sudo apt-get install -y ca-certificates cmake curl patch python3-pip sudo apt-get install -y libopencv-dev pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test transformers run: | make --jobs=5 --output-sync=target -C backend/python/transformers make --jobs=5 --output-sync=target -C backend/python/transformers test tests-rerankers: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install build-essential ffmpeg # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh sudo apt-get install -y ca-certificates cmake curl patch python3-pip sudo apt-get install -y libopencv-dev pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test rerankers run: | make --jobs=5 --output-sync=target -C backend/python/rerankers make --jobs=5 --output-sync=target -C backend/python/rerankers test tests-diffusers: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch python3-pip sudo apt-get install -y libopencv-dev # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test diffusers run: | make --jobs=5 --output-sync=target -C backend/python/diffusers make --jobs=5 --output-sync=target -C backend/python/diffusers test #tests-vllm: # runs-on: ubuntu-latest # steps: # - name: Clone # uses: actions/checkout@v6 # with: # submodules: true # - name: Dependencies # run: | # sudo apt-get update # sudo apt-get install -y build-essential ffmpeg # sudo apt-get install -y ca-certificates cmake curl patch python3-pip # sudo apt-get install -y libopencv-dev # # Install UV # curl -LsSf https://astral.sh/uv/install.sh | sh # pip install --user --no-cache-dir grpcio-tools==1.64.1 # - name: Test vllm backend # run: | # make --jobs=5 --output-sync=target -C backend/python/vllm # make --jobs=5 --output-sync=target -C backend/python/vllm test # tests-transformers-musicgen: # runs-on: ubuntu-latest # steps: # - name: Clone # uses: actions/checkout@v6 # with: # submodules: true # - name: Dependencies # run: | # sudo apt-get update # sudo apt-get install build-essential ffmpeg # # Install UV # curl -LsSf https://astral.sh/uv/install.sh | sh # sudo apt-get install -y ca-certificates cmake curl patch python3-pip # sudo apt-get install -y libopencv-dev # pip install --user --no-cache-dir grpcio-tools==1.64.1 # - name: Test transformers-musicgen # run: | # make --jobs=5 --output-sync=target -C backend/python/transformers-musicgen # make --jobs=5 --output-sync=target -C backend/python/transformers-musicgen test # tests-bark: # runs-on: ubuntu-latest # steps: # - name: Release space from worker # run: | # echo "Listing top largest packages" # pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) # head -n 30 <<< "${pkgs}" # echo # df -h # echo # sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true # sudo apt-get remove --auto-remove android-sdk-platform-tools || true # sudo apt-get purge --auto-remove android-sdk-platform-tools || true # sudo rm -rf /usr/local/lib/android # sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true # sudo rm -rf /usr/share/dotnet # sudo apt-get remove -y '^mono-.*' || true # sudo apt-get remove -y '^ghc-.*' || true # sudo apt-get remove -y '.*jdk.*|.*jre.*' || true # sudo apt-get remove -y 'php.*' || true # sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true # sudo apt-get remove -y '^google-.*' || true # sudo apt-get remove -y azure-cli || true # sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true # sudo apt-get remove -y '^gfortran-.*' || true # sudo apt-get remove -y microsoft-edge-stable || true # sudo apt-get remove -y firefox || true # sudo apt-get remove -y powershell || true # sudo apt-get remove -y r-base-core || true # sudo apt-get autoremove -y # sudo apt-get clean # echo # echo "Listing top largest packages" # pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) # head -n 30 <<< "${pkgs}" # echo # sudo rm -rfv build || true # sudo rm -rf /usr/share/dotnet || true # sudo rm -rf /opt/ghc || true # sudo rm -rf "/usr/local/share/boost" || true # sudo rm -rf "$AGENT_TOOLSDIRECTORY" || true # df -h # - name: Clone # uses: actions/checkout@v6 # with: # submodules: true # - name: Dependencies # run: | # sudo apt-get update # sudo apt-get install build-essential ffmpeg # # Install UV # curl -LsSf https://astral.sh/uv/install.sh | sh # sudo apt-get install -y ca-certificates cmake curl patch python3-pip # sudo apt-get install -y libopencv-dev # pip install --user --no-cache-dir grpcio-tools==1.64.1 # - name: Test bark # run: | # make --jobs=5 --output-sync=target -C backend/python/bark # make --jobs=5 --output-sync=target -C backend/python/bark test # Below tests needs GPU. Commented out for now # TODO: Re-enable as soon as we have GPU nodes # tests-vllm: # runs-on: ubuntu-latest # steps: # - name: Clone # uses: actions/checkout@v6 # with: # submodules: true # - name: Dependencies # run: | # sudo apt-get update # sudo apt-get install build-essential ffmpeg # # Install UV # curl -LsSf https://astral.sh/uv/install.sh | sh # sudo apt-get install -y ca-certificates cmake curl patch python3-pip # sudo apt-get install -y libopencv-dev # pip install --user --no-cache-dir grpcio-tools==1.64.1 # - name: Test vllm # run: | # make --jobs=5 --output-sync=target -C backend/python/vllm # make --jobs=5 --output-sync=target -C backend/python/vllm test tests-coqui: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch espeak espeak-ng python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test coqui run: | make --jobs=5 --output-sync=target -C backend/python/coqui make --jobs=5 --output-sync=target -C backend/python/coqui test tests-moonshine: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test moonshine run: | make --jobs=5 --output-sync=target -C backend/python/moonshine make --jobs=5 --output-sync=target -C backend/python/moonshine test tests-pocket-tts: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test pocket-tts run: | make --jobs=5 --output-sync=target -C backend/python/pocket-tts make --jobs=5 --output-sync=target -C backend/python/pocket-tts test tests-qwen-tts: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test qwen-tts run: | make --jobs=5 --output-sync=target -C backend/python/qwen-tts make --jobs=5 --output-sync=target -C backend/python/qwen-tts test # TODO: s2-pro model is too large to load on CPU-only CI runners — re-enable # when we have GPU runners or a smaller test model. # tests-fish-speech: # runs-on: ubuntu-latest # timeout-minutes: 45 # steps: # - name: Clone # uses: actions/checkout@v6 # with: # submodules: true # - name: Dependencies # run: | # sudo apt-get update # sudo apt-get install -y build-essential ffmpeg portaudio19-dev # sudo apt-get install -y ca-certificates cmake curl patch python3-pip # # Install UV # curl -LsSf https://astral.sh/uv/install.sh | sh # pip install --user --no-cache-dir grpcio-tools==1.64.1 # - name: Test fish-speech # run: | # make --jobs=5 --output-sync=target -C backend/python/fish-speech # make --jobs=5 --output-sync=target -C backend/python/fish-speech test tests-qwen-asr: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential ffmpeg sox sudo apt-get install -y ca-certificates cmake curl patch python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test qwen-asr run: | make --jobs=5 --output-sync=target -C backend/python/qwen-asr make --jobs=5 --output-sync=target -C backend/python/qwen-asr test tests-nemo: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential ffmpeg sox sudo apt-get install -y ca-certificates cmake curl patch python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test nemo run: | make --jobs=5 --output-sync=target -C backend/python/nemo make --jobs=5 --output-sync=target -C backend/python/nemo test tests-voxcpm: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install build-essential ffmpeg sudo apt-get install -y ca-certificates cmake curl patch python3-pip # Install UV curl -LsSf https://astral.sh/uv/install.sh | sh pip install --user --no-cache-dir grpcio-tools==1.64.1 - name: Test voxcpm run: | make --jobs=5 --output-sync=target -C backend/python/voxcpm make --jobs=5 --output-sync=target -C backend/python/voxcpm test tests-acestep-cpp: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential cmake curl libopenblas-dev ffmpeg - name: Setup Go uses: actions/setup-go@v5 - name: Display Go version run: go version - name: Proto Dependencies run: | # Install protoc curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \ rm protoc.zip go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af PATH="$PATH:$HOME/go/bin" make protogen-go - name: Build acestep-cpp run: | make --jobs=5 --output-sync=target -C backend/go/acestep-cpp - name: Test acestep-cpp run: | make --jobs=5 --output-sync=target -C backend/go/acestep-cpp test tests-voxtral: runs-on: ubuntu-latest steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential cmake curl libopenblas-dev ffmpeg - name: Setup Go uses: actions/setup-go@v5 # You can test your matrix by printing the current Go version - name: Display Go version run: go version - name: Proto Dependencies run: | # Install protoc curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \ rm protoc.zip go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af PATH="$PATH:$HOME/go/bin" make protogen-go - name: Build voxtral run: | make --jobs=5 --output-sync=target -C backend/go/voxtral - name: Test voxtral run: | make --jobs=5 --output-sync=target -C backend/go/voxtral test ================================================ FILE: .github/workflows/test.yml ================================================ --- name: 'tests' on: pull_request: push: branches: - master tags: - '*' env: GRPC_VERSION: v1.65.0 concurrency: group: ci-tests-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: tests-linux: runs-on: ubuntu-latest strategy: matrix: go-version: ['1.25.x'] steps: - name: Free Disk Space (Ubuntu) uses: jlumbroso/free-disk-space@main with: # this might remove tools that are actually needed, # if set to "true" but frees about 6 GB tool-cache: true # all of these default to true, but feel free to set to # "false" if necessary for your workflow android: true dotnet: true haskell: true large-packages: true docker-images: true swap-storage: true - name: Release space from worker run: | echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo df -h echo sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true sudo apt-get remove --auto-remove android-sdk-platform-tools || true sudo apt-get purge --auto-remove android-sdk-platform-tools || true sudo rm -rf /usr/local/lib/android sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true sudo rm -rf /usr/share/dotnet sudo apt-get remove -y '^mono-.*' || true sudo apt-get remove -y '^ghc-.*' || true sudo apt-get remove -y '.*jdk.*|.*jre.*' || true sudo apt-get remove -y 'php.*' || true sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true sudo apt-get remove -y '^google-.*' || true sudo apt-get remove -y azure-cli || true sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true sudo apt-get remove -y '^gfortran-.*' || true sudo apt-get autoremove -y sudo apt-get clean echo echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo sudo rm -rfv build || true df -h - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Setup Go ${{ matrix.go-version }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: false # You can test your matrix by printing the current Go version - name: Display Go version run: go version - name: Proto Dependencies run: | # Install protoc curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \ rm protoc.zip go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af PATH="$PATH:$HOME/go/bin" make protogen-go - name: Dependencies run: | sudo apt-get update sudo apt-get install curl ffmpeg libopus-dev - name: Setup Node.js uses: actions/setup-node@v6 with: node-version: '22' - name: Build React UI run: make react-ui - name: Build backends run: | make backends/transformers mkdir external && mv backends/transformers external/transformers make backends/llama-cpp backends/local-store backends/silero-vad backends/piper backends/whisper backends/stablediffusion-ggml - name: Test run: | TRANSFORMER_BACKEND=$PWD/external/transformers/run.sh PATH="$PATH:/root/go/bin" GO_TAGS="tts" make --jobs 5 --output-sync=target test - name: Setup tmate session if tests fail if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.23 with: detached: true connect-timeout-seconds: 180 limit-access-to-actor: true tests-e2e-container: runs-on: ubuntu-latest steps: - name: Release space from worker run: | echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo df -h echo sudo apt-get remove -y '^llvm-.*|^libllvm.*' || true sudo apt-get remove --auto-remove android-sdk-platform-tools || true sudo apt-get purge --auto-remove android-sdk-platform-tools || true sudo rm -rf /usr/local/lib/android sudo apt-get remove -y '^dotnet-.*|^aspnetcore-.*' || true sudo rm -rf /usr/share/dotnet sudo apt-get remove -y '^mono-.*' || true sudo apt-get remove -y '^ghc-.*' || true sudo apt-get remove -y '.*jdk.*|.*jre.*' || true sudo apt-get remove -y 'php.*' || true sudo apt-get remove -y hhvm powershell firefox monodoc-manual msbuild || true sudo apt-get remove -y '^google-.*' || true sudo apt-get remove -y azure-cli || true sudo apt-get remove -y '^mongo.*-.*|^postgresql-.*|^mysql-.*|^mssql-.*' || true sudo apt-get remove -y '^gfortran-.*' || true sudo apt-get autoremove -y sudo apt-get clean echo echo "Listing top largest packages" pkgs=$(dpkg-query -Wf '${Installed-Size}\t${Package}\t${Status}\n' | awk '$NF == "installed"{print $1 "\t" $2}' | sort -nr) head -n 30 <<< "${pkgs}" echo sudo rm -rfv build || true df -h - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Dependencies run: | # Install protoc curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \ rm protoc.zip go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af PATH="$PATH:$HOME/go/bin" make protogen-go - name: Test run: | PATH="$PATH:$HOME/go/bin" make backends/local-store backends/silero-vad backends/llama-cpp backends/whisper backends/piper backends/stablediffusion-ggml docker-build-e2e e2e-aio - name: Setup tmate session if tests fail if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.23 with: detached: true connect-timeout-seconds: 180 limit-access-to-actor: true tests-apple: runs-on: macos-latest strategy: matrix: go-version: ['1.25.x'] steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Setup Go ${{ matrix.go-version }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: false # You can test your matrix by printing the current Go version - name: Display Go version run: go version - name: Dependencies run: | brew install protobuf grpc make protoc-gen-go protoc-gen-go-grpc libomp llvm opus pip install --user --no-cache-dir grpcio-tools grpcio - name: Setup Node.js uses: actions/setup-node@v6 with: node-version: '22' - name: Build React UI run: make react-ui - name: Build llama-cpp-darwin run: | make protogen-go make backends/llama-cpp-darwin - name: Test run: | export C_INCLUDE_PATH=/usr/local/include export CPLUS_INCLUDE_PATH=/usr/local/include export CC=/opt/homebrew/opt/llvm/bin/clang # Used to run the newer GNUMake version from brew that supports --output-sync export PATH="/opt/homebrew/opt/make/libexec/gnubin:$PATH" PATH="$PATH:$HOME/go/bin" make protogen-go PATH="$PATH:$HOME/go/bin" BUILD_TYPE="GITHUB_CI_HAS_BROKEN_METAL" CMAKE_ARGS="-DGGML_F16C=OFF -DGGML_AVX512=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF" make --jobs 4 --output-sync=target test - name: Setup tmate session if tests fail if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.23 with: detached: true connect-timeout-seconds: 180 limit-access-to-actor: true ================================================ FILE: .github/workflows/tests-e2e.yml ================================================ --- name: 'E2E Backend Tests' on: pull_request: push: branches: - master tags: - '*' concurrency: group: ci-tests-e2e-backend-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: tests-e2e-backend: runs-on: ubuntu-latest strategy: matrix: go-version: ['1.25.x'] steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Setup Go ${{ matrix.go-version }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: false - name: Display Go version run: go version - name: Proto Dependencies run: | # Install protoc curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \ rm protoc.zip go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af PATH="$PATH:$HOME/go/bin" make protogen-go - name: Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential libopus-dev - name: Setup Node.js uses: actions/setup-node@v6 with: node-version: '22' - name: Build React UI run: make react-ui - name: Test Backend E2E run: | PATH="$PATH:$HOME/go/bin" make build-mock-backend test-e2e - name: Setup tmate session if tests fail if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.23 with: detached: true connect-timeout-seconds: 180 limit-access-to-actor: true ================================================ FILE: .github/workflows/tests-ui-e2e.yml ================================================ --- name: 'UI E2E Tests' on: pull_request: paths: - 'core/http/**' - 'tests/e2e-ui/**' - 'tests/e2e/mock-backend/**' push: branches: - master concurrency: group: ci-tests-ui-e2e-${{ github.head_ref || github.ref }}-${{ github.repository }} cancel-in-progress: true jobs: tests-ui-e2e: runs-on: ubuntu-latest strategy: matrix: go-version: ['1.26.x'] steps: - name: Clone uses: actions/checkout@v6 with: submodules: true - name: Setup Go ${{ matrix.go-version }} uses: actions/setup-go@v5 with: go-version: ${{ matrix.go-version }} cache: false - name: Setup Node.js uses: actions/setup-node@v6 with: node-version: '22' - name: Proto Dependencies run: | curl -L -s https://github.com/protocolbuffers/protobuf/releases/download/v26.1/protoc-26.1-linux-x86_64.zip -o protoc.zip && \ unzip -j -d /usr/local/bin protoc.zip bin/protoc && \ rm protoc.zip go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af - name: System Dependencies run: | sudo apt-get update sudo apt-get install -y build-essential libopus-dev - name: Build UI test server run: PATH="$PATH:$HOME/go/bin" make build-ui-test-server - name: Install Playwright working-directory: core/http/react-ui run: | npm install npx playwright install --with-deps chromium - name: Run Playwright tests working-directory: core/http/react-ui run: npx playwright test - name: Upload Playwright report if: ${{ failure() }} uses: actions/upload-artifact@v7 with: name: playwright-report path: core/http/react-ui/playwright-report/ retention-days: 7 - name: Setup tmate session if tests fail if: ${{ failure() }} uses: mxschmitt/action-tmate@v3.23 with: detached: true connect-timeout-seconds: 180 limit-access-to-actor: true ================================================ FILE: .github/workflows/update_swagger.yaml ================================================ name: Update swagger on: schedule: - cron: 0 20 * * * workflow_dispatch: jobs: swagger: if: github.repository == 'mudler/LocalAI' strategy: fail-fast: false runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - uses: actions/setup-go@v5 with: go-version: 'stable' - name: Dependencies run: | sudo apt-get update sudo apt-get install protobuf-compiler - run: | go install github.com/swaggo/swag/cmd/swag@latest go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 - name: Bump swagger 🔧 run: | make protogen-go swagger - name: Create Pull Request uses: peter-evans/create-pull-request@v8 with: token: ${{ secrets.UPDATE_BOT_TOKEN }} push-to-fork: ci-forks/LocalAI commit-message: 'feat(swagger): update swagger' title: 'feat(swagger): update swagger' branch: "update/swagger" body: Update swagger signoff: true ================================================ FILE: .github/workflows/yaml-check.yml ================================================ name: 'Yamllint GitHub Actions' on: - pull_request jobs: yamllint: name: 'Yamllint' runs-on: ubuntu-latest steps: - name: 'Checkout' uses: actions/checkout@master - name: 'Yamllint model gallery' uses: karancode/yamllint-github-action@master with: yamllint_file_or_dir: 'gallery' yamllint_strict: false yamllint_comment: true env: GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: 'Yamllint Backend gallery' uses: karancode/yamllint-github-action@master with: yamllint_file_or_dir: 'backend' yamllint_strict: false yamllint_comment: true env: GITHUB_ACCESS_TOKEN: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .gitignore ================================================ # go-llama build artifacts /sources/ __pycache__/ *.a *.o get-sources prepare-sources /backend/cpp/llama-cpp/grpc-server /backend/cpp/llama-cpp/llama.cpp /backend/cpp/llama-* !backend/cpp/llama-cpp /backends /backend-images /result.yaml protoc *.log go-ggml-transformers go-gpt2 whisper.cpp /bloomz go-bert # LocalAI build binary LocalAI /local-ai /local-ai-launcher # prevent above rules from omitting the helm chart !charts/* # prevent above rules from omitting the api/localai folder !api/localai !core/**/localai # Ignore models models/* test-models/ test-dir/ tests/e2e-aio/backends mock-backend release/ # just in case .DS_Store .idea # Generated during build backend-assets/* !backend-assets/.keep prepare /ggml-metal.metal docs/static/gallery.html # Protobuf generated files *.pb.go *pb2.py *pb2_grpc.py # SonarQube .scannerwork # backend virtual environments **/venv # per-developer customization files for the development container .devcontainer/customization/* # React UI build artifacts (keep placeholder dist/index.html) core/http/react-ui/node_modules/ core/http/react-ui/dist # Extracted backend binaries for container-based testing local-backends/ # UI E2E test artifacts tests/e2e-ui/ui-test-server core/http/react-ui/playwright-report/ core/http/react-ui/test-results/ ================================================ FILE: .gitmodules ================================================ [submodule "docs/themes/hugo-theme-relearn"] path = docs/themes/hugo-theme-relearn url = https://github.com/McShelby/hugo-theme-relearn.git ================================================ FILE: .goreleaser.yaml ================================================ version: 2 before: hooks: - make protogen-go - make react-ui - go mod tidy dist: release source: enabled: true name_template: '{{ .ProjectName }}-{{ .Tag }}-source' builds: - main: ./cmd/local-ai env: - CGO_ENABLED=0 ldflags: - -s -w - -X "github.com/mudler/LocalAI/internal.Version={{ .Tag }}" - -X "github.com/mudler/LocalAI/internal.Commit={{ .FullCommit }}" goos: - linux - darwin #- windows goarch: - amd64 - arm64 ignore: - goos: darwin goarch: amd64 archives: - formats: [ 'binary' ] # this removes the tar of the archives, leaving the binaries alone name_template: local-ai-{{ .Tag }}-{{ .Os }}-{{ .Arch }}{{ if .Arm }}v{{ .Arm }}{{ end }} checksum: name_template: '{{ .ProjectName }}-{{ .Tag }}-checksums.txt' snapshot: version_template: "{{ .Tag }}-next" changelog: use: github-native ================================================ FILE: .vscode/extensions.json ================================================ { "recommendations": [ "golang.go" ] } ================================================ FILE: .vscode/launch.json ================================================ { "version": "0.2.0", "configurations": [ { "name": "Python: Current File", "type": "debugpy", "request": "launch", "program": "${file}", "console": "integratedTerminal", "justMyCode": false, "cwd": "${fileDirname}", "env": { "OPENAI_API_BASE": "http://localhost:8080/v1", "OPENAI_API_KEY": "abc" } }, { "name": "Launch LocalAI API", "type": "go", "request": "launch", "mode": "debug", "program": "${workspaceRoot}", "args": [], "env": { "LOCALAI_LOG_LEVEL": "debug", "LOCALAI_P2P": "true", "LOCALAI_FEDERATED": "true" }, "buildFlags": ["-tags", "", "-v"], "envFile": "${workspaceFolder}/.env", "cwd": "${workspaceRoot}" } ] } ================================================ FILE: .yamllint ================================================ extends: default rules: line-length: disable ================================================ FILE: AGENTS.md ================================================ # LocalAI Agent Instructions This file is an index to detailed topic guides in the `.agents/` directory. Read the relevant file(s) for the task at hand — you don't need to load all of them. ## Topics | File | When to read | |------|-------------| | [.agents/building-and-testing.md](.agents/building-and-testing.md) | Building the project, running tests, Docker builds for specific platforms | | [.agents/adding-backends.md](.agents/adding-backends.md) | Adding a new backend (Python, Go, or C++) — full step-by-step checklist | | [.agents/coding-style.md](.agents/coding-style.md) | Code style, editorconfig, logging, documentation conventions | | [.agents/llama-cpp-backend.md](.agents/llama-cpp-backend.md) | Working on the llama.cpp backend — architecture, updating, tool call parsing | | [.agents/testing-mcp-apps.md](.agents/testing-mcp-apps.md) | Testing MCP Apps (interactive tool UIs) in the React UI | | [.agents/api-endpoints-and-auth.md](.agents/api-endpoints-and-auth.md) | Adding API endpoints, auth middleware, feature permissions, user access control | ## Quick Reference - **Logging**: Use `github.com/mudler/xlog` (same API as slog) - **Go style**: Prefer `any` over `interface{}` - **Comments**: Explain *why*, not *what* - **Docs**: Update `docs/content/` when adding features or changing config - **Build**: Inspect `Makefile` and `.github/workflows/` — ask the user before running long builds - **UI**: The active UI is the React app in `core/http/react-ui/`. The older Alpine.js/HTML UI in `core/http/static/` is pending deprecation — all new UI work goes in the React UI ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to LocalAI Thank you for your interest in contributing to LocalAI! We appreciate your time and effort in helping to improve our project. Before you get started, please take a moment to review these guidelines. ## Table of Contents - [Getting Started](#getting-started) - [Prerequisites](#prerequisites) - [Setting up the Development Environment](#setting-up-the-development-environment) - [Environment Variables](#environment-variables) - [Contributing](#contributing) - [Submitting an Issue](#submitting-an-issue) - [Development Workflow](#development-workflow) - [Creating a Pull Request (PR)](#creating-a-pull-request-pr) - [Coding Guidelines](#coding-guidelines) - [Testing](#testing) - [Documentation](#documentation) - [Community and Communication](#community-and-communication) ## Getting Started ### Prerequisites - **Go 1.21+** (the project currently uses Go 1.26 in `go.mod`, but 1.21 is the minimum supported version) - [Download Go](https://go.dev/dl/) or install via your package manager - macOS: `brew install go` - Ubuntu/Debian: follow the [official instructions](https://go.dev/doc/install) (the `apt` version is often outdated) - Verify: `go version` - **Git** - **GNU Make** - **GCC / C/C++ toolchain** (required for CGo and native backends) - **Protocol Buffers compiler** (`protoc`) — needed for gRPC code generation #### System dependencies by platform
Ubuntu / Debian ```bash sudo apt-get update sudo apt-get install -y build-essential gcc g++ cmake git wget \ protobuf-compiler libprotobuf-dev pkg-config \ libopencv-dev libgrpc-dev ```
CentOS / RHEL / Fedora ```bash sudo dnf groupinstall -y "Development Tools" sudo dnf install -y cmake git wget protobuf-compiler protobuf-devel \ opencv-devel grpc-devel ```
macOS ```bash xcode-select --install brew install cmake git protobuf grpc opencv wget ```
Windows Use [WSL 2](https://learn.microsoft.com/en-us/windows/wsl/install) with an Ubuntu distribution, then follow the Ubuntu instructions above.
### Setting up the Development Environment 1. **Clone the repository:** ```bash git clone https://github.com/mudler/LocalAI.git cd LocalAI ``` 2. **Build LocalAI:** ```bash make build ``` This runs protobuf generation, installs Go tools, builds the React UI, and compiles the `local-ai` binary. Key build variables you can set: | Variable | Description | Example | |---|---|---| | `BUILD_TYPE` | GPU/accelerator type (`cublas`, `hipblas`, `intel`, ``) | `BUILD_TYPE=cublas make build` | | `GO_TAGS` | Additional Go build tags | `GO_TAGS=debug make build` | | `CUDA_MAJOR_VERSION` | CUDA major version (default: `13`) | `CUDA_MAJOR_VERSION=12` | 3. **Run LocalAI:** ```bash ./local-ai ``` 4. **Development mode with live reload:** ```bash make build-dev ``` This installs [`air`](https://github.com/air-verse/air) automatically and watches for file changes, rebuilding and restarting the server on each save. 5. **Containerized build** (no local toolchain needed): ```bash make docker ``` For GPU-specific Docker builds, see the `docker-build-*` targets in the Makefile and refer to [CLAUDE.md](CLAUDE.md) for detailed backend build instructions. ### Environment Variables LocalAI is configured primarily through environment variables (or equivalent CLI flags). The most useful ones for development are: | Variable | Description | Default | |---|---|---| | `LOCALAI_DEBUG` | Enable debug mode | `false` | | `LOCALAI_LOG_LEVEL` | Log verbosity (`error`, `warn`, `info`, `debug`, `trace`) | — | | `LOCALAI_LOG_FORMAT` | Log format (`default`, `text`, `json`) | `default` | | `LOCALAI_MODELS_PATH` | Path to model files | `./models` | | `LOCALAI_BACKENDS_PATH` | Path to backend binaries | `./backends` | | `LOCALAI_CONFIG_DIR` | Directory for dynamic config files (API keys, external backends) | `./configuration` | | `LOCALAI_THREADS` | Number of threads for inference | — | | `LOCALAI_ADDRESS` | Bind address for the API server | `:8080` | | `LOCALAI_API_KEY` | API key(s) for authentication | — | | `LOCALAI_CORS` | Enable CORS | `false` | | `LOCALAI_DISABLE_WEBUI` | Disable the web UI | `false` | See `core/cli/run.go` for the full list of supported environment variables. ## Contributing We welcome contributions from everyone! To get started, follow these steps: ### Submitting an Issue If you find a bug, have a feature request, or encounter any issues, please check the [issue tracker](https://github.com/go-skynet/LocalAI/issues) to see if a similar issue has already been reported. If not, feel free to [create a new issue](https://github.com/go-skynet/LocalAI/issues/new) and provide as much detail as possible. ### Development Workflow #### Branch naming conventions Use a descriptive branch name that indicates the type and scope of the change: - `feature/` — new functionality - `fix/` — bug fixes - `docs/` — documentation changes - `refactor/` — code refactoring #### Commit messages - Use a short, imperative subject line (e.g., "feat: add whisper backend support", not "Added whisper backend support") - Keep the subject under 72 characters - Use the body to explain **why** the change was made when the subject alone is not sufficient - Use [conventional commits](https://www.conventionalcommits.org/en/v1.0.0/) #### Creating a Pull Request (PR) Before jumping into a PR for a massive feature or big change, it is preferred to discuss it first via an issue. 1. Fork the repository. 2. Create a new branch: `git checkout -b feature/my-change` 3. Make your changes, keeping commits focused and atomic. 4. Run tests locally before pushing (see [Testing](#testing) below). 5. Push to your fork: `git push origin feature/my-change` 6. Open a pull request against the `master` branch. 7. Fill in the PR description with: - What the change does and why - How it was tested - Any breaking changes or migration steps 8. Respond to review feedback promptly. Push follow-up commits rather than force-pushing amended commits so reviewers can see incremental changes. 9. Once approved, a maintainer will merge your PR. ## Coding Guidelines This project uses an [`.editorconfig`](.editorconfig) file to define formatting standards (indentation, line endings, charset, etc.). Please configure your editor to respect it. For AI-assisted development, see [`CLAUDE.md`](CLAUDE.md) for agent-specific guidelines including build instructions and backend architecture details. ### General Principles - Write code that can be tested. All new features and bug fixes should include test coverage. - Use comments sparingly to explain **why** code does something, not **what** it does. Comments should add context that would be difficult to deduce from reading the code alone. - Keep changes focused. Avoid unrelated refactors, formatting changes, or feature additions in the same PR. ### Go Code - Prefer modern Go idioms — for example, use `any` instead of `interface{}`. - Use [`golangci-lint`](https://golangci-lint.run) to catch common issues before submitting a PR. - Use [`github.com/mudler/xlog`](https://github.com/mudler/xlog) for logging (same API as `slog`). Do not use `fmt.Println` or the standard `log` package for operational logging. - Use tab indentation for Go files (as defined in `.editorconfig`). ### Python Code - Use 4-space indentation (as defined in `.editorconfig`). - Include a `requirements.txt` for any new dependencies. ### Code Review - All contributions go through code review via pull requests. - Reviewers will check for correctness, test coverage, adherence to these guidelines, and clarity of intent. - Be responsive to review feedback and keep discussions constructive. ## Testing All new features and bug fixes should include test coverage. The project uses [Ginkgo](https://onsi.github.io/ginkgo/) as its test framework. ### Running unit tests ```bash make test ``` This downloads test model fixtures, runs protobuf generation, and executes the full test suite including llama-gguf, TTS, and stable-diffusion tests. Note: some tests require model files to be downloaded, so the first run may take longer. To run tests for a specific package: ```bash go test ./core/config/... go test ./pkg/model/... ``` To run a specific test by name using Ginkgo's `--focus` flag: ```bash go run github.com/onsi/ginkgo/v2/ginkgo --focus="should load a model" -v -r ./core/ ``` ### Running end-to-end tests The e2e tests run LocalAI in a Docker container and exercise the API: ```bash make test-e2e ``` ### Running E2E container tests These tests build a standard LocalAI Docker image and run it with pre-configured model configs to verify that most endpoints work correctly: ```bash # Build the LocalAI docker image make docker-build-e2e # Run the e2e tests (uses model configs from tests/e2e-aio/models/) make e2e-aio ``` ### Testing backends To prepare and test extra (Python) backends: ```bash make prepare-test-extra # build Python backends for testing make test-extra # run backend-specific tests ``` ## Documentation We welcome contributions to the documentation. Please open a new PR or create a new issue. The documentation is available under `docs/` https://github.com/mudler/LocalAI/tree/master/docs ### Gallery YAML Schema LocalAI provides a JSON Schema for gallery model YAML files at: `core/schema/gallery-model.schema.json` This schema mirrors the internal gallery model configuration and can be used by editors (such as VS Code) to enable autocomplete, validation, and inline documentation when creating or modifying gallery files. To use it with the YAML language server, add the following comment at the top of a gallery YAML file: ```yaml # yaml-language-server: $schema=../core/schema/gallery-model.schema.json ``` ## Community and Communication - You can reach out via the Github issue tracker. - Open a new discussion at [Discussion](https://github.com/go-skynet/LocalAI/discussions) - Join the Discord channel [Discord](https://discord.gg/uJAeKSAGDy) ================================================ FILE: Dockerfile ================================================ ARG BASE_IMAGE=ubuntu:24.04 ARG GRPC_BASE_IMAGE=${BASE_IMAGE} ARG INTEL_BASE_IMAGE=${BASE_IMAGE} ARG UBUNTU_CODENAME=noble FROM ${BASE_IMAGE} AS requirements ENV DEBIAN_FRONTEND=noninteractive RUN apt-get update && \ apt-get install -y --no-install-recommends \ ca-certificates curl wget espeak-ng libgomp1 \ ffmpeg libopenblas0 libopenblas-dev libopus0 sox && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* # The requirements-drivers target is for BUILD_TYPE specific items. If you need to install something specific to CUDA, or specific to ROCM, it goes here. FROM requirements AS requirements-drivers ARG BUILD_TYPE ARG CUDA_MAJOR_VERSION=12 ARG CUDA_MINOR_VERSION=0 ARG SKIP_DRIVERS=false ARG TARGETARCH ARG TARGETVARIANT ENV BUILD_TYPE=${BUILD_TYPE} ARG UBUNTU_VERSION=2404 RUN mkdir -p /run/localai RUN echo "default" > /run/localai/capability # Vulkan requirements RUN < /run/localai/capability fi EOT # CuBLAS requirements RUN < /run/localai/capability fi EOT RUN < /run/localai/capability fi EOT # https://github.com/NVIDIA/Isaac-GR00T/issues/343 RUN < /run/localai/capability && \ # I have no idea why, but the ROCM lib packages don't trigger ldconfig after they install, which results in local-ai and others not being able # to locate the libraries. We run ldconfig ourselves to work around this packaging deficiency ldconfig \ ; fi RUN if [ "${BUILD_TYPE}" = "hipblas" ]; then \ ln -s /opt/rocm-**/lib/llvm/lib/libomp.so /usr/lib/libomp.so \ ; fi RUN expr "${BUILD_TYPE}" = intel && echo "intel" > /run/localai/capability || echo "not intel" # Cuda ENV PATH=/usr/local/cuda/bin:${PATH} # HipBLAS requirements ENV PATH=/opt/rocm/bin:${PATH} ################################### ################################### # The requirements-core target is common to all images. It should not be placed in requirements-core unless every single build will use it. FROM requirements-drivers AS build-requirements ARG GO_VERSION=1.25.4 ARG CMAKE_VERSION=3.31.10 ARG CMAKE_FROM_SOURCE=false ARG TARGETARCH ARG TARGETVARIANT RUN apt-get update && \ apt-get install -y --no-install-recommends \ build-essential \ ccache \ ca-certificates espeak-ng \ curl libssl-dev \ git \ git-lfs \ libopus-dev pkg-config \ unzip upx-ucl python3 python-is-python3 && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* # Install CMake (the version in 22.04 is too old) RUN < /etc/apt/sources.list.d/intel-graphics.list RUN apt-get update && \ apt-get install -y --no-install-recommends \ intel-oneapi-runtime-libs && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* ################################### ################################### # The builder-base target has the arguments, variables, and copies shared between full builder images and the uncompiled devcontainer FROM build-requirements AS builder-base ARG GO_TAGS="auth" ARG GRPC_BACKENDS ARG MAKEFLAGS ARG LD_FLAGS="-s -w" ARG TARGETARCH ARG TARGETVARIANT ENV GRPC_BACKENDS=${GRPC_BACKENDS} ENV GO_TAGS=${GO_TAGS} ENV MAKEFLAGS=${MAKEFLAGS} ENV NVIDIA_DRIVER_CAPABILITIES=compute,utility ENV NVIDIA_REQUIRE_CUDA="cuda>=${CUDA_MAJOR_VERSION}.0" ENV NVIDIA_VISIBLE_DEVICES=all ENV LD_FLAGS=${LD_FLAGS} RUN echo "GO_TAGS: $GO_TAGS" && echo "TARGETARCH: $TARGETARCH" WORKDIR /build # We need protoc installed, and the version in 22.04 is too old. RUN < com.apple.security.network.client com.apple.security.network.server ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2023-2025 Ettore Di Giacinto (mudler@localai.io) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: Makefile ================================================ # Disable parallel execution for backend builds .NOTPARALLEL: backends/diffusers backends/llama-cpp backends/outetts backends/piper backends/stablediffusion-ggml backends/whisper backends/faster-whisper backends/silero-vad backends/local-store backends/huggingface backends/rfdetr backends/kitten-tts backends/kokoro backends/chatterbox backends/llama-cpp-darwin backends/neutts build-darwin-python-backend build-darwin-go-backend backends/mlx backends/diffuser-darwin backends/mlx-vlm backends/mlx-audio backends/mlx-distributed backends/stablediffusion-ggml-darwin backends/vllm backends/vllm-omni backends/moonshine backends/pocket-tts backends/qwen-tts backends/faster-qwen3-tts backends/qwen-asr backends/nemo backends/voxcpm backends/whisperx backends/ace-step backends/acestep-cpp backends/fish-speech backends/voxtral backends/opus GOCMD=go GOTEST=$(GOCMD) test GOVET=$(GOCMD) vet BINARY_NAME=local-ai LAUNCHER_BINARY_NAME=local-ai-launcher UBUNTU_VERSION?=2404 UBUNTU_CODENAME?=noble GORELEASER?= export BUILD_TYPE?= export CUDA_MAJOR_VERSION?=13 export CUDA_MINOR_VERSION?=0 GO_TAGS?= BUILD_ID?= NATIVE?=false TEST_DIR=/tmp/test TEST_FLAKES?=5 RANDOM := $(shell bash -c 'echo $$RANDOM') VERSION?=$(shell git describe --always --tags || echo "dev" ) # go tool nm ./local-ai | grep Commit LD_FLAGS?=-s -w override LD_FLAGS += -X "github.com/mudler/LocalAI/internal.Version=$(VERSION)" override LD_FLAGS += -X "github.com/mudler/LocalAI/internal.Commit=$(shell git rev-parse HEAD)" OPTIONAL_TARGETS?= export OS := $(shell uname -s) ARCH := $(shell uname -m) GREEN := $(shell tput -Txterm setaf 2) YELLOW := $(shell tput -Txterm setaf 3) WHITE := $(shell tput -Txterm setaf 7) CYAN := $(shell tput -Txterm setaf 6) RESET := $(shell tput -Txterm sgr0) # Default Docker bridge IP E2E_BRIDGE_IP?=172.17.0.1 ifndef UNAME_S UNAME_S := $(shell uname -s) endif ifeq ($(OS),Darwin) ifeq ($(OSX_SIGNING_IDENTITY),) OSX_SIGNING_IDENTITY := $(shell security find-identity -v -p codesigning | grep '"' | head -n 1 | sed -E 's/.*"(.*)"/\1/') endif endif # check if goreleaser exists ifeq (, $(shell which goreleaser)) GORELEASER=curl -sfL https://goreleaser.com/static/run | bash -s -- else GORELEASER=$(shell which goreleaser) endif TEST_PATHS?=./api/... ./pkg/... ./core/... .PHONY: all test build vendor all: help ## GENERIC rebuild: ## Rebuilds the project $(GOCMD) clean -cache $(MAKE) build clean: ## Remove build related file $(GOCMD) clean -cache rm -f prepare rm -rf $(BINARY_NAME) rm -rf release/ $(MAKE) protogen-clean rmdir pkg/grpc/proto || true clean-tests: rm -rf test-models rm -rf test-dir ## Install Go tools install-go-tools: go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@1958fcbe2ca8bd93af633f11e97d44e567e945af go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.34.2 ## React UI: react-ui: ifneq ($(wildcard core/http/react-ui/dist),) @echo "react-ui dist already exists, skipping build" else cd core/http/react-ui && npm install && npm run build endif react-ui-docker: docker run --entrypoint /bin/bash -v $(CURDIR):/app:z oven/bun:1 \ -c "cd /app/core/http/react-ui && bun install && bun run build" core/http/react-ui/dist: react-ui ## Build: build: protogen-go install-go-tools core/http/react-ui/dist ## Build the project $(info ${GREEN}I local-ai build info:${RESET}) $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) $(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET}) $(info ${GREEN}I UPX: ${YELLOW}$(UPX)${RESET}) rm -rf $(BINARY_NAME) || true CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(BINARY_NAME) ./cmd/local-ai build-launcher: ## Build the launcher application $(info ${GREEN}I local-ai launcher build info:${RESET}) $(info ${GREEN}I BUILD_TYPE: ${YELLOW}$(BUILD_TYPE)${RESET}) $(info ${GREEN}I GO_TAGS: ${YELLOW}$(GO_TAGS)${RESET}) $(info ${GREEN}I LD_FLAGS: ${YELLOW}$(LD_FLAGS)${RESET}) rm -rf $(LAUNCHER_BINARY_NAME) || true CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o $(LAUNCHER_BINARY_NAME) ./cmd/launcher build-all: build build-launcher ## Build both server and launcher build-dev: ## Run LocalAI in dev mode with live reload @command -v air >/dev/null 2>&1 || go install github.com/air-verse/air@latest air -c .air.toml dev-dist: $(GORELEASER) build --snapshot --clean dist: $(GORELEASER) build --clean osx-signed: build codesign --deep --force --sign "$(OSX_SIGNING_IDENTITY)" --entitlements "./Entitlements.plist" "./$(BINARY_NAME)" ## Run run: ## run local-ai CGO_LDFLAGS="$(CGO_LDFLAGS)" $(GOCMD) run ./ test-models/testmodel.ggml: mkdir -p test-models mkdir -p test-dir wget -q https://huggingface.co/mradermacher/gpt2-alpaca-gpt4-GGUF/resolve/main/gpt2-alpaca-gpt4.Q4_K_M.gguf -O test-models/testmodel.ggml wget -q https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-base.en.bin -O test-models/whisper-en wget -q https://huggingface.co/mudler/all-MiniLM-L6-v2/resolve/main/ggml-model-q4_0.bin -O test-models/bert wget -q https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav -O test-dir/audio.wav cp tests/models_fixtures/* test-models prepare-test: protogen-go cp tests/models_fixtures/* test-models ######################################################## ## Tests ######################################################## ## Test targets test: test-models/testmodel.ggml protogen-go @echo 'Running tests' export GO_TAGS="debug" $(MAKE) prepare-test OPUS_SHIM_LIBRARY=$(abspath ./pkg/opus/shim/libopusshim.so) \ HUGGINGFACE_GRPC=$(abspath ./)/backend/python/transformers/run.sh TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="!llama-gguf" --flake-attempts $(TEST_FLAKES) --fail-fast -v -r $(TEST_PATHS) $(MAKE) test-llama-gguf $(MAKE) test-tts $(MAKE) test-stablediffusion ######################################################## ## E2E AIO tests (uses standard image with pre-configured models) ######################################################## docker-build-e2e: docker build \ --build-arg MAKEFLAGS="--jobs=5 --output-sync=target" \ --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \ --build-arg BUILD_TYPE=$(BUILD_TYPE) \ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \ --build-arg GO_TAGS="$(GO_TAGS)" \ -t local-ai:tests -f Dockerfile . e2e-aio: LOCALAI_BACKEND_DIR=$(abspath ./backends) \ LOCALAI_MODELS_DIR=$(abspath ./tests/e2e-aio/models) \ LOCALAI_IMAGE_TAG=tests \ LOCALAI_IMAGE=local-ai \ $(MAKE) run-e2e-aio run-e2e-aio: protogen-go @echo 'Running e2e AIO tests' $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e-aio ######################################################## ## E2E tests ######################################################## prepare-e2e: docker build \ --build-arg IMAGE_TYPE=core \ --build-arg BUILD_TYPE=$(BUILD_TYPE) \ --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \ --build-arg GO_TAGS="$(GO_TAGS)" \ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \ -t localai-tests . run-e2e-image: docker run -p 5390:8080 -e MODELS_PATH=/models -e THREADS=1 -e DEBUG=true -d --rm -v $(TEST_DIR):/models --name e2e-tests-$(RANDOM) localai-tests test-e2e: build-mock-backend prepare-e2e run-e2e-image @echo 'Running e2e tests' BUILD_TYPE=$(BUILD_TYPE) \ LOCALAI_API=http://$(E2E_BRIDGE_IP):5390 \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e $(MAKE) clean-mock-backend $(MAKE) teardown-e2e docker rmi localai-tests teardown-e2e: rm -rf $(TEST_DIR) || true docker stop $$(docker ps -q --filter ancestor=localai-tests) ######################################################## ## Integration and unit tests ######################################################## test-llama-gguf: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="llama-gguf" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) test-tts: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="tts" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) test-stablediffusion: prepare-test TEST_DIR=$(abspath ./)/test-dir/ FIXTURES=$(abspath ./)/tests/fixtures CONFIG_FILE=$(abspath ./)/test-models/config.yaml MODELS_PATH=$(abspath ./)/test-models BACKENDS_PATH=$(abspath ./)/backends \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stablediffusion" --flake-attempts $(TEST_FLAKES) -v -r $(TEST_PATHS) test-stores: $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="stores" --flake-attempts $(TEST_FLAKES) -v -r tests/integration test-opus: @echo 'Running opus backend tests' $(MAKE) -C backend/go/opus libopusshim.so $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./backend/go/opus/... test-opus-docker: @echo 'Running opus backend tests in Docker' docker build --target builder \ --build-arg BUILD_TYPE=$(or $(BUILD_TYPE),) \ --build-arg BASE_IMAGE=$(or $(BASE_IMAGE),ubuntu:24.04) \ --build-arg BACKEND=opus \ -t localai-opus-test -f backend/Dockerfile.golang . docker run --rm localai-opus-test \ bash -c 'cd /LocalAI && go run github.com/onsi/ginkgo/v2/ginkgo --flake-attempts $(TEST_FLAKES) -v -r ./backend/go/opus/...' test-realtime: build-mock-backend @echo 'Running realtime e2e tests (mock backend)' $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime && !real-models" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e # Real-model realtime tests. Set REALTIME_TEST_MODEL to use your own pipeline, # or leave unset to auto-build one from the component env vars below. REALTIME_VAD?=silero-vad-ggml REALTIME_STT?=whisper-1 REALTIME_LLM?=qwen3-0.6b REALTIME_TTS?=tts-1 REALTIME_BACKENDS_PATH?=$(abspath ./)/backends test-realtime-models: build-mock-backend @echo 'Running realtime e2e tests (real models)' REALTIME_TEST_MODEL=$${REALTIME_TEST_MODEL:-realtime-test-pipeline} \ REALTIME_VAD=$(REALTIME_VAD) \ REALTIME_STT=$(REALTIME_STT) \ REALTIME_LLM=$(REALTIME_LLM) \ REALTIME_TTS=$(REALTIME_TTS) \ REALTIME_BACKENDS_PATH=$(REALTIME_BACKENDS_PATH) \ $(GOCMD) run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e # --- Container-based real-model testing --- REALTIME_BACKEND_NAMES ?= silero-vad whisper llama-cpp kokoro REALTIME_MODELS_DIR ?= $(abspath ./models) REALTIME_BACKENDS_DIR ?= $(abspath ./local-backends) REALTIME_DOCKER_FLAGS ?= --gpus all local-backends: mkdir -p local-backends extract-backend-%: docker-build-% local-backends @echo "Extracting backend $*..." @CID=$$(docker create local-ai-backend:$*) && \ rm -rf local-backends/$* && mkdir -p local-backends/$* && \ docker cp $$CID:/ - | tar -xf - -C local-backends/$* && \ docker rm $$CID > /dev/null extract-realtime-backends: $(addprefix extract-backend-,$(REALTIME_BACKEND_NAMES)) test-realtime-models-docker: build-mock-backend docker build --target build-requirements \ --build-arg BUILD_TYPE=$(or $(BUILD_TYPE),cublas) \ --build-arg CUDA_MAJOR_VERSION=$(or $(CUDA_MAJOR_VERSION),13) \ --build-arg CUDA_MINOR_VERSION=$(or $(CUDA_MINOR_VERSION),0) \ -t localai-test-runner . docker run --rm \ $(REALTIME_DOCKER_FLAGS) \ -v $(abspath ./):/build \ -v $(REALTIME_MODELS_DIR):/models:ro \ -v $(REALTIME_BACKENDS_DIR):/backends \ -v localai-go-cache:/root/go/pkg/mod \ -v localai-go-build-cache:/root/.cache/go-build \ -e REALTIME_TEST_MODEL=$${REALTIME_TEST_MODEL:-realtime-test-pipeline} \ -e REALTIME_VAD=$(REALTIME_VAD) \ -e REALTIME_STT=$(REALTIME_STT) \ -e REALTIME_LLM=$(REALTIME_LLM) \ -e REALTIME_TTS=$(REALTIME_TTS) \ -e REALTIME_BACKENDS_PATH=/backends \ -e REALTIME_MODELS_PATH=/models \ -w /build \ localai-test-runner \ bash -c 'git config --global --add safe.directory /build && \ make protogen-go && make build-mock-backend && \ go run github.com/onsi/ginkgo/v2/ginkgo --label-filter="Realtime" --flake-attempts $(TEST_FLAKES) -v -r ./tests/e2e' test-container: docker build --target requirements -t local-ai-test-container . docker run -ti --rm --entrypoint /bin/bash -ti -v $(abspath ./):/build local-ai-test-container ######################################################## ## Help ######################################################## ## Help: help: ## Show this help. @echo '' @echo 'Usage:' @echo ' ${YELLOW}make${RESET} ${GREEN}${RESET}' @echo '' @echo 'Targets:' @awk 'BEGIN {FS = ":.*?## "} { \ if (/^[a-zA-Z_-]+:.*?##.*$$/) {printf " ${YELLOW}%-20s${GREEN}%s${RESET}\n", $$1, $$2} \ else if (/^## .*$$/) {printf " ${CYAN}%s${RESET}\n", substr($$1,4)} \ }' $(MAKEFILE_LIST) ######################################################## ## Backends ######################################################## .PHONY: protogen protogen: protogen-go protoc: @OS_NAME=$$(uname -s | tr '[:upper:]' '[:lower:]'); \ ARCH_NAME=$$(uname -m); \ if [ "$$OS_NAME" = "darwin" ]; then \ if [ "$$ARCH_NAME" = "arm64" ]; then \ FILE=protoc-31.1-osx-aarch_64.zip; \ elif [ "$$ARCH_NAME" = "x86_64" ]; then \ FILE=protoc-31.1-osx-x86_64.zip; \ else \ echo "Unsupported macOS architecture: $$ARCH_NAME"; exit 1; \ fi; \ elif [ "$$OS_NAME" = "linux" ]; then \ if [ "$$ARCH_NAME" = "x86_64" ]; then \ FILE=protoc-31.1-linux-x86_64.zip; \ elif [ "$$ARCH_NAME" = "aarch64" ] || [ "$$ARCH_NAME" = "arm64" ]; then \ FILE=protoc-31.1-linux-aarch_64.zip; \ elif [ "$$ARCH_NAME" = "ppc64le" ]; then \ FILE=protoc-31.1-linux-ppcle_64.zip; \ elif [ "$$ARCH_NAME" = "s390x" ]; then \ FILE=protoc-31.1-linux-s390_64.zip; \ elif [ "$$ARCH_NAME" = "i386" ] || [ "$$ARCH_NAME" = "x86" ]; then \ FILE=protoc-31.1-linux-x86_32.zip; \ else \ echo "Unsupported Linux architecture: $$ARCH_NAME"; exit 1; \ fi; \ else \ echo "Unsupported OS: $$OS_NAME"; exit 1; \ fi; \ URL=https://github.com/protocolbuffers/protobuf/releases/download/v31.1/$$FILE; \ curl -L $$URL -o protoc.zip && \ unzip -j -d $(CURDIR) protoc.zip bin/protoc && rm protoc.zip .PHONY: protogen-go protogen-go: protoc install-go-tools mkdir -p pkg/grpc/proto ./protoc --experimental_allow_proto3_optional -Ibackend/ --go_out=pkg/grpc/proto/ --go_opt=paths=source_relative --go-grpc_out=pkg/grpc/proto/ --go-grpc_opt=paths=source_relative \ backend/backend.proto .PHONY: protogen-go-clean protogen-go-clean: $(RM) pkg/grpc/proto/backend.pb.go pkg/grpc/proto/backend_grpc.pb.go $(RM) bin/* prepare-test-extra: protogen-python $(MAKE) -C backend/python/transformers $(MAKE) -C backend/python/outetts $(MAKE) -C backend/python/diffusers $(MAKE) -C backend/python/chatterbox $(MAKE) -C backend/python/vllm $(MAKE) -C backend/python/vllm-omni $(MAKE) -C backend/python/vibevoice $(MAKE) -C backend/python/moonshine $(MAKE) -C backend/python/pocket-tts $(MAKE) -C backend/python/qwen-tts $(MAKE) -C backend/python/fish-speech $(MAKE) -C backend/python/faster-qwen3-tts $(MAKE) -C backend/python/qwen-asr $(MAKE) -C backend/python/nemo $(MAKE) -C backend/python/voxcpm $(MAKE) -C backend/python/whisperx $(MAKE) -C backend/python/ace-step test-extra: prepare-test-extra $(MAKE) -C backend/python/transformers test $(MAKE) -C backend/python/outetts test $(MAKE) -C backend/python/diffusers test $(MAKE) -C backend/python/chatterbox test $(MAKE) -C backend/python/vllm test $(MAKE) -C backend/python/vllm-omni test $(MAKE) -C backend/python/vibevoice test $(MAKE) -C backend/python/moonshine test $(MAKE) -C backend/python/pocket-tts test $(MAKE) -C backend/python/qwen-tts test $(MAKE) -C backend/python/fish-speech test $(MAKE) -C backend/python/faster-qwen3-tts test $(MAKE) -C backend/python/qwen-asr test $(MAKE) -C backend/python/nemo test $(MAKE) -C backend/python/voxcpm test $(MAKE) -C backend/python/whisperx test $(MAKE) -C backend/python/ace-step test DOCKER_IMAGE?=local-ai IMAGE_TYPE?=core BASE_IMAGE?=ubuntu:24.04 docker: docker build \ --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \ --build-arg GO_TAGS="$(GO_TAGS)" \ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \ --build-arg BUILD_TYPE=$(BUILD_TYPE) \ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \ -t $(DOCKER_IMAGE) . docker-cuda12: docker build \ --build-arg CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION} \ --build-arg CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION} \ --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \ --build-arg GO_TAGS="$(GO_TAGS)" \ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \ --build-arg BUILD_TYPE=$(BUILD_TYPE) \ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \ -t $(DOCKER_IMAGE)-cuda-12 . docker-image-intel: docker build \ --build-arg BASE_IMAGE=intel/oneapi-basekit:2025.3.0-0-devel-ubuntu24.04 \ --build-arg IMAGE_TYPE=$(IMAGE_TYPE) \ --build-arg GO_TAGS="$(GO_TAGS)" \ --build-arg MAKEFLAGS="$(DOCKER_MAKEFLAGS)" \ --build-arg BUILD_TYPE=intel \ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \ -t $(DOCKER_IMAGE) . ######################################################## ## Backends ######################################################## # Pattern rule for standard backends (docker-based) # This matches all backends that use docker-build-* and docker-save-* backends/%: docker-build-% docker-save-% build ./local-ai backends install "ocifile://$(abspath ./backend-images/$*.tar)" # Darwin-specific backends (keep as explicit targets since they have special build logic) backends/llama-cpp-darwin: build bash ./scripts/build/llama-cpp-darwin.sh ./local-ai backends install "ocifile://$(abspath ./backend-images/llama-cpp.tar)" build-darwin-python-backend: build bash ./scripts/build/python-darwin.sh build-darwin-go-backend: build bash ./scripts/build/golang-darwin.sh backends/mlx: BACKEND=mlx $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx.tar)" backends/diffuser-darwin: BACKEND=diffusers $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/diffusers.tar)" backends/mlx-vlm: BACKEND=mlx-vlm $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-vlm.tar)" backends/mlx-audio: BACKEND=mlx-audio $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-audio.tar)" backends/mlx-distributed: BACKEND=mlx-distributed $(MAKE) build-darwin-python-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/mlx-distributed.tar)" backends/stablediffusion-ggml-darwin: BACKEND=stablediffusion-ggml BUILD_TYPE=metal $(MAKE) build-darwin-go-backend ./local-ai backends install "ocifile://$(abspath ./backend-images/stablediffusion-ggml.tar)" backend-images: mkdir -p backend-images # Backend metadata: BACKEND_NAME | DOCKERFILE_TYPE | BUILD_CONTEXT | PROGRESS_FLAG | NEEDS_BACKEND_ARG # llama-cpp is special - uses llama-cpp Dockerfile and doesn't need BACKEND arg BACKEND_LLAMA_CPP = llama-cpp|llama-cpp|.|false|false # Golang backends BACKEND_PIPER = piper|golang|.|false|true BACKEND_LOCAL_STORE = local-store|golang|.|false|true BACKEND_HUGGINGFACE = huggingface|golang|.|false|true BACKEND_SILERO_VAD = silero-vad|golang|.|false|true BACKEND_STABLEDIFFUSION_GGML = stablediffusion-ggml|golang|.|--progress=plain|true BACKEND_WHISPER = whisper|golang|.|false|true BACKEND_VOXTRAL = voxtral|golang|.|false|true BACKEND_ACESTEP_CPP = acestep-cpp|golang|.|false|true BACKEND_OPUS = opus|golang|.|false|true # Python backends with root context BACKEND_RERANKERS = rerankers|python|.|false|true BACKEND_TRANSFORMERS = transformers|python|.|false|true BACKEND_OUTETTS = outetts|python|.|false|true BACKEND_FASTER_WHISPER = faster-whisper|python|.|false|true BACKEND_COQUI = coqui|python|.|false|true BACKEND_RFDETR = rfdetr|python|.|false|true BACKEND_KITTEN_TTS = kitten-tts|python|.|false|true BACKEND_NEUTTS = neutts|python|.|false|true BACKEND_KOKORO = kokoro|python|.|false|true BACKEND_VLLM = vllm|python|.|false|true BACKEND_VLLM_OMNI = vllm-omni|python|.|false|true BACKEND_DIFFUSERS = diffusers|python|.|--progress=plain|true BACKEND_CHATTERBOX = chatterbox|python|.|false|true BACKEND_VIBEVOICE = vibevoice|python|.|--progress=plain|true BACKEND_MOONSHINE = moonshine|python|.|false|true BACKEND_POCKET_TTS = pocket-tts|python|.|false|true BACKEND_QWEN_TTS = qwen-tts|python|.|false|true BACKEND_FISH_SPEECH = fish-speech|python|.|false|true BACKEND_FASTER_QWEN3_TTS = faster-qwen3-tts|python|.|false|true BACKEND_QWEN_ASR = qwen-asr|python|.|false|true BACKEND_NEMO = nemo|python|.|false|true BACKEND_VOXCPM = voxcpm|python|.|false|true BACKEND_WHISPERX = whisperx|python|.|false|true BACKEND_ACE_STEP = ace-step|python|.|false|true BACKEND_MLX_DISTRIBUTED = mlx-distributed|python|./|false|true # Helper function to build docker image for a backend # Usage: $(call docker-build-backend,BACKEND_NAME,DOCKERFILE_TYPE,BUILD_CONTEXT,PROGRESS_FLAG,NEEDS_BACKEND_ARG) define docker-build-backend docker build $(if $(filter-out false,$(4)),$(4)) \ --build-arg BUILD_TYPE=$(BUILD_TYPE) \ --build-arg BASE_IMAGE=$(BASE_IMAGE) \ --build-arg CUDA_MAJOR_VERSION=$(CUDA_MAJOR_VERSION) \ --build-arg CUDA_MINOR_VERSION=$(CUDA_MINOR_VERSION) \ --build-arg UBUNTU_VERSION=$(UBUNTU_VERSION) \ --build-arg UBUNTU_CODENAME=$(UBUNTU_CODENAME) \ $(if $(filter true,$(5)),--build-arg BACKEND=$(1)) \ -t local-ai-backend:$(1) -f backend/Dockerfile.$(2) $(3) endef # Generate docker-build targets from backend definitions define generate-docker-build-target docker-build-$(word 1,$(subst |, ,$(1))): $$(call docker-build-backend,$(word 1,$(subst |, ,$(1))),$(word 2,$(subst |, ,$(1))),$(word 3,$(subst |, ,$(1))),$(word 4,$(subst |, ,$(1))),$(word 5,$(subst |, ,$(1)))) endef # Generate all docker-build targets $(eval $(call generate-docker-build-target,$(BACKEND_LLAMA_CPP))) $(eval $(call generate-docker-build-target,$(BACKEND_PIPER))) $(eval $(call generate-docker-build-target,$(BACKEND_LOCAL_STORE))) $(eval $(call generate-docker-build-target,$(BACKEND_HUGGINGFACE))) $(eval $(call generate-docker-build-target,$(BACKEND_SILERO_VAD))) $(eval $(call generate-docker-build-target,$(BACKEND_STABLEDIFFUSION_GGML))) $(eval $(call generate-docker-build-target,$(BACKEND_WHISPER))) $(eval $(call generate-docker-build-target,$(BACKEND_VOXTRAL))) $(eval $(call generate-docker-build-target,$(BACKEND_OPUS))) $(eval $(call generate-docker-build-target,$(BACKEND_RERANKERS))) $(eval $(call generate-docker-build-target,$(BACKEND_TRANSFORMERS))) $(eval $(call generate-docker-build-target,$(BACKEND_OUTETTS))) $(eval $(call generate-docker-build-target,$(BACKEND_FASTER_WHISPER))) $(eval $(call generate-docker-build-target,$(BACKEND_COQUI))) $(eval $(call generate-docker-build-target,$(BACKEND_RFDETR))) $(eval $(call generate-docker-build-target,$(BACKEND_KITTEN_TTS))) $(eval $(call generate-docker-build-target,$(BACKEND_NEUTTS))) $(eval $(call generate-docker-build-target,$(BACKEND_KOKORO))) $(eval $(call generate-docker-build-target,$(BACKEND_VLLM))) $(eval $(call generate-docker-build-target,$(BACKEND_VLLM_OMNI))) $(eval $(call generate-docker-build-target,$(BACKEND_DIFFUSERS))) $(eval $(call generate-docker-build-target,$(BACKEND_CHATTERBOX))) $(eval $(call generate-docker-build-target,$(BACKEND_VIBEVOICE))) $(eval $(call generate-docker-build-target,$(BACKEND_MOONSHINE))) $(eval $(call generate-docker-build-target,$(BACKEND_POCKET_TTS))) $(eval $(call generate-docker-build-target,$(BACKEND_QWEN_TTS))) $(eval $(call generate-docker-build-target,$(BACKEND_FISH_SPEECH))) $(eval $(call generate-docker-build-target,$(BACKEND_FASTER_QWEN3_TTS))) $(eval $(call generate-docker-build-target,$(BACKEND_QWEN_ASR))) $(eval $(call generate-docker-build-target,$(BACKEND_NEMO))) $(eval $(call generate-docker-build-target,$(BACKEND_VOXCPM))) $(eval $(call generate-docker-build-target,$(BACKEND_WHISPERX))) $(eval $(call generate-docker-build-target,$(BACKEND_ACE_STEP))) $(eval $(call generate-docker-build-target,$(BACKEND_ACESTEP_CPP))) $(eval $(call generate-docker-build-target,$(BACKEND_MLX_DISTRIBUTED))) # Pattern rule for docker-save targets docker-save-%: backend-images docker save local-ai-backend:$* -o backend-images/$*.tar docker-build-backends: docker-build-llama-cpp docker-build-rerankers docker-build-vllm docker-build-vllm-omni docker-build-transformers docker-build-outetts docker-build-diffusers docker-build-kokoro docker-build-faster-whisper docker-build-coqui docker-build-chatterbox docker-build-vibevoice docker-build-moonshine docker-build-pocket-tts docker-build-qwen-tts docker-build-fish-speech docker-build-faster-qwen3-tts docker-build-qwen-asr docker-build-nemo docker-build-voxcpm docker-build-whisperx docker-build-ace-step docker-build-acestep-cpp docker-build-voxtral docker-build-mlx-distributed ######################################################## ### Mock Backend for E2E Tests ######################################################## build-mock-backend: protogen-go $(GOCMD) build -o tests/e2e/mock-backend/mock-backend ./tests/e2e/mock-backend clean-mock-backend: rm -f tests/e2e/mock-backend/mock-backend ######################################################## ### UI E2E Test Server ######################################################## build-ui-test-server: build-mock-backend react-ui protogen-go $(GOCMD) build -o tests/e2e-ui/ui-test-server ./tests/e2e-ui test-ui-e2e: build-ui-test-server cd core/http/react-ui && npm install && npx playwright install --with-deps chromium && npx playwright test test-ui-e2e-docker: docker build -t localai-ui-e2e -f tests/e2e-ui/Dockerfile . docker run --rm localai-ui-e2e clean-ui-test-server: rm -f tests/e2e-ui/ui-test-server ######################################################## ### END Backends ######################################################## .PHONY: swagger swagger: swag init -g core/http/app.go --output swagger # DEPRECATED: gen-assets is for the legacy Alpine.js UI. Remove when legacy UI is removed. .PHONY: gen-assets gen-assets: $(GOCMD) run core/dependencies_manager/manager.go webui_static.yaml core/http/static/assets ## Documentation docs/layouts/_default: mkdir -p docs/layouts/_default docs/static/gallery.html: docs/layouts/_default $(GOCMD) run ./.github/ci/modelslist.go ./gallery/index.yaml > docs/static/gallery.html docs/public: docs/layouts/_default docs/static/gallery.html cd docs && hugo --minify docs-clean: rm -rf docs/public rm -rf docs/static/gallery.html .PHONY: docs docs: docs/static/gallery.html cd docs && hugo serve ######################################################## ## Platform-specific builds ######################################################## ## fyne cross-platform build build-launcher-darwin: build-launcher go run github.com/tiagomelo/macos-dmg-creator/cmd/createdmg@latest \ --appName "LocalAI" \ --appBinaryPath "$(LAUNCHER_BINARY_NAME)" \ --bundleIdentifier "com.localai.launcher" \ --iconPath "core/http/static/logo.png" \ --outputDir "dist/" build-launcher-linux: cd cmd/launcher && go run fyne.io/tools/cmd/fyne@latest package -os linux -icon ../../core/http/static/logo.png --executable $(LAUNCHER_BINARY_NAME)-linux && mv launcher.tar.xz ../../$(LAUNCHER_BINARY_NAME)-linux.tar.xz ================================================ FILE: README.md ================================================




LocalAI forks LocalAI stars LocalAI pull-requests

LocalAI License

LocalAI Docker hub LocalAI Quay.io

Follow LocalAI_API Join LocalAI Discord Community

mudler%2FLocalAI | Trendshift

> :bulb: Get help - [❓FAQ](https://localai.io/faq/) [💭Discussions](https://github.com/go-skynet/LocalAI/discussions) [:speech_balloon: Discord](https://discord.gg/uJAeKSAGDy) [:book: Documentation website](https://localai.io/) > > [💻 Quickstart](https://localai.io/basics/getting_started/) [🖼️ Models](https://models.localai.io/) [🚀 Roadmap](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) [🛫 Examples](https://github.com/mudler/LocalAI-examples) Try on [![Telegram](https://img.shields.io/badge/Telegram-2CA5E0?style=for-the-badge&logo=telegram&logoColor=white)](https://t.me/localaiofficial_bot) [![tests](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/test.yml)[![Build and Release](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/release.yaml)[![build container images](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/image.yml)[![Bump dependencies](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml/badge.svg)](https://github.com/go-skynet/LocalAI/actions/workflows/bump_deps.yaml)[![Artifact Hub](https://img.shields.io/endpoint?url=https://artifacthub.io/badge/repository/localai)](https://artifacthub.io/packages/search?repo=localai)

LocalAI Examples Repository

**LocalAI** is the free, Open Source OpenAI alternative. LocalAI act as a drop-in replacement REST API that's compatible with OpenAI (Elevenlabs, Anthropic... ) API specifications for local AI inferencing. It allows you to run LLMs, generate images, audio (and not only) locally or on-prem with consumer grade hardware, supporting multiple model families. Does not require GPU. It is created and maintained by [Ettore Di Giacinto](https://github.com/mudler). ## Screenshots / Video ### Chat, Model gallery https://github.com/user-attachments/assets/08cbb692-57da-48f7-963d-2e7b43883c18 ### Agents https://github.com/user-attachments/assets/6270b331-e21d-4087-a540-6290006b381a ### Youtube video




## 💻 Quickstart ### macOS Download: Download LocalAI for macOS > Note: the DMGs are not signed by Apple as quarantined. See https://github.com/mudler/LocalAI/issues/6268 for a workaround, fix is tracked here: https://github.com/mudler/LocalAI/issues/6244 > Install the DMG and paste this code into terminal: `sudo xattr -d com.apple.quarantine /Applications/LocalAI.app` ### Containers (Docker, podman, ...) > **💡 Docker Run vs Docker Start** > > - `docker run` creates and starts a new container. If a container with the same name already exists, this command will fail. > - `docker start` starts an existing container that was previously created with `docker run`. > > If you've already run LocalAI before and want to start it again, use: `docker start -i local-ai` #### CPU only image: ```bash docker run -ti --name local-ai -p 8080:8080 localai/localai:latest ``` #### NVIDIA GPU Images: ```bash # CUDA 13.0 docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-13 # CUDA 12.0 docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-gpu-nvidia-cuda-12 # NVIDIA Jetson (L4T) ARM64 # CUDA 12 (for Nvidia AGX Orin and similar platforms) docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64 # CUDA 13 (for Nvidia DGX Spark) docker run -ti --name local-ai -p 8080:8080 --gpus all localai/localai:latest-nvidia-l4t-arm64-cuda-13 ``` #### AMD GPU Images (ROCm): ```bash docker run -ti --name local-ai -p 8080:8080 --device=/dev/kfd --device=/dev/dri --group-add=video localai/localai:latest-gpu-hipblas ``` #### Intel GPU Images (oneAPI): ```bash docker run -ti --name local-ai -p 8080:8080 --device=/dev/dri/card1 --device=/dev/dri/renderD128 localai/localai:latest-gpu-intel ``` #### Vulkan GPU Images: ```bash docker run -ti --name local-ai -p 8080:8080 localai/localai:latest-gpu-vulkan ``` To load models: ```bash # From the model gallery (see available models with `local-ai models list`, in the WebUI from the model tab, or visiting https://models.localai.io) local-ai run llama-3.2-1b-instruct:q4_k_m # Start LocalAI with the phi-2 model directly from huggingface local-ai run huggingface://TheBloke/phi-2-GGUF/phi-2.Q8_0.gguf # Install and run a model from the Ollama OCI registry local-ai run ollama://gemma:2b # Run a model from a configuration file local-ai run https://gist.githubusercontent.com/.../phi-2.yaml # Install and run a model from a standard OCI registry (e.g., Docker Hub) local-ai run oci://localai/phi-2:latest ``` > ⚡ **Automatic Backend Detection**: When you install models from the gallery or YAML files, LocalAI automatically detects your system's GPU capabilities (NVIDIA, AMD, Intel) and downloads the appropriate backend. For advanced configuration options, see [GPU Acceleration](https://localai.io/features/gpu-acceleration/#automatic-backend-detection). For more information, see [💻 Getting started](https://localai.io/basics/getting_started/index.html), if you are interested in our roadmap items and future enhancements, you can see the [Issues labeled as Roadmap here](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) ## 📰 Latest project news - March 2026: [Agent management](https://github.com/mudler/LocalAI/pull/8820), [New React UI](https://github.com/mudler/LocalAI/pull/8772), [WebRTC](https://github.com/mudler/LocalAI/pull/8790),[MLX-distributed via P2P and RDMA](https://github.com/mudler/LocalAI/pull/8801), [MCP Apps, MCP Client-side](https://github.com/mudler/LocalAI/pull/8947) - February 2026: [Realtime API for audio-to-audio with tool calling](https://github.com/mudler/LocalAI/pull/6245), [ACE-Step 1.5 support](https://github.com/mudler/LocalAI/pull/8396) - January 2026: **LocalAI 3.10.0** - Major release with Anthropic API support, Open Responses API for stateful agents, video & image generation suite (LTX-2), unified GPU backends, tool streaming & XML parsing, system-aware backend gallery, crash fixes for AVX-only CPUs and AMD VRAM reporting, request tracing, and new backends: **Moonshine** (ultra-fast transcription), **Pocket-TTS** (lightweight TTS). Vulkan arm64 builds now available. [Release notes](https://github.com/mudler/LocalAI/releases/tag/v3.10.0). - December 2025: [Dynamic Memory Resource reclaimer](https://github.com/mudler/LocalAI/pull/7583), [Automatic fitting of models to multiple GPUS(llama.cpp)](https://github.com/mudler/LocalAI/pull/7584), [Added Vibevoice backend](https://github.com/mudler/LocalAI/pull/7494) - November 2025: Major improvements to the UX. Among these: [Import models via URL](https://github.com/mudler/LocalAI/pull/7245) and [Multiple chats and history](https://github.com/mudler/LocalAI/pull/7325) - October 2025: 🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) support added for agentic capabilities with external tools - September 2025: New Launcher application for MacOS and Linux, extended support to many backends for Mac and Nvidia L4T devices. Models: Added MLX-Audio, WAN 2.2. WebUI improvements and Python-based backends now ships portable python environments. - August 2025: MLX, MLX-VLM, Diffusers and llama.cpp are now supported on Mac M1/M2/M3+ chips ( with `development` suffix in the gallery ): https://github.com/mudler/LocalAI/pull/6049 https://github.com/mudler/LocalAI/pull/6119 https://github.com/mudler/LocalAI/pull/6121 https://github.com/mudler/LocalAI/pull/6060 - July/August 2025: 🔍 [Object Detection](https://localai.io/features/object-detection/) added to the API featuring [rf-detr](https://github.com/roboflow/rf-detr) - July 2025: All backends migrated outside of the main binary. LocalAI is now more lightweight, small, and automatically downloads the required backend to run the model. [Read the release notes](https://github.com/mudler/LocalAI/releases/tag/v3.2.0) - June 2025: [Backend management](https://github.com/mudler/LocalAI/pull/5607) has been added. Attention: extras images are going to be deprecated from the next release! Read [the backend management PR](https://github.com/mudler/LocalAI/pull/5607). - May 2025: [Audio input](https://github.com/mudler/LocalAI/pull/5466) and [Reranking](https://github.com/mudler/LocalAI/pull/5396) in llama.cpp backend, [Realtime API](https://github.com/mudler/LocalAI/pull/5392), Support to Gemma, SmollVLM, and more multimodal models (available in the gallery). - May 2025: Important: image name changes [See release](https://github.com/mudler/LocalAI/releases/tag/v2.29.0) - Apr 2025: Rebrand, WebUI enhancements - Apr 2025: [LocalAGI](https://github.com/mudler/LocalAGI) and [LocalRecall](https://github.com/mudler/LocalRecall) join the LocalAI family stack. - Apr 2025: WebUI overhaul - Feb 2025: Backend cleanup, Breaking changes, new backends (kokoro, OutelTTS, faster-whisper), Nvidia L4T images - Jan 2025: LocalAI model release: https://huggingface.co/mudler/LocalAI-functioncall-phi-4-v0.3, SANA support in diffusers: https://github.com/mudler/LocalAI/pull/4603 - Dec 2024: stablediffusion.cpp backend (ggml) added ( https://github.com/mudler/LocalAI/pull/4289 ) - Nov 2024: Bark.cpp backend added ( https://github.com/mudler/LocalAI/pull/4287 ) - Nov 2024: Voice activity detection models (**VAD**) added to the API: https://github.com/mudler/LocalAI/pull/4204 - Oct 2024: examples moved to [LocalAI-examples](https://github.com/mudler/LocalAI-examples) - Aug 2024: 🆕 FLUX-1, [P2P Explorer](https://explorer.localai.io) - July 2024: 🔥🔥 🆕 P2P Dashboard, LocalAI Federated mode and AI Swarms: https://github.com/mudler/LocalAI/pull/2723. P2P Global community pools: https://github.com/mudler/LocalAI/issues/3113 - May 2024: 🔥🔥 Decentralized P2P llama.cpp: https://github.com/mudler/LocalAI/pull/2343 (peer2peer llama.cpp!) 👉 Docs https://localai.io/features/distribute/ - May 2024: 🔥🔥 Distributed inferencing: https://github.com/mudler/LocalAI/pull/2324 - April 2024: Reranker API: https://github.com/mudler/LocalAI/pull/2121 Roadmap items: [List of issues](https://github.com/mudler/LocalAI/issues?q=is%3Aissue+is%3Aopen+label%3Aroadmap) ## 🚀 [Features](https://localai.io/features/) - 🧩 [Backend Gallery](https://localai.io/backends/): Install/remove backends on the fly, powered by OCI images — fully customizable and API-driven. - 📖 [Text generation with GPTs](https://localai.io/features/text-generation/) (`llama.cpp`, `transformers`, `vllm` ... [:book: and more](https://localai.io/model-compatibility/index.html#model-compatibility-table)) - 🗣 [Text to Audio](https://localai.io/features/text-to-audio/) - 🔈 [Audio to Text](https://localai.io/features/audio-to-text/) - 🎨 [Image generation](https://localai.io/features/image-generation) - 🔥 [OpenAI-alike tools API](https://localai.io/features/openai-functions/) - ⚡ [Realtime API](https://localai.io/features/openai-realtime/) (Speech-to-speech) - 🧠 [Embeddings generation for vector databases](https://localai.io/features/embeddings/) - ✍️ [Constrained grammars](https://localai.io/features/constrained_grammars/) - 🖼️ [Download Models directly from Huggingface ](https://localai.io/models/) - 🥽 [Vision API](https://localai.io/features/gpt-vision/) - 🔍 [Object Detection](https://localai.io/features/object-detection/) - 📈 [Reranker API](https://localai.io/features/reranker/) - 🆕🖧 [P2P Inferencing](https://localai.io/features/distribute/) - 🆕🔌 [Model Context Protocol (MCP)](https://localai.io/docs/features/mcp/) - Agentic capabilities with external tools and [LocalAGI's Agentic capabilities](https://github.com/mudler/LocalAGI) - 🆕🤖 [Built-in Agents](https://localai.io/features/agents/) - Autonomous AI agents with tool use, knowledge base (RAG), skills, SSE streaming, import/export, and [Agent Hub](https://agenthub.localai.io) — powered by [LocalAGI](https://github.com/mudler/LocalAGI) - 🔊 Voice activity detection (Silero-VAD support) - 🌍 Integrated WebUI! ## 🧩 Supported Backends & Acceleration LocalAI supports a comprehensive range of AI backends with multiple acceleration options: ### Text Generation & Language Models | Backend | Description | Acceleration Support | |---------|-------------|---------------------| | **llama.cpp** | LLM inference in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, Metal, CPU | | **vLLM** | Fast LLM inference with PagedAttention | CUDA 12/13, ROCm, Intel | | **transformers** | HuggingFace transformers framework | CUDA 12/13, ROCm, Intel, CPU | | **MLX** | Apple Silicon LLM inference | Metal (M1/M2/M3+) | | **MLX-VLM** | Apple Silicon Vision-Language Models | Metal (M1/M2/M3+) | | **vLLM Omni** | Multimodal vLLM with vision and audio | CUDA 12/13, ROCm, Intel | ### Audio & Speech Processing | Backend | Description | Acceleration Support | |---------|-------------|---------------------| | **whisper.cpp** | OpenAI Whisper in C/C++ | CUDA 12/13, ROCm, Intel SYCL, Vulkan, CPU | | **faster-whisper** | Fast Whisper with CTranslate2 | CUDA 12/13, ROCm, Intel, CPU | | **moonshine** | Ultra-fast transcription engine for low-end devices | CUDA 12/13, Metal, CPU | | **coqui** | Advanced TTS with 1100+ languages | CUDA 12/13, ROCm, Intel, CPU | | **kokoro** | Lightweight TTS model | CUDA 12/13, ROCm, Intel, CPU | | **chatterbox** | Production-grade TTS | CUDA 12/13, CPU | | **piper** | Fast neural TTS system | CPU | | **kitten-tts** | Kitten TTS models | CPU | | **silero-vad** | Voice Activity Detection | CPU | | **neutts** | Text-to-speech with voice cloning | CUDA 12/13, ROCm, CPU | | **vibevoice** | Real-time TTS with voice cloning | CUDA 12/13, ROCm, Intel, CPU | | **pocket-tts** | Lightweight CPU-based TTS | CUDA 12/13, ROCm, Intel, CPU | | **qwen-tts** | High-quality TTS with custom voice, voice design, and voice cloning | CUDA 12/13, ROCm, Intel, CPU | | **nemo** | NVIDIA NeMo framework for speech models | CUDA 12/13, ROCm, Intel, CPU | | **outetts** | OuteTTS with voice cloning | CUDA 12/13, CPU | | **faster-qwen3-tts** | Faster Qwen3 TTS | CUDA 12/13, ROCm, Intel, CPU | | **qwen-asr** | Qwen ASR speech recognition | CUDA 12/13, ROCm, Intel, CPU | | **voxcpm** | VoxCPM speech understanding | CUDA 12/13, Metal, CPU | | **whisperx** | Enhanced Whisper transcription | CUDA 12/13, ROCm, Intel, CPU | | **ace-step** | Music generation from text descriptions, lyrics, or audio samples | CUDA 12/13, ROCm, Intel, Metal, CPU | ### Image & Video Generation | Backend | Description | Acceleration Support | |---------|-------------|---------------------| | **stablediffusion.cpp** | Stable Diffusion in C/C++ | CUDA 12/13, Intel SYCL, Vulkan, CPU | | **diffusers** | HuggingFace diffusion models | CUDA 12/13, ROCm, Intel, Metal, CPU | ### Specialized AI Tasks | Backend | Description | Acceleration Support | |---------|-------------|---------------------| | **rfdetr** | Real-time object detection | CUDA 12/13, Intel, CPU | | **rerankers** | Document reranking API | CUDA 12/13, ROCm, Intel, CPU | | **local-store** | Vector database | CPU | | **huggingface** | HuggingFace API integration | API-based | ### Hardware Acceleration Matrix | Acceleration Type | Supported Backends | Hardware Support | |-------------------|-------------------|------------------| | **NVIDIA CUDA 12** | All CUDA-compatible backends | Nvidia hardware | | **NVIDIA CUDA 13** | All CUDA-compatible backends | Nvidia hardware | | **AMD ROCm** | llama.cpp, whisper, vllm, transformers, diffusers, rerankers, coqui, kokoro, neutts, vibevoice, pocket-tts, qwen-tts, ace-step | AMD Graphics | | **Intel oneAPI** | llama.cpp, whisper, stablediffusion, vllm, transformers, diffusers, rfdetr, rerankers, coqui, kokoro, vibevoice, pocket-tts, qwen-tts, ace-step | Intel Arc, Intel iGPUs | | **Apple Metal** | llama.cpp, whisper, diffusers, MLX, MLX-VLM, moonshine, ace-step | Apple M1/M2/M3+ | | **Vulkan** | llama.cpp, whisper, stablediffusion | Cross-platform GPUs | | **NVIDIA Jetson (CUDA 12)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr, ace-step | ARM64 embedded AI (AGX Orin, etc.) | | **NVIDIA Jetson (CUDA 13)** | llama.cpp, whisper, stablediffusion, diffusers, rfdetr | ARM64 embedded AI (DGX Spark) | | **CPU Optimized** | All backends | AVX/AVX2/AVX512, quantization support | ### 🔗 Community and integrations Build and deploy custom containers: - https://github.com/sozercan/aikit WebUIs: - https://github.com/Jirubizu/localai-admin - https://github.com/go-skynet/LocalAI-frontend - QA-Pilot(An interactive chat project that leverages LocalAI LLMs for rapid understanding and navigation of GitHub code repository) https://github.com/reid41/QA-Pilot Agentic Libraries: - https://github.com/mudler/cogito MCPs: - https://github.com/mudler/MCPs OS Assistant: - https://github.com/mudler/Keygeist - Keygeist is an AI-powered keyboard operator that listens for key combinations and responds with AI-generated text typed directly into your Linux box. Model galleries - https://github.com/go-skynet/model-gallery Voice: - https://github.com/richiejp/VoxInput Other: - Helm chart https://github.com/go-skynet/helm-charts - VSCode extension https://github.com/badgooooor/localai-vscode-plugin - Langchain: https://python.langchain.com/docs/integrations/providers/localai/ - Terminal utility https://github.com/djcopley/ShellOracle - Local Smart assistant https://github.com/mudler/LocalAGI - Home Assistant https://github.com/drndos/hass-openai-custom-conversation / https://github.com/valentinfrlch/ha-llmvision / https://github.com/loryanstrant/HA-LocalAI-Monitor - Discord bot https://github.com/mudler/LocalAGI/tree/main/examples/discord - Slack bot https://github.com/mudler/LocalAGI/tree/main/examples/slack - Shell-Pilot(Interact with LLM using LocalAI models via pure shell scripts on your Linux or MacOS system) https://github.com/reid41/shell-pilot - Telegram bot https://github.com/mudler/LocalAI/tree/master/examples/telegram-bot - Another Telegram Bot https://github.com/JackBekket/Hellper - Auto-documentation https://github.com/JackBekket/Reflexia - Github bot which answer on issues, with code and documentation as context https://github.com/JackBekket/GitHelper - Github Actions: https://github.com/marketplace/actions/start-localai - Examples: https://github.com/mudler/LocalAI/tree/master/examples/ ### 🔗 Resources - [LLM finetuning guide](https://localai.io/docs/advanced/fine-tuning/) - [How to build locally](https://localai.io/basics/build/index.html) - [How to install in Kubernetes](https://localai.io/basics/getting_started/index.html#run-localai-in-kubernetes) - [Projects integrating LocalAI](https://localai.io/docs/integrations/) - [How tos section](https://io.midori-ai.xyz/howtos/) (curated by our community) ## :book: 🎥 [Media, Blogs, Social](https://localai.io/basics/news/#media-blogs-social) - 🆕 [LocalAI Autonomous Dev Team Blog Post](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/) - [Run Visual studio code with LocalAI (SUSE)](https://www.suse.com/c/running-ai-locally/) - 🆕 [Run LocalAI on Jetson Nano Devkit](https://mudler.pm/posts/local-ai-jetson-nano-devkit/) - [Run LocalAI on AWS EKS with Pulumi](https://www.pulumi.com/blog/low-code-llm-apps-with-local-ai-flowise-and-pulumi/) - [Run LocalAI on AWS](https://staleks.hashnode.dev/installing-localai-on-aws-ec2-instance) - [Create a slackbot for teams and OSS projects that answer to documentation](https://mudler.pm/posts/smart-slackbot-for-teams/) - [LocalAI meets k8sgpt](https://www.youtube.com/watch?v=PKrDNuJ_dfE) - [Question Answering on Documents locally with LangChain, LocalAI, Chroma, and GPT4All](https://mudler.pm/posts/localai-question-answering/) - [Tutorial to use k8sgpt with LocalAI](https://medium.com/@tyler_97636/k8sgpt-localai-unlock-kubernetes-superpowers-for-free-584790de9b65) ## 🤖 Autonomous Development Team LocalAI is now helped being maintained (for small tasks!) by a full team of autonomous AI agents led by an AI Scrum Master! This experiment demonstrates how open source projects can leverage AI agents for sustainable, long-term maintenance. - **📊 Live Reports**: [Automatically generated reports](http://reports.localai.io) - **📋 Project Board**: [Agent task tracking](https://github.com/users/mudler/projects/6) - **📝 Blog Post**: [Learn about the autonomous dev team experiment](https://mudler.pm/posts/2026/02/28/a-call-to-open-source-maintainers-stop-babysitting-ai-how-i-built-a-100-local-autonomous-dev-team-to-maintain-localai-and-why-you-should-too/) ## Citation If you utilize this repository, data in a downstream project, please consider citing it with: ``` @misc{localai, author = {Ettore Di Giacinto}, title = {LocalAI: The free, Open source OpenAI alternative}, year = {2023}, publisher = {GitHub}, journal = {GitHub repository}, howpublished = {\url{https://github.com/go-skynet/LocalAI}}, ``` ## ❤️ Sponsors > Do you find LocalAI useful? Support the project by becoming [a backer or sponsor](https://github.com/sponsors/mudler). Your logo will show up here with a link to your website. A huge thank you to our generous sponsors who support this project covering CI expenses, and our [Sponsor list](https://github.com/sponsors/mudler):


### Individual sponsors A special thanks to individual sponsors that contributed to the project, a full list is in [Github](https://github.com/sponsors/mudler) and [buymeacoffee](https://buymeacoffee.com/mudler), a special shout out goes to [drikster80](https://github.com/drikster80) for being generous. Thank you everyone! ## 🌟 Star history [![LocalAI Star history Chart](https://api.star-history.com/svg?repos=go-skynet/LocalAI&type=Date)](https://star-history.com/#go-skynet/LocalAI&Date) ## 📖 License LocalAI is a community-driven project created by [Ettore Di Giacinto](https://github.com/mudler/). MIT - Author Ettore Di Giacinto ## 🙇 Acknowledgements LocalAI couldn't have been built without the help of great software already available from the community. Thank you! - [llama.cpp](https://github.com/ggerganov/llama.cpp) - https://github.com/tatsu-lab/stanford_alpaca - https://github.com/cornelk/llama-go for the initial ideas - https://github.com/antimatter15/alpaca.cpp - https://github.com/EdVince/Stable-Diffusion-NCNN - https://github.com/ggerganov/whisper.cpp - https://github.com/rhasspy/piper - [exo](https://github.com/exo-explore/exo) for the MLX distributed auto-parallel sharding implementation ## 🤗 Contributors This is a community project, a special thanks to our contributors! 🤗 ================================================ FILE: SECURITY.md ================================================ # Security Policy ## Introduction At LocalAI, we take the security of our software seriously. We understand the importance of protecting our community from vulnerabilities and are committed to ensuring the safety and security of our users. ## Supported Versions We provide support and updates for certain versions of our software. The following table outlines which versions are currently supported with security updates: | Version Series | Support Level | Details | | -------------- | ------------- | ------- | | 3.x | :white_check_mark: Actively supported | Full security updates and bug fixes for the latest minor versions. | | 2.x | :warning: Security fixes only | Critical security patches only, until **December 31, 2025**. | | 1.x | :x: End-of-life (EOL) | No longer supported as of **January 1, 2024**. No security fixes will be provided. | ### What each support level means - **Actively supported (3.x):** Receives all security updates, bug fixes, and new features. Users should stay on the latest 3.x minor release for the best protection. - **Security fixes only (2.x):** Receives only critical security patches (e.g., remote code execution, authentication bypass, data exposure). No bug fixes or new features. Support ends December 31, 2025. - **End-of-life (1.x):** No updates of any kind. Users on 1.x are strongly encouraged to upgrade immediately, as known vulnerabilities will not be patched. ### Migrating from older versions If you are running an unsupported or soon-to-be-unsupported version, we recommend upgrading as soon as possible: - **From 1.x to 3.x:** Version 1.x reached end-of-life on January 1, 2024. Review the [release notes](https://github.com/mudler/LocalAI/releases) for breaking changes across major versions, and upgrade directly to the latest 3.x release. - **From 2.x to 3.x:** While 2.x still receives critical security patches until December 31, 2025, we recommend planning your migration to 3.x to benefit from ongoing improvements and full support. Please ensure that you are using a supported version to receive the latest security updates. ## Reporting a Vulnerability We encourage the responsible disclosure of any security vulnerabilities. If you believe you've found a security issue in our software, we kindly ask you to follow the steps below to report it to us: 1. **Email Us:** Send an email to [security@localai.io](mailto:security@localai.io) with a detailed report. Please do not disclose the vulnerability publicly or to any third parties before it has been addressed by us. 2. **Expect a Response:** We aim to acknowledge receipt of vulnerability reports within 48 hours. Our security team will review your report and work closely with you to understand the impact and ensure a thorough investigation. 3. **Collaboration:** If the vulnerability is accepted, we will work with you and our community to address the issue promptly. We'll keep you informed throughout the resolution process and may request additional information or collaboration. 4. **Disclosure:** Once the vulnerability has been resolved, we encourage a coordinated disclosure. We believe in transparency and will work with you to ensure that our community is informed in a responsible manner. ## Use of Third-Party Platforms As a Free and Open Source Software (FOSS) organization, we do not offer monetary bounties. However, researchers who wish to report vulnerabilities can also do so via [Huntr](https://huntr.dev/bounties), a platform that recognizes contributions to open source security. ## Contact For any security-related inquiries beyond vulnerability reporting, please contact us at [security@localai.io](mailto:security@localai.io). ## Acknowledgments We appreciate the efforts of those who contribute to the security of our project. Your responsible disclosure is invaluable to the safety and integrity of LocalAI. Thank you for helping us keep LocalAI secure. ================================================ FILE: backend/Dockerfile.golang ================================================ ARG BASE_IMAGE=ubuntu:24.04 FROM ${BASE_IMAGE} AS builder ARG BACKEND=rerankers ARG BUILD_TYPE ENV BUILD_TYPE=${BUILD_TYPE} ARG CUDA_MAJOR_VERSION ARG CUDA_MINOR_VERSION ARG SKIP_DRIVERS=false ENV CUDA_MAJOR_VERSION=${CUDA_MAJOR_VERSION} ENV CUDA_MINOR_VERSION=${CUDA_MINOR_VERSION} ENV DEBIAN_FRONTEND=noninteractive ARG TARGETARCH ARG TARGETVARIANT ARG GO_VERSION=1.25.4 ARG UBUNTU_VERSION=2404 RUN apt-get update && \ apt-get install -y --no-install-recommends \ build-essential \ git ccache \ ca-certificates \ make cmake wget libopenblas-dev \ curl unzip \ libssl-dev && \ apt-get clean && \ rm -rf /var/lib/apt/lists/* # Cuda ENV PATH=/usr/local/cuda/bin:${PATH} # HipBLAS requirements ENV PATH=/opt/rocm/bin:${PATH} # Vulkan requirements RUN < Metadata = 52; // Generic per-request metadata (e.g., enable_thinking) } // ToolCallDelta represents an incremental tool call update from the C++ parser. // Used for both streaming (partial diffs) and non-streaming (final tool calls). message ToolCallDelta { int32 index = 1; // tool call index (0-based) string id = 2; // tool call ID (e.g., "call_abc123") string name = 3; // function name (set on first appearance) string arguments = 4; // arguments chunk (incremental in streaming, full in non-streaming) } // ChatDelta represents incremental content/reasoning/tool_call updates parsed by the C++ backend. message ChatDelta { string content = 1; // content text delta string reasoning_content = 2; // reasoning/thinking text delta repeated ToolCallDelta tool_calls = 3; // tool call deltas } // The response message containing the result message Reply { bytes message = 1; int32 tokens = 2; int32 prompt_tokens = 3; double timing_prompt_processing = 4; double timing_token_generation = 5; bytes audio = 6; bytes logprobs = 7; // JSON-encoded logprobs data matching OpenAI format repeated ChatDelta chat_deltas = 8; // Parsed chat deltas from C++ autoparser (streaming + non-streaming) } message GrammarTrigger { string word = 1; } message ModelOptions { string Model = 1; int32 ContextSize = 2; int32 Seed = 3; int32 NBatch = 4; bool F16Memory = 5; bool MLock = 6; bool MMap = 7; bool VocabOnly = 8; bool LowVRAM = 9; bool Embeddings = 10; bool NUMA = 11; int32 NGPULayers = 12; string MainGPU = 13; string TensorSplit = 14; int32 Threads = 15; float RopeFreqBase = 17; float RopeFreqScale = 18; float RMSNormEps = 19; int32 NGQA = 20; string ModelFile = 21; // Diffusers string PipelineType = 26; string SchedulerType = 27; bool CUDA = 28; float CFGScale = 29; bool IMG2IMG = 30; string CLIPModel = 31; string CLIPSubfolder = 32; int32 CLIPSkip = 33; string ControlNet = 48; string Tokenizer = 34; // LLM (llama.cpp) string LoraBase = 35; string LoraAdapter = 36; float LoraScale = 42; bool NoMulMatQ = 37; string DraftModel = 39; string AudioPath = 38; // vllm string Quantization = 40; float GPUMemoryUtilization = 50; bool TrustRemoteCode = 51; bool EnforceEager = 52; int32 SwapSpace = 53; int32 MaxModelLen = 54; int32 TensorParallelSize = 55; string LoadFormat = 58; bool DisableLogStatus = 66; string DType = 67; int32 LimitImagePerPrompt = 68; int32 LimitVideoPerPrompt = 69; int32 LimitAudioPerPrompt = 70; string MMProj = 41; string RopeScaling = 43; float YarnExtFactor = 44; float YarnAttnFactor = 45; float YarnBetaFast = 46; float YarnBetaSlow = 47; string Type = 49; string FlashAttention = 56; bool NoKVOffload = 57; string ModelPath = 59; repeated string LoraAdapters = 60; repeated float LoraScales = 61; repeated string Options = 62; string CacheTypeKey = 63; string CacheTypeValue = 64; repeated GrammarTrigger GrammarTriggers = 65; bool Reranking = 71; repeated string Overrides = 72; } message Result { string message = 1; bool success = 2; } message EmbeddingResult { repeated float embeddings = 1; } message TranscriptRequest { string dst = 2; string language = 3; uint32 threads = 4; bool translate = 5; bool diarize = 6; string prompt = 7; } message TranscriptResult { repeated TranscriptSegment segments = 1; string text = 2; } message TranscriptSegment { int32 id = 1; int64 start = 2; int64 end = 3; string text = 4; repeated int32 tokens = 5; string speaker = 6; } message GenerateImageRequest { int32 height = 1; int32 width = 2; int32 step = 4; int32 seed = 5; string positive_prompt = 6; string negative_prompt = 7; string dst = 8; string src = 9; // Diffusers string EnableParameters = 10; int32 CLIPSkip = 11; // Reference images for models that support them (e.g., Flux Kontext) repeated string ref_images = 12; } message GenerateVideoRequest { string prompt = 1; string negative_prompt = 2; // Negative prompt for video generation string start_image = 3; // Path or base64 encoded image for the start frame string end_image = 4; // Path or base64 encoded image for the end frame int32 width = 5; int32 height = 6; int32 num_frames = 7; // Number of frames to generate int32 fps = 8; // Frames per second int32 seed = 9; float cfg_scale = 10; // Classifier-free guidance scale int32 step = 11; // Number of inference steps string dst = 12; // Output path for the generated video } message TTSRequest { string text = 1; string model = 2; string dst = 3; string voice = 4; optional string language = 5; } message VADRequest { repeated float audio = 1; } message VADSegment { float start = 1; float end = 2; } message VADResponse { repeated VADSegment segments = 1; } message SoundGenerationRequest { string text = 1; string model = 2; string dst = 3; optional float duration = 4; optional float temperature = 5; optional bool sample = 6; optional string src = 7; optional int32 src_divisor = 8; optional bool think = 9; optional string caption = 10; optional string lyrics = 11; optional int32 bpm = 12; optional string keyscale = 13; optional string language = 14; optional string timesignature = 15; optional bool instrumental = 17; } message TokenizationResponse { int32 length = 1; repeated int32 tokens = 2; } message MemoryUsageData { uint64 total = 1; map breakdown = 2; } message StatusResponse { enum State { UNINITIALIZED = 0; BUSY = 1; READY = 2; ERROR = -1; } State state = 1; MemoryUsageData memory = 2; } message Message { string role = 1; string content = 2; // Optional fields for OpenAI-compatible message format string name = 3; // Tool name (for tool messages) string tool_call_id = 4; // Tool call ID (for tool messages) string reasoning_content = 5; // Reasoning content (for thinking models) string tool_calls = 6; // Tool calls as JSON string (for assistant messages with tool calls) } message DetectOptions { string src = 1; } message Detection { float x = 1; float y = 2; float width = 3; float height = 4; float confidence = 5; string class_name = 6; } message DetectResponse { repeated Detection Detections = 1; } message ToolFormatMarkers { string format_type = 1; // "json_native", "tag_with_json", "tag_with_tagged" // Tool section markers string section_start = 2; // e.g., "", "[TOOL_CALLS]" string section_end = 3; // e.g., "" string per_call_start = 4; // e.g., "<|tool_call_begin|>" string per_call_end = 5; // e.g., "<|tool_call_end|>" // Function name markers (TAG_WITH_JSON / TAG_WITH_TAGGED) string func_name_prefix = 6; // e.g., "" string func_close = 8; // e.g., "" // Argument markers (TAG_WITH_TAGGED) string arg_name_prefix = 9; // e.g., "" string arg_value_prefix = 11; string arg_value_suffix = 12; // e.g., "" string arg_separator = 13; // e.g., "\n" // JSON format fields (JSON_NATIVE) string name_field = 14; // e.g., "name" string args_field = 15; // e.g., "arguments" string id_field = 16; // e.g., "id" bool fun_name_is_key = 17; bool tools_array_wrapped = 18; bool uses_python_dicts = 19; // Reasoning markers string reasoning_start = 20; // e.g., "" string reasoning_end = 21; // e.g., "" // Content markers string content_start = 22; string content_end = 23; // Args wrapper markers string args_start = 24; // e.g., "" string args_end = 25; // e.g., "" // JSON parameter ordering string function_field = 26; // e.g., "function" (wrapper key in JSON) repeated string parameter_order = 27; // Generated ID field (alternative field name for generated IDs) string gen_id_field = 28; // e.g., "call_id" // Call ID markers (position and delimiters for tool call IDs) string call_id_position = 29; // "none", "pre_func_name", "between_func_and_args", "post_args" string call_id_prefix = 30; // e.g., "[CALL_ID]" string call_id_suffix = 31; // e.g., "" } message AudioEncodeRequest { bytes pcm_data = 1; int32 sample_rate = 2; int32 channels = 3; map options = 4; } message AudioEncodeResult { repeated bytes frames = 1; int32 sample_rate = 2; int32 samples_per_frame = 3; } message AudioDecodeRequest { repeated bytes frames = 1; map options = 2; } message AudioDecodeResult { bytes pcm_data = 1; int32 sample_rate = 2; int32 samples_per_frame = 3; } message ModelMetadataResponse { bool supports_thinking = 1; string rendered_template = 2; // The rendered chat template with enable_thinking=true (empty if not applicable) ToolFormatMarkers tool_format = 3; // Auto-detected tool format markers from differential template analysis } ================================================ FILE: backend/cpp/grpc/.gitignore ================================================ installed_packages/ grpc_build/ grpc_repo/ ================================================ FILE: backend/cpp/grpc/Makefile ================================================ # Basic platform detection HOST_SYSTEM = $(shell uname | cut -f 1 -d_) SYSTEM ?= $(HOST_SYSTEM) TAG_LIB_GRPC?=v1.59.0 GIT_REPO_LIB_GRPC?=https://github.com/grpc/grpc.git GIT_CLONE_DEPTH?=1 INSTALLED_PACKAGES=installed_packages GRPC_REPO=grpc_repo GRPC_BUILD=grpc_build export CMAKE_ARGS?= CMAKE_ARGS+=-DCMAKE_BUILD_TYPE=Release CMAKE_ARGS+=-DgRPC_INSTALL=ON CMAKE_ARGS+=-DEXECUTABLE_OUTPUT_PATH=../$(INSTALLED_PACKAGES)/grpc/bin CMAKE_ARGS+=-DLIBRARY_OUTPUT_PATH=../$(INSTALLED_PACKAGES)/grpc/lib CMAKE_ARGS+=-DgRPC_BUILD_TESTS=OFF CMAKE_ARGS+=-DgRPC_BUILD_CSHARP_EXT=OFF CMAKE_ARGS+=-DgRPC_BUILD_GRPC_CPP_PLUGIN=ON CMAKE_ARGS+=-DgRPC_BUILD_GRPC_CSHARP_PLUGIN=OFF CMAKE_ARGS+=-DgRPC_BUILD_GRPC_NODE_PLUGIN=OFF CMAKE_ARGS+=-DgRPC_BUILD_GRPC_OBJECTIVE_C_PLUGIN=OFF CMAKE_ARGS+=-DgRPC_BUILD_GRPC_PHP_PLUGIN=OFF CMAKE_ARGS+=-DgRPC_BUILD_GRPC_PYTHON_PLUGIN=ON CMAKE_ARGS+=-DgRPC_BUILD_GRPC_RUBY_PLUGIN=OFF CMAKE_ARGS+=-Dprotobuf_WITH_ZLIB=ON CMAKE_ARGS+=-DRE2_BUILD_TESTING=OFF CMAKE_ARGS+=-DCMAKE_INSTALL_PREFIX=../$(INSTALLED_PACKAGES) # windows need to set OPENSSL_NO_ASM. Results in slower crypto performance but doesn't build otherwise. # May be resolvable, but for now its set. More info: https://stackoverflow.com/a/75240504/480673 ifeq ($(SYSTEM),MSYS) CMAKE_ARGS+=-DOPENSSL_NO_ASM=ON endif ifeq ($(SYSTEM),MINGW64) CMAKE_ARGS+=-DOPENSSL_NO_ASM=ON endif ifeq ($(SYSTEM),MINGW32) CMAKE_ARGS+=-DOPENSSL_NO_ASM=ON endif ifeq ($(SYSTEM),CYGWIN) CMAKE_ARGS+=-DOPENSSL_NO_ASM=ON endif $(INSTALLED_PACKAGES): grpc_build $(GRPC_REPO): mkdir -p $(GRPC_REPO)/grpc cd $(GRPC_REPO)/grpc && \ git init && \ git remote add origin $(GIT_REPO_LIB_GRPC) && \ git fetch origin && \ git checkout $(TAG_LIB_GRPC) && \ git submodule update --init --recursive --depth 1 --single-branch $(GRPC_BUILD): $(GRPC_REPO) mkdir -p $(GRPC_BUILD) cd $(GRPC_BUILD) && cmake $(CMAKE_ARGS) ../$(GRPC_REPO)/grpc && cmake --build . && cmake --build . --target install build: $(INSTALLED_PACKAGES) rebuild: rm -rf grpc_build $(MAKE) grpc_build clean: rm -rf grpc_build rm -rf grpc_repo rm -rf installed_packages ================================================ FILE: backend/cpp/llama-cpp/CMakeLists.txt ================================================ set(TARGET grpc-server) set(CMAKE_CXX_STANDARD 17) cmake_minimum_required(VERSION 3.15) set(TARGET grpc-server) set(_PROTOBUF_LIBPROTOBUF libprotobuf) set(_REFLECTION grpc++_reflection) if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin") # Set correct Homebrew install folder for Apple Silicon and Intel Macs if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "arm64") set(HOMEBREW_DEFAULT_PREFIX "/opt/homebrew") else() set(HOMEBREW_DEFAULT_PREFIX "/usr/local") endif() link_directories("${HOMEBREW_DEFAULT_PREFIX}/lib") include_directories("${HOMEBREW_DEFAULT_PREFIX}/include") endif() find_package(absl CONFIG REQUIRED) find_package(Protobuf CONFIG REQUIRED) find_package(gRPC CONFIG REQUIRED) find_program(_PROTOBUF_PROTOC protoc) set(_GRPC_GRPCPP grpc++) find_program(_GRPC_CPP_PLUGIN_EXECUTABLE grpc_cpp_plugin) include_directories(${CMAKE_CURRENT_BINARY_DIR}) include_directories(${Protobuf_INCLUDE_DIRS}) message(STATUS "Using protobuf version ${Protobuf_VERSION} | Protobuf_INCLUDE_DIRS: ${Protobuf_INCLUDE_DIRS} | CMAKE_CURRENT_BINARY_DIR: ${CMAKE_CURRENT_BINARY_DIR}") # Proto file get_filename_component(hw_proto "../../../../../../backend/backend.proto" ABSOLUTE) get_filename_component(hw_proto_path "${hw_proto}" PATH) # Generated sources set(hw_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.cc") set(hw_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/backend.pb.h") set(hw_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.cc") set(hw_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/backend.grpc.pb.h") add_custom_command( OUTPUT "${hw_proto_srcs}" "${hw_proto_hdrs}" "${hw_grpc_srcs}" "${hw_grpc_hdrs}" COMMAND ${_PROTOBUF_PROTOC} ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}" --cpp_out "${CMAKE_CURRENT_BINARY_DIR}" -I "${hw_proto_path}" --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}" "${hw_proto}" DEPENDS "${hw_proto}") # hw_grpc_proto add_library(hw_grpc_proto ${hw_grpc_srcs} ${hw_grpc_hdrs} ${hw_proto_srcs} ${hw_proto_hdrs} ) add_executable(${TARGET} grpc-server.cpp json.hpp httplib.h) target_include_directories(${TARGET} PRIVATE ../llava) target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_link_libraries(${TARGET} PRIVATE common llama mtmd ${CMAKE_THREAD_LIBS_INIT} absl::flags hw_grpc_proto absl::flags_parse gRPC::${_REFLECTION} gRPC::${_GRPC_GRPCPP} protobuf::${_PROTOBUF_LIBPROTOBUF}) target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO) endif() ================================================ FILE: backend/cpp/llama-cpp/Makefile ================================================ LLAMA_VERSION?=5744d7ec430e2f875a393770195fda530560773f LLAMA_REPO?=https://github.com/ggerganov/llama.cpp CMAKE_ARGS?= BUILD_TYPE?= NATIVE?=false ONEAPI_VARS?=/opt/intel/oneapi/setvars.sh TARGET?=--target grpc-server JOBS?=$(shell nproc 2>/dev/null || sysctl -n hw.ncpu 2>/dev/null || echo 1) ARCH?=$(shell uname -m) # Disable Shared libs as we are linking on static gRPC and we can't mix shared and static CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF -DLLAMA_CURL=OFF CURRENT_MAKEFILE_DIR := $(dir $(abspath $(lastword $(MAKEFILE_LIST)))) ifeq ($(NATIVE),false) CMAKE_ARGS+=-DGGML_NATIVE=OFF -DLLAMA_OPENSSL=OFF endif # If build type is cublas, then we set -DGGML_CUDA=ON to CMAKE_ARGS automatically ifeq ($(BUILD_TYPE),cublas) CMAKE_ARGS+=-DGGML_CUDA=ON # If build type is openblas then we set -DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS # to CMAKE_ARGS automatically else ifeq ($(BUILD_TYPE),openblas) CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS # If build type is clblas (openCL) we set -DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path else ifeq ($(BUILD_TYPE),clblas) CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path # If it's hipblas we do have also to set CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ else ifeq ($(BUILD_TYPE),hipblas) ROCM_HOME ?= /opt/rocm ROCM_PATH ?= /opt/rocm export CXX=$(ROCM_HOME)/llvm/bin/clang++ export CC=$(ROCM_HOME)/llvm/bin/clang AMDGPU_TARGETS?=gfx803,gfx900,gfx906,gfx908,gfx90a,gfx942,gfx1010,gfx1030,gfx1032,gfx1100,gfx1101,gfx1102,gfx1200,gfx1201 CMAKE_ARGS+=-DGGML_HIP=ON -DAMDGPU_TARGETS=$(AMDGPU_TARGETS) else ifeq ($(BUILD_TYPE),vulkan) CMAKE_ARGS+=-DGGML_VULKAN=1 else ifeq ($(OS),Darwin) ifeq ($(BUILD_TYPE),) BUILD_TYPE=metal endif ifneq ($(BUILD_TYPE),metal) CMAKE_ARGS+=-DGGML_METAL=OFF else CMAKE_ARGS+=-DGGML_METAL=ON CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON CMAKE_ARGS+=-DGGML_METAL_USE_BF16=ON CMAKE_ARGS+=-DGGML_OPENMP=OFF endif TARGET+=--target ggml-metal endif ifeq ($(BUILD_TYPE),sycl_f16) CMAKE_ARGS+=-DGGML_SYCL=ON \ -DCMAKE_C_COMPILER=icx \ -DCMAKE_CXX_COMPILER=icpx \ -DCMAKE_CXX_FLAGS="-fsycl" \ -DGGML_SYCL_F16=ON endif ifeq ($(BUILD_TYPE),sycl_f32) CMAKE_ARGS+=-DGGML_SYCL=ON \ -DCMAKE_C_COMPILER=icx \ -DCMAKE_CXX_COMPILER=icpx \ -DCMAKE_CXX_FLAGS="-fsycl" endif INSTALLED_PACKAGES=$(CURDIR)/../grpc/installed_packages INSTALLED_LIB_CMAKE=$(INSTALLED_PACKAGES)/lib/cmake ADDED_CMAKE_ARGS=-Dabsl_DIR=${INSTALLED_LIB_CMAKE}/absl \ -DProtobuf_DIR=${INSTALLED_LIB_CMAKE}/protobuf \ -Dutf8_range_DIR=${INSTALLED_LIB_CMAKE}/utf8_range \ -DgRPC_DIR=${INSTALLED_LIB_CMAKE}/grpc \ -DCMAKE_CXX_STANDARD_INCLUDE_DIRECTORIES=${INSTALLED_PACKAGES}/include build-llama-cpp-grpc-server: # Conditionally build grpc for the llama backend to use if needed ifdef BUILD_GRPC_FOR_BACKEND_LLAMA $(MAKE) -C ../../grpc build _PROTOBUF_PROTOC=${INSTALLED_PACKAGES}/bin/proto \ _GRPC_CPP_PLUGIN_EXECUTABLE=${INSTALLED_PACKAGES}/bin/grpc_cpp_plugin \ PATH="${INSTALLED_PACKAGES}/bin:${PATH}" \ CMAKE_ARGS="${CMAKE_ARGS} ${ADDED_CMAKE_ARGS}" \ LLAMA_VERSION=$(LLAMA_VERSION) \ $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(VARIANT) grpc-server else echo "BUILD_GRPC_FOR_BACKEND_LLAMA is not defined." LLAMA_VERSION=$(LLAMA_VERSION) $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../$(VARIANT) grpc-server endif llama-cpp-avx2: llama.cpp cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build purge $(info ${GREEN}I llama-cpp build info:avx2${RESET}) CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="llama-cpp-avx2-build" build-llama-cpp-grpc-server cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx2-build/grpc-server llama-cpp-avx2 llama-cpp-avx512: llama.cpp cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build purge $(info ${GREEN}I llama-cpp build info:avx512${RESET}) CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on" $(MAKE) VARIANT="llama-cpp-avx512-build" build-llama-cpp-grpc-server cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx512-build/grpc-server llama-cpp-avx512 llama-cpp-avx: llama.cpp cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build purge $(info ${GREEN}I llama-cpp build info:avx${RESET}) CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) VARIANT="llama-cpp-avx-build" build-llama-cpp-grpc-server cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-avx-build/grpc-server llama-cpp-avx llama-cpp-fallback: llama.cpp cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build purge $(info ${GREEN}I llama-cpp build info:fallback${RESET}) CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) VARIANT="llama-cpp-fallback-build" build-llama-cpp-grpc-server cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-fallback-build/grpc-server llama-cpp-fallback llama-cpp-grpc: llama.cpp cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build $(MAKE) -C $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build purge $(info ${GREEN}I llama-cpp build info:grpc${RESET}) CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_RPC=ON -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" TARGET="--target grpc-server --target rpc-server" $(MAKE) VARIANT="llama-cpp-grpc-build" build-llama-cpp-grpc-server cp -rfv $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build/grpc-server llama-cpp-grpc llama-cpp-rpc-server: llama-cpp-grpc cp -rf $(CURRENT_MAKEFILE_DIR)/../llama-cpp-grpc-build/llama.cpp/build/bin/rpc-server llama-cpp-rpc-server llama.cpp: mkdir -p llama.cpp cd llama.cpp && \ git init && \ git remote add origin $(LLAMA_REPO) && \ git fetch origin && \ git checkout -b build $(LLAMA_VERSION) && \ git submodule update --init --recursive --depth 1 --single-branch llama.cpp/tools/grpc-server: llama.cpp mkdir -p llama.cpp/tools/grpc-server bash prepare.sh rebuild: bash prepare.sh rm -rf grpc-server $(MAKE) grpc-server package: bash package.sh purge: rm -rf llama.cpp/build rm -rf llama.cpp/tools/grpc-server rm -rf grpc-server clean: purge rm -rf llama.cpp grpc-server: llama.cpp llama.cpp/tools/grpc-server @echo "Building grpc-server with $(BUILD_TYPE) build type and $(CMAKE_ARGS)" ifneq (,$(findstring sycl,$(BUILD_TYPE))) +bash -c "source $(ONEAPI_VARS); \ cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release -j $(JOBS) $(TARGET)" else +cd llama.cpp && mkdir -p build && cd build && cmake .. $(CMAKE_ARGS) && cmake --build . --config Release -j $(JOBS) $(TARGET) endif cp llama.cpp/build/bin/grpc-server . ================================================ FILE: backend/cpp/llama-cpp/grpc-server.cpp ================================================ // llama.cpp gRPC C++ backend server // // Ettore Di Giacinto and llama.cpp authors // // This is a gRPC server for llama.cpp compatible with the LocalAI proto // Note: this is a re-adaptation of the original llama.cpp example/server.cpp for HTTP (https://github.com/ggerganov/llama.cpp/tree/master/examples/server), // but modified to work with gRPC // #include "server-task.cpp" #include "server-queue.cpp" #include "server-common.cpp" #include "server-context.cpp" // LocalAI #include "backend.pb.h" #include "backend.grpc.pb.h" #include "common.h" #include "chat-auto-parser.h" #include #include #include #include #include #include #include #include #include #if defined(_WIN32) #include #endif using grpc::Server; using grpc::ServerBuilder; using grpc::ServerContext; using grpc::Status; // END LocalAI ///////////////////////////////// //////////////////////////////// //////// LOCALAI code starts below here ///////////////////////////////// //////////////////////////////// bool loaded_model; // TODO: add a mutex for this, but happens only once loading the model static std::function shutdown_handler; static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; static inline void signal_handler(int signal) { if (is_terminating.test_and_set()) { // in case it hangs, we can force terminate the server by hitting Ctrl+C twice // this is for better developer experience, we can remove when the server is stable enough fprintf(stderr, "Received second interrupt, terminating immediately.\n"); exit(1); } shutdown_handler(signal); } // Forward declarations static void start_llama_server(server_context& ctx_server); static json parse_options(bool streaming, const backend::PredictOptions* predict, const common_params& params_base, llama_context* ctx); static ggml_type kv_cache_type_from_str(const std::string & s); static std::string get_all_kv_cache_types(); static void add_rpc_devices(std::string servers); static void params_parse(server_context& ctx_server, const backend::ModelOptions* request, common_params & params); static void start_llama_server(server_context& ctx_server) { LOG_INF("%s: starting llama server\n", __func__); LOG_INF("%s: waiting for model to be loaded\n", __func__); // Wait for model to be loaded first while (!loaded_model) { std::this_thread::sleep_for(std::chrono::milliseconds(100)); } LOG_INF("%s: model loaded\n", __func__); // print sample chat example to make it clear which template is used // LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, // common_chat_templates_source(ctx_server.impl->chat_params.tmpls.get()), // common_chat_format_example(ctx_server.impl->chat_params.tmpls.get(), ctx_server.impl->params_base.use_jinja).c_str(), ctx_server.impl->params_base.default_template_kwargs); // Keep the chat templates initialized in load_model() so they can be used when UseTokenizerTemplate is enabled // Templates will only be used conditionally in Predict/PredictStream when UseTokenizerTemplate is true and Messages are provided shutdown_handler = [&](int) { // this will unblock start_loop() ctx_server.terminate(); }; // TODO: refactor in common/console #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; sigint_action.sa_handler = signal_handler; sigemptyset (&sigint_action.sa_mask); sigint_action.sa_flags = 0; sigaction(SIGINT, &sigint_action, NULL); sigaction(SIGTERM, &sigint_action, NULL); #elif defined (_WIN32) auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; }; SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); #endif // this call blocks the main thread until ctx_server.terminate() is called ctx_server.start_loop(); } json parse_options(bool streaming, const backend::PredictOptions* predict, const common_params& params_base, llama_context* ctx) { // Create now a json data from the prediction options instead // json data; data["stream"] = streaming; data["cache_prompt"] = predict->promptcacheall(); data["n_predict"] = predict->tokens() == 0 ? -1 : predict->tokens(); data["top_k"] = predict->topk(); data["top_p"] = predict->topp(); data["typical_p"] = predict->typicalp(); data["temperature"] = predict->temperature(); data["repeat_last_n"] = predict->repeat(); data["repeat_penalty"] = predict->penalty(); data["frequency_penalty"] = predict->frequencypenalty(); data["presence_penalty"] = predict->presencepenalty(); data["mirostat"] = predict->mirostat(); data["mirostat_tau"] = predict->mirostattau(); data["mirostat_eta"] = predict->mirostateta(); data["n_keep"] = predict->nkeep(); data["seed"] = predict->seed(); std::string grammar_str = predict->grammar(); if (!grammar_str.empty()) { data["grammar"] = grammar_str; SRV_INF("Using grammar: %s\n", grammar_str.c_str()); } // Only set prompt if UseTokenizerTemplate is false or if no Messages are provided // When UseTokenizerTemplate is true and Messages are provided, prompt will be set via chat templates in Predict/PredictStream if (!predict->usetokenizertemplate() || predict->messages_size() == 0) { data["prompt"] = predict->prompt(); } // Extract tools and tool_choice from proto and add to data JSON SRV_INF("[TOOLS DEBUG] parse_options: Checking for tools in proto, tools().empty()=%d, tools().size()=%zu\n", predict->tools().empty() ? 1 : 0, predict->tools().size()); if (!predict->tools().empty()) { SRV_INF("[TOOLS DEBUG] parse_options: Tools string from proto (first 500 chars): %s\n", predict->tools().substr(0, std::min(500, predict->tools().size())).c_str()); try { // Parse tools JSON string and add to data json tools_json = json::parse(predict->tools()); data["tools"] = tools_json; SRV_INF("Extracted tools from proto: %s\n", predict->tools().c_str()); // Debug: Log tools count and names if (tools_json.is_array()) { SRV_INF("[TOOLS DEBUG] parse_options: Successfully parsed %zu tools from Go layer\n", tools_json.size()); for (size_t i = 0; i < tools_json.size(); i++) { if (tools_json[i].contains("function") && tools_json[i]["function"].contains("name")) { SRV_INF("[TOOLS DEBUG] parse_options: Tool %zu: %s\n", i, tools_json[i]["function"]["name"].get().c_str()); } else if (tools_json[i].contains("name")) { SRV_INF("[TOOLS DEBUG] parse_options: Tool %zu: %s\n", i, tools_json[i]["name"].get().c_str()); } } } else { SRV_WRN("[TOOLS DEBUG] parse_options: Parsed tools JSON is not an array: %s\n", tools_json.dump().c_str()); } } catch (const json::parse_error& e) { SRV_WRN("Failed to parse tools JSON from proto: %s\n", e.what()); SRV_WRN("[TOOLS DEBUG] parse_options: Tools string that failed to parse: %s\n", predict->tools().c_str()); } } else { SRV_INF("%s", "[TOOLS DEBUG] parse_options: No tools received from Go layer (predict->tools() is empty)\n"); } // Debug: Verify tools are in data after extraction if (data.contains("tools")) { SRV_INF("[TOOLS DEBUG] parse_options: Tools successfully added to data, count: %zu\n", data["tools"].is_array() ? data["tools"].size() : 0); } else { SRV_INF("%s", "[TOOLS DEBUG] parse_options: WARNING - Tools NOT in data after extraction!\n"); } if (!predict->toolchoice().empty()) { try { // Parse tool_choice JSON string json tool_choice_json = json::parse(predict->toolchoice()); // tool_choice can be a string ("auto", "none", "required") or an object // Store it as-is (string or object) so we can convert object to "required" later when adding to body_json if (tool_choice_json.is_string()) { data["tool_choice"] = tool_choice_json.get(); SRV_DBG("[TOOLS DEBUG] Received tool_choice from Go layer: %s\n", tool_choice_json.get().c_str()); } else { // Store object as-is so we can detect it later and convert to "required" data["tool_choice"] = tool_choice_json; SRV_DBG("[TOOLS DEBUG] Received tool_choice object from Go layer: %s\n", tool_choice_json.dump().c_str()); } SRV_INF("Extracted tool_choice from proto: %s\n", predict->toolchoice().c_str()); } catch (const json::parse_error& e) { // If parsing fails, treat as string data["tool_choice"] = predict->toolchoice(); SRV_INF("Extracted tool_choice as string: %s\n", predict->toolchoice().c_str()); } } // Extract logprobs and top_logprobs from proto and add to JSON data // Following server.cpp pattern: logprobs maps to n_probs when provided if (predict->logprobs() > 0) { data["logprobs"] = predict->logprobs(); // Map logprobs to n_probs (following server.cpp line 369 pattern) // n_probs will be set by params_from_json_cmpl if logprobs is provided data["n_probs"] = predict->logprobs(); SRV_INF("Using logprobs: %d\n", predict->logprobs()); } if (predict->toplogprobs() > 0) { data["top_logprobs"] = predict->toplogprobs(); SRV_INF("Using top_logprobs: %d\n", predict->toplogprobs()); } // Extract logit_bias from proto and add to JSON data if (!predict->logitbias().empty()) { try { // Parse logit_bias JSON string from proto json logit_bias_json = json::parse(predict->logitbias()); // Add to data - llama.cpp server expects it as an object (map) data["logit_bias"] = logit_bias_json; SRV_INF("Using logit_bias: %s\n", predict->logitbias().c_str()); } catch (const json::parse_error& e) { SRV_ERR("Failed to parse logit_bias JSON from proto: %s\n", e.what()); } } data["ignore_eos"] = predict->ignoreeos(); data["embeddings"] = predict->embeddings(); // Add the correlationid to json data data["correlation_id"] = predict->correlationid(); // for each image in the request, add the image data // for (int i = 0; i < predict->images_size(); i++) { data["image_data"].push_back(json { {"id", i}, {"data", predict->images(i)}, }); } // for each audio in the request, add the audio data for (int i = 0; i < predict->audios_size(); i++) { data["audio_data"].push_back(json { {"id", i}, {"data", predict->audios(i)}, }); } data["stop"] = predict->stopprompts(); // data["n_probs"] = predict->nprobs(); //TODO: images, // Serialize grammar triggers from server context to JSON array if (!params_base.sampling.grammar_triggers.empty()) { json grammar_triggers = json::array(); for (const auto& trigger : params_base.sampling.grammar_triggers) { json trigger_json; trigger_json["value"] = trigger.value; // Always serialize as WORD type since upstream converts WORD to TOKEN internally trigger_json["type"] = static_cast(COMMON_GRAMMAR_TRIGGER_TYPE_WORD); grammar_triggers.push_back(trigger_json); } data["grammar_triggers"] = grammar_triggers; } // Serialize preserved tokens from server context to JSON array if (!params_base.sampling.preserved_tokens.empty()) { json preserved_tokens = json::array(); for (const auto& token : params_base.sampling.preserved_tokens) { preserved_tokens.push_back(common_token_to_piece(ctx, token)); } data["preserved_tokens"] = preserved_tokens; } return data; } const std::vector kv_cache_types = { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_IQ4_NL, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, }; static ggml_type kv_cache_type_from_str(const std::string & s) { for (const auto & type : kv_cache_types) { if (ggml_type_name(type) == s) { return type; } } throw std::runtime_error("Unsupported cache type: " + s); } static std::string get_all_kv_cache_types() { std::ostringstream msg; for (const auto & type : kv_cache_types) { msg << ggml_type_name(type) << (&type == &kv_cache_types.back() ? "" : ", "); } return msg.str(); } // Adds an RPC server // Description here: https://github.com/ggml-org/llama.cpp/blob/master/tools/rpc/README.md static void add_rpc_devices(std::string servers) { auto rpc_servers = string_split(servers, ','); // Trim whitespace to allow more flexible configurations, such as having entries on separate lines. for (std::string & server : rpc_servers) { server.erase(0, server.find_first_not_of(" \t\n\r")); server.erase(server.find_last_not_of(" \t\n\r") + 1); } if (rpc_servers.empty()) { throw std::invalid_argument("no RPC servers specified"); } ggml_backend_reg_t rpc_reg = ggml_backend_reg_by_name("RPC"); if (!rpc_reg) { throw std::invalid_argument("failed to find RPC backend"); } typedef ggml_backend_reg_t (*ggml_backend_rpc_add_server_t)(const char * endpoint); ggml_backend_rpc_add_server_t ggml_backend_rpc_add_server_fn = (ggml_backend_rpc_add_server_t) ggml_backend_reg_get_proc_address(rpc_reg, "ggml_backend_rpc_add_server"); if (!ggml_backend_rpc_add_server_fn) { throw std::invalid_argument("failed to find RPC add server function"); } for (const auto & server : rpc_servers) { ggml_backend_reg_t reg = ggml_backend_rpc_add_server_fn(server.c_str()); ggml_backend_register(reg); } } static void params_parse(server_context& /*ctx_server*/, const backend::ModelOptions* request, common_params & params) { // this is comparable to: https://github.com/ggerganov/llama.cpp/blob/d9b33fe95bd257b36c84ee5769cc048230067d6f/examples/server/server.cpp#L1809 params.model.path = request->modelfile(); if (!request->mmproj().empty()) { params.mmproj.path = request->mmproj(); } // params.model_alias ?? params.model_alias.insert(request->modelfile()); if (!request->cachetypekey().empty()) { params.cache_type_k = kv_cache_type_from_str(request->cachetypekey()); } if (!request->cachetypevalue().empty()) { params.cache_type_v = kv_cache_type_from_str(request->cachetypevalue()); } params.n_ctx = request->contextsize(); //params.memory_f16 = request->f16memory(); params.cpuparams.n_threads = request->threads(); params.n_gpu_layers = request->ngpulayers(); params.n_batch = request->nbatch(); //params.verbosity = INT_MAX; // Enable all debug logs by setting verbosity threshold to maximum //common_log_set_verbosity_thold(INT_MAX); params.n_ubatch = request->nbatch(); // fixes issue with reranking models being limited to 512 tokens (the default n_ubatch size); allows for setting the maximum input amount of tokens thereby avoiding this error "input is too large to process. increase the physical batch size" // Initialize ctx_shift to false by default (can be overridden by options) params.ctx_shift = false; // Initialize cache_ram_mib to -1 by default (no limit, can be overridden by options) params.cache_ram_mib = -1; // Initialize n_parallel to 1 by default (can be overridden by options) params.n_parallel = 1; // Initialize grpc_servers to empty (can be overridden by options) std::string grpc_servers_option = ""; // Initialize fit_params options (can be overridden by options) // fit_params: whether to auto-adjust params to fit device memory (default: true as in llama.cpp) params.fit_params = true; // fit_params_target: target margin per device in bytes (default: 1GB per device) // Initialize as vector with default value for all devices params.fit_params_target = std::vector(llama_max_devices(), 1024 * 1024 * 1024); // fit_params_min_ctx: minimum context size for fit (default: 4096) params.fit_params_min_ctx = 4096; // Initialize additional server options (can be overridden by options) // n_cache_reuse: min chunk size for KV cache reuse via shifting (default: 0 = disabled) params.n_cache_reuse = 0; // slot_prompt_similarity: threshold for slot prompt matching (default: 0.1) params.slot_prompt_similarity = 0.1f; // swa_full: use full-size SWA cache (default: false) params.swa_full = false; // cont_batching: continuous batching (default: true, auto-enabled when n_parallel > 1) params.cont_batching = true; // check_tensors: validate tensor data (default: false) params.check_tensors = false; // warmup: enable warmup run (default: true) params.warmup = true; // no_op_offload: disable host tensor op offload (default: false) params.no_op_offload = false; // kv_unified: enable unified KV cache (default: false) params.kv_unified = false; // n_ctx_checkpoints: max context checkpoints per slot (default: 8) params.n_ctx_checkpoints = 8; // llama memory fit fails if we don't provide a buffer for tensor overrides const size_t ntbo = llama_max_tensor_buft_overrides(); while (params.tensor_buft_overrides.size() < ntbo) { params.tensor_buft_overrides.push_back({nullptr, nullptr}); } // decode options. Options are in form optname:optvale, or if booleans only optname. for (int i = 0; i < request->options_size(); i++) { std::string opt = request->options(i); std::vector opt_buf(opt.begin(), opt.end()); opt_buf.push_back('\0'); char *optname = strtok(opt_buf.data(), ":"); char *optval = strtok(NULL, ":"); std::string optval_str = (optval == NULL) ? "true" : optval; if (!strcmp(optname, "context_shift")) { if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { params.ctx_shift = true; } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { params.ctx_shift = false; } } else if (!strcmp(optname, "use_jinja") || !strcmp(optname, "jinja")) { if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { params.use_jinja = true; } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { params.use_jinja = false; } } else if (!strcmp(optname, "cache_ram")) { if (optval != NULL) { try { params.cache_ram_mib = std::stoi(optval_str); } catch (const std::exception& e) { // If conversion fails, keep default value (-1) } } } else if (!strcmp(optname, "parallel") || !strcmp(optname, "n_parallel")) { if (optval != NULL) { try { params.n_parallel = std::stoi(optval_str); if (params.n_parallel > 1) { params.cont_batching = true; } } catch (const std::exception& e) { // If conversion fails, keep default value (1) } } } else if (!strcmp(optname, "grpc_servers") || !strcmp(optname, "rpc_servers")) { if (optval != NULL) { grpc_servers_option = optval_str; } } else if (!strcmp(optname, "fit_params") || !strcmp(optname, "fit")) { if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { params.fit_params = true; } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { params.fit_params = false; } } else if (!strcmp(optname, "fit_params_target") || !strcmp(optname, "fit_target")) { if (optval != NULL) { try { // Value is in MiB, can be comma-separated list for multiple devices // Single value is broadcast across all devices std::string arg_next = optval_str; const std::regex regex{ R"([,/]+)" }; std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; std::vector split_arg{ it, {} }; if (split_arg.size() >= llama_max_devices()) { // Too many values provided continue; } if (split_arg.size() == 1) { // Single value: broadcast to all devices size_t value_mib = std::stoul(split_arg[0]); std::fill(params.fit_params_target.begin(), params.fit_params_target.end(), value_mib * 1024 * 1024); } else { // Multiple values: set per device for (size_t i = 0; i < split_arg.size() && i < params.fit_params_target.size(); i++) { params.fit_params_target[i] = std::stoul(split_arg[i]) * 1024 * 1024; } } } catch (const std::exception& e) { // If conversion fails, keep default value (1GB per device) } } } else if (!strcmp(optname, "fit_params_min_ctx") || !strcmp(optname, "fit_ctx")) { if (optval != NULL) { try { params.fit_params_min_ctx = std::stoi(optval_str); } catch (const std::exception& e) { // If conversion fails, keep default value (4096) } } } else if (!strcmp(optname, "n_cache_reuse") || !strcmp(optname, "cache_reuse")) { if (optval != NULL) { try { params.n_cache_reuse = std::stoi(optval_str); } catch (const std::exception& e) { // If conversion fails, keep default value (0) } } } else if (!strcmp(optname, "slot_prompt_similarity") || !strcmp(optname, "sps")) { if (optval != NULL) { try { params.slot_prompt_similarity = std::stof(optval_str); } catch (const std::exception& e) { // If conversion fails, keep default value (0.1) } } } else if (!strcmp(optname, "swa_full")) { if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { params.swa_full = true; } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { params.swa_full = false; } } else if (!strcmp(optname, "cont_batching") || !strcmp(optname, "continuous_batching")) { if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { params.cont_batching = true; } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { params.cont_batching = false; } } else if (!strcmp(optname, "check_tensors")) { if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { params.check_tensors = true; } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { params.check_tensors = false; } } else if (!strcmp(optname, "warmup")) { if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { params.warmup = true; } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { params.warmup = false; } } else if (!strcmp(optname, "no_op_offload")) { if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { params.no_op_offload = true; } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { params.no_op_offload = false; } } else if (!strcmp(optname, "kv_unified") || !strcmp(optname, "unified_kv")) { if (optval_str == "true" || optval_str == "1" || optval_str == "yes" || optval_str == "on" || optval_str == "enabled") { params.kv_unified = true; } else if (optval_str == "false" || optval_str == "0" || optval_str == "no" || optval_str == "off" || optval_str == "disabled") { params.kv_unified = false; } } else if (!strcmp(optname, "n_ctx_checkpoints") || !strcmp(optname, "ctx_checkpoints")) { if (optval != NULL) { try { params.n_ctx_checkpoints = std::stoi(optval_str); } catch (const std::exception& e) { // If conversion fails, keep default value (8) } } } } // Set params.n_parallel from environment variable if not set via options (fallback) if (params.n_parallel == 1) { const char *env_parallel = std::getenv("LLAMACPP_PARALLEL"); if (env_parallel != NULL) { try { params.n_parallel = std::stoi(env_parallel); if (params.n_parallel > 1) { params.cont_batching = true; } } catch (const std::exception& e) { // If conversion fails, keep default value (1) } } } // Add RPC devices from option or environment variable (fallback) if (!grpc_servers_option.empty()) { add_rpc_devices(grpc_servers_option); } else { const char *llama_grpc_servers = std::getenv("LLAMACPP_GRPC_SERVERS"); if (llama_grpc_servers != NULL) { add_rpc_devices(std::string(llama_grpc_servers)); } } // Add kv_overrides if (request->overrides_size() > 0) { for (int i = 0; i < request->overrides_size(); i++) { string_parse_kv_override(request->overrides(i).c_str(), params.kv_overrides); } } if (!params.kv_overrides.empty()) { params.kv_overrides.emplace_back(); params.kv_overrides.back().key[0] = 0; } // TODO: Add yarn if (!request->tensorsplit().empty()) { std::string arg_next = request->tensorsplit(); // split string by , and / const std::regex regex{ R"([,/]+)" }; std::sregex_token_iterator it{ arg_next.begin(), arg_next.end(), regex, -1 }; std::vector split_arg{ it, {} }; GGML_ASSERT(split_arg.size() <= llama_max_devices()); for (size_t i_device = 0; i_device < llama_max_devices(); ++i_device) { if (i_device < split_arg.size()) { params.tensor_split[i_device] = std::stof(split_arg[i_device]); } else { params.tensor_split[i_device] = 0.0f; } } } if (!request->maingpu().empty()) { params.main_gpu = std::stoi(request->maingpu()); } if (!request->loraadapter().empty() && !request->lorabase().empty()) { float scale_factor = 1.0f; if (request->lorascale() != 0.0f) { scale_factor = request->lorascale(); } // get the directory of modelfile std::string model_dir = params.model.path.substr(0, params.model.path.find_last_of("/\\")); common_adapter_lora_info lora_info; lora_info.path = model_dir + "/" + request->loraadapter(); lora_info.scale = scale_factor; lora_info.task_name = ""; lora_info.prompt_prefix = ""; lora_info.ptr = nullptr; params.lora_adapters.push_back(std::move(lora_info)); } params.use_mlock = request->mlock(); params.use_mmap = request->mmap(); if (request->flashattention() == "on" || request->flashattention() == "enabled") { params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; } else if (request->flashattention() == "off" || request->flashattention() == "disabled") { params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; } else if (request->flashattention() == "auto") { params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; } params.no_kv_offload = request->nokvoffload(); params.embedding = request->embeddings() || request->reranking(); if (request->reranking()) { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } if (request->ropescaling() == "none") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_NONE; } else if (request->ropescaling() == "yarn") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_YARN; } else if (request->ropescaling() == "linear") { params.rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_LINEAR; } if ( request->yarnextfactor() != 0.0f ) { params.yarn_ext_factor = request->yarnextfactor(); } if ( request->yarnattnfactor() != 0.0f ) { params.yarn_attn_factor = request->yarnattnfactor(); } if ( request->yarnbetafast() != 0.0f ) { params.yarn_beta_fast = request->yarnbetafast(); } if ( request->yarnbetaslow() != 0.0f ) { params.yarn_beta_slow = request->yarnbetaslow(); } if ( request->ropefreqbase() != 0.0f ) { params.rope_freq_base = request->ropefreqbase(); } if ( request->ropefreqscale() != 0.0f ) { params.rope_freq_scale = request->ropefreqscale(); } if (request->grammartriggers_size() > 0) { //params.sampling.grammar_lazy = true; // Store grammar trigger words for processing after model is loaded for (int i = 0; i < request->grammartriggers_size(); i++) { const auto & word = request->grammartriggers(i).word(); common_grammar_trigger trigger; trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_WORD; trigger.value = word; params.sampling.grammar_triggers.push_back(std::move(trigger)); } } } // GRPC Server start class BackendServiceImpl final : public backend::Backend::Service { private: server_context& ctx_server; common_params params_base; // Store copy of params_base, set after model load public: BackendServiceImpl(server_context& ctx) : ctx_server(ctx) {} grpc::Status Health(ServerContext* /*context*/, const backend::HealthMessage* /*request*/, backend::Reply* reply) override { // Implement Health RPC reply->set_message("OK"); return Status::OK; } grpc::Status LoadModel(ServerContext* /*context*/, const backend::ModelOptions* request, backend::Result* result) override { // Implement LoadModel RPC common_params params; params_parse(ctx_server, request, params); common_init(); // Ensure debug logs are enabled after common_init() sets up logging common_log_set_verbosity_thold(params.verbosity); llama_backend_init(); llama_numa_init(params.numa); LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); LOG_INF("\n"); LOG_INF("%s\n", common_params_get_system_info(params).c_str()); LOG_INF("\n"); // Capture error messages during model loading struct error_capture { std::string captured_error; std::mutex error_mutex; ggml_log_callback original_callback; void* original_user_data; } error_capture_data; // Get original log callback llama_log_get(&error_capture_data.original_callback, &error_capture_data.original_user_data); // Set custom callback to capture errors llama_log_set([](ggml_log_level level, const char * text, void * user_data) { auto* capture = static_cast(user_data); // Capture error messages if (level == GGML_LOG_LEVEL_ERROR) { std::lock_guard lock(capture->error_mutex); // Append error message, removing trailing newlines std::string msg(text); while (!msg.empty() && (msg.back() == '\n' || msg.back() == '\r')) { msg.pop_back(); } if (!msg.empty()) { if (!capture->captured_error.empty()) { capture->captured_error.append("; "); } capture->captured_error.append(msg); } } // Also call original callback to preserve logging if (capture->original_callback) { capture->original_callback(level, text, capture->original_user_data); } }, &error_capture_data); // load the model bool load_success = ctx_server.load_model(params); // Restore original log callback llama_log_set(error_capture_data.original_callback, error_capture_data.original_user_data); if (!load_success) { std::string error_msg = "Failed to load model: " + params.model.path; if (!params.mmproj.path.empty()) { error_msg += " (with mmproj: " + params.mmproj.path + ")"; } if (params.speculative.has_dft() && !params.speculative.mparams_dft.path.empty()) { error_msg += " (with draft model: " + params.speculative.mparams_dft.path + ")"; } // Add captured error details if available { std::lock_guard lock(error_capture_data.error_mutex); if (!error_capture_data.captured_error.empty()) { error_msg += ". Error: " + error_capture_data.captured_error; } else { error_msg += ". Model file may not exist or be invalid."; } } result->set_message(error_msg); result->set_success(false); return grpc::Status(grpc::StatusCode::INTERNAL, error_msg); } // Process grammar triggers now that vocab is available if (!params.sampling.grammar_triggers.empty()) { std::vector processed_triggers; for (const auto& trigger : params.sampling.grammar_triggers) { if (trigger.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { auto ids = common_tokenize(ctx_server.impl->vocab, trigger.value, /* add_special= */ false, /* parse_special= */ true); if (ids.size() == 1) { auto token = ids[0]; // Add the token to preserved_tokens if not already present if (params.sampling.preserved_tokens.find(token) == params.sampling.preserved_tokens.end()) { params.sampling.preserved_tokens.insert(token); LOG_INF("Added grammar trigger token to preserved tokens: %d (`%s`)\n", token, trigger.value.c_str()); } LOG_INF("Grammar trigger token: %d (`%s`)\n", token, trigger.value.c_str()); common_grammar_trigger processed_trigger; processed_trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; processed_trigger.value = trigger.value; processed_trigger.token = token; processed_triggers.push_back(std::move(processed_trigger)); } else { LOG_INF("Grammar trigger word: `%s`\n", trigger.value.c_str()); processed_triggers.push_back(trigger); } } else { processed_triggers.push_back(trigger); } } // Update the grammar triggers in params params.sampling.grammar_triggers = std::move(processed_triggers); } //ctx_server.init(); result->set_message("Loading succeeded"); result->set_success(true); loaded_model = true; // Store copy of params_base for use in parse_options and other methods params_base = params; return Status::OK; } // Helper function to extract logprobs from JSON response static json extract_logprobs_from_json(const json& res_json) { json logprobs_json = json::object(); // Check for OAI-compatible format: choices[0].logprobs if (res_json.contains("choices") && res_json["choices"].is_array() && res_json["choices"].size() > 0 && res_json["choices"][0].contains("logprobs")) { logprobs_json = res_json["choices"][0]["logprobs"]; } // Check for non-OAI format: completion_probabilities else if (res_json.contains("completion_probabilities")) { // Convert completion_probabilities to OAI format logprobs_json["content"] = res_json["completion_probabilities"]; } // Check for direct logprobs field else if (res_json.contains("logprobs")) { logprobs_json = res_json["logprobs"]; } return logprobs_json; } // Helper: populate chat_deltas on a Reply from oaicompat_msg_diffs (streaming chunks) static void populate_chat_deltas_from_diffs(backend::Reply & reply, const std::vector & diffs) { for (const auto & diff : diffs) { auto* delta = reply.add_chat_deltas(); if (!diff.content_delta.empty()) { delta->set_content(diff.content_delta); } if (!diff.reasoning_content_delta.empty()) { delta->set_reasoning_content(diff.reasoning_content_delta); } if (diff.tool_call_index != std::string::npos) { auto* tc = delta->add_tool_calls(); tc->set_index(static_cast(diff.tool_call_index)); if (!diff.tool_call_delta.id.empty()) { tc->set_id(diff.tool_call_delta.id); } if (!diff.tool_call_delta.name.empty()) { tc->set_name(diff.tool_call_delta.name); } if (!diff.tool_call_delta.arguments.empty()) { tc->set_arguments(diff.tool_call_delta.arguments); } } } } // Helper: populate chat_deltas on a Reply from final oaicompat_msg (non-streaming) static void populate_chat_deltas_from_final(backend::Reply & reply, const common_chat_msg & msg) { // Content delta if (!msg.content.empty() || !msg.reasoning_content.empty() || !msg.tool_calls.empty()) { auto* delta = reply.add_chat_deltas(); if (!msg.content.empty()) { delta->set_content(msg.content); } if (!msg.reasoning_content.empty()) { delta->set_reasoning_content(msg.reasoning_content); } // Tool calls as individual deltas within the same ChatDelta for (size_t i = 0; i < msg.tool_calls.size(); i++) { auto* tc = delta->add_tool_calls(); tc->set_index(static_cast(i)); tc->set_id(msg.tool_calls[i].id); tc->set_name(msg.tool_calls[i].name); tc->set_arguments(msg.tool_calls[i].arguments); } } } grpc::Status PredictStream(grpc::ServerContext* context, const backend::PredictOptions* request, grpc::ServerWriter* writer) override { if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } json data = parse_options(true, request, params_base, ctx_server.get_llama_context()); //Raise error if embeddings is set to true if (params_base.embedding) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in streaming mode"); } auto completion_id = gen_chatcmplid(); // get response reader - it contains references to the queues and will stay valid auto rd = ctx_server.get_response_reader(); try { std::vector tasks; std::string prompt_str; std::vector files; // Declare files early so it's accessible in both branches // Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_params.tmpls != nullptr) { // Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse json body_json; json messages_json = json::array(); // Find the last user message index to attach images/audio to int last_user_msg_idx = -1; for (int i = request->messages_size() - 1; i >= 0; i--) { if (request->messages(i).role() == "user") { last_user_msg_idx = i; break; } } for (int i = 0; i < request->messages_size(); i++) { const auto& msg = request->messages(i); json msg_json; msg_json["role"] = msg.role(); bool is_last_user_msg = (i == last_user_msg_idx); bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0); // Handle content - can be string, null, or array // For multimodal content, we'll embed images/audio from separate fields if (!msg.content().empty()) { // Try to parse content as JSON to see if it's already an array json content_val; try { content_val = json::parse(msg.content()); // Handle null values - convert to empty string to avoid template errors if (content_val.is_null()) { content_val = ""; } } catch (const json::parse_error&) { // Not JSON, treat as plain string content_val = msg.content(); } // If content is an object (e.g., from tool call failures), convert to string if (content_val.is_object()) { content_val = content_val.dump(); } // If content is a string and this is the last user message with images/audio, combine them if (content_val.is_string() && is_last_user_msg && has_images_or_audio) { json content_array = json::array(); // Add text first content_array.push_back({{"type", "text"}, {"text", content_val.get()}}); // Add images if (request->images_size() > 0) { for (int j = 0; j < request->images_size(); j++) { json image_chunk; image_chunk["type"] = "image_url"; json image_url; image_url["url"] = "data:image/jpeg;base64," + request->images(j); image_chunk["image_url"] = image_url; content_array.push_back(image_chunk); } } // Add audios if (request->audios_size() > 0) { for (int j = 0; j < request->audios_size(); j++) { json audio_chunk; audio_chunk["type"] = "input_audio"; json input_audio; input_audio["data"] = request->audios(j); input_audio["format"] = "wav"; // default, could be made configurable audio_chunk["input_audio"] = input_audio; content_array.push_back(audio_chunk); } } msg_json["content"] = content_array; } else { // Use content as-is (already array or not last user message) // Ensure null values are converted to empty string if (content_val.is_null()) { msg_json["content"] = ""; } else { msg_json["content"] = content_val; } } } else if (is_last_user_msg && has_images_or_audio) { // If no content but this is the last user message with images/audio, create content array json content_array = json::array(); if (request->images_size() > 0) { for (int j = 0; j < request->images_size(); j++) { json image_chunk; image_chunk["type"] = "image_url"; json image_url; image_url["url"] = "data:image/jpeg;base64," + request->images(j); image_chunk["image_url"] = image_url; content_array.push_back(image_chunk); } } if (request->audios_size() > 0) { for (int j = 0; j < request->audios_size(); j++) { json audio_chunk; audio_chunk["type"] = "input_audio"; json input_audio; input_audio["data"] = request->audios(j); input_audio["format"] = "wav"; // default, could be made configurable audio_chunk["input_audio"] = input_audio; content_array.push_back(audio_chunk); } } msg_json["content"] = content_array; } else if (msg.role() == "tool") { // Tool role messages must have content field set, even if empty // Jinja templates expect content to be a string, not null or object SRV_INF("[CONTENT DEBUG] PredictStream: Message %d is tool role, content_empty=%d\n", i, msg.content().empty() ? 1 : 0); if (msg.content().empty()) { msg_json["content"] = ""; SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): empty content, set to empty string\n", i); } else { SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): content exists: %s\n", i, msg.content().substr(0, std::min(200, msg.content().size())).c_str()); // Content exists, parse and ensure it's a string json content_val; try { content_val = json::parse(msg.content()); SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): parsed JSON, type=%s\n", i, content_val.is_null() ? "null" : content_val.is_object() ? "object" : content_val.is_string() ? "string" : content_val.is_array() ? "array" : "other"); // Handle null values - Jinja templates expect content to be a string, not null if (content_val.is_null()) { msg_json["content"] = ""; SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): null content, converted to empty string\n", i); } else if (content_val.is_object()) { // If content is an object (e.g., from tool call failures/errors), convert to string msg_json["content"] = content_val.dump(); SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): object content, converted to string: %s\n", i, content_val.dump().substr(0, std::min(200, content_val.dump().size())).c_str()); } else if (content_val.is_string()) { msg_json["content"] = content_val.get(); SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): string content, using as-is\n", i); } else { // For arrays or other types, convert to string msg_json["content"] = content_val.dump(); SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): %s content, converted to string\n", i, content_val.is_array() ? "array" : "other type"); } } catch (const json::parse_error&) { // Not JSON, treat as plain string msg_json["content"] = msg.content(); SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (tool): not JSON, using as string\n", i); } } } else { // Ensure all messages have content set (fallback for any unhandled cases) // Jinja templates expect content to be present, default to empty string if not set if (!msg_json.contains("content")) { SRV_INF("[CONTENT DEBUG] PredictStream: Message %d (role=%s): no content field, adding empty string\n", i, msg.role().c_str()); msg_json["content"] = ""; } } // Add optional fields for OpenAI-compatible message format if (!msg.name().empty()) { msg_json["name"] = msg.name(); } if (!msg.tool_call_id().empty()) { msg_json["tool_call_id"] = msg.tool_call_id(); } if (!msg.reasoning_content().empty()) { msg_json["reasoning_content"] = msg.reasoning_content(); } if (!msg.tool_calls().empty()) { // Parse tool_calls JSON string and add to message try { json tool_calls = json::parse(msg.tool_calls()); msg_json["tool_calls"] = tool_calls; SRV_INF("[TOOL CALLS DEBUG] PredictStream: Message %d has tool_calls: %s\n", i, tool_calls.dump().c_str()); // IMPORTANT: If message has tool_calls but content is empty or not set, // set content to space " " instead of empty string "", because llama.cpp's // common_chat_msgs_to_json_oaicompat converts empty strings to null (line 312), // which causes template errors when accessing message.content[:tool_start_length] if (!msg_json.contains("content") || (msg_json.contains("content") && msg_json["content"].is_string() && msg_json["content"].get().empty())) { SRV_INF("[CONTENT DEBUG] PredictStream: Message %d has tool_calls but empty content, setting to space\n", i); msg_json["content"] = " "; } // Log each tool call with name and arguments if (tool_calls.is_array()) { for (size_t tc_idx = 0; tc_idx < tool_calls.size(); tc_idx++) { const auto& tc = tool_calls[tc_idx]; std::string tool_name = "unknown"; std::string tool_args = "{}"; if (tc.contains("function")) { const auto& func = tc["function"]; if (func.contains("name")) { tool_name = func["name"].get(); } if (func.contains("arguments")) { tool_args = func["arguments"].is_string() ? func["arguments"].get() : func["arguments"].dump(); } } else if (tc.contains("name")) { tool_name = tc["name"].get(); if (tc.contains("arguments")) { tool_args = tc["arguments"].is_string() ? tc["arguments"].get() : tc["arguments"].dump(); } } SRV_INF("[TOOL CALLS DEBUG] PredictStream: Message %d, tool_call %zu: name=%s, arguments=%s\n", i, tc_idx, tool_name.c_str(), tool_args.c_str()); } } } catch (const json::parse_error& e) { SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what()); } } // Debug: Log final content state before adding to array if (msg_json.contains("content")) { if (msg_json["content"].is_null()) { SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: content is NULL - THIS WILL CAUSE ERROR!\n", i); } else { SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: content type=%s, has_value=%d\n", i, msg_json["content"].is_string() ? "string" : msg_json["content"].is_array() ? "array" : msg_json["content"].is_object() ? "object" : "other", msg_json["content"].is_null() ? 0 : 1); } } else { SRV_INF("[CONTENT DEBUG] PredictStream: Message %d FINAL STATE: NO CONTENT FIELD - THIS WILL CAUSE ERROR!\n", i); } messages_json.push_back(msg_json); } // Final safety check: Ensure no message has null content (Jinja templates require strings) SRV_INF("[CONTENT DEBUG] PredictStream: Running final safety check on %zu messages\n", messages_json.size()); for (size_t idx = 0; idx < messages_json.size(); idx++) { auto& msg = messages_json[idx]; if (msg.contains("content") && msg["content"].is_null()) { SRV_INF("[CONTENT DEBUG] PredictStream: Safety check found message %zu with NULL content, converting to empty string\n", idx); msg["content"] = ""; } else if (!msg.contains("content")) { SRV_INF("[CONTENT DEBUG] PredictStream: Safety check found message %zu without content field, adding empty string\n", idx); msg["content"] = ""; } else { SRV_INF("[CONTENT DEBUG] PredictStream: Safety check message %zu: content OK, type=%s\n", idx, msg["content"].is_string() ? "string" : msg["content"].is_array() ? "array" : msg["content"].is_object() ? "object" : "other"); } } // Debug: Count tool messages int tool_msg_count = 0; for (const auto& msg : messages_json) { if (msg.contains("role") && msg["role"] == "tool") { tool_msg_count++; } } SRV_DBG("[TOOLS DEBUG] PredictStream: Built %d tool messages out of %zu total messages\n", tool_msg_count, messages_json.size()); // Debug: Print full conversation (messages) SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full messages array:\n%s\n", messages_json.dump(2).c_str()); body_json["messages"] = messages_json; body_json["stream"] = true; // PredictStream is always streaming // Check if grammar is provided from Go layer (NoGrammar=false) // If grammar is provided, we must use it and NOT let template generate grammar from tools // oaicompat_chat_params_parse throws an error if both grammar and tools are provided bool has_grammar_from_go = data.contains("grammar") && data["grammar"].is_string() && !data["grammar"].get().empty(); SRV_INF("[TOOLS DEBUG] PredictStream: has_grammar_from_go=%d, data.contains(\"tools\")=%d, data.contains(\"grammar\")=%d\n", has_grammar_from_go ? 1 : 0, data.contains("tools") ? 1 : 0, data.contains("grammar") ? 1 : 0); if (data.contains("grammar")) { SRV_INF("[TOOLS DEBUG] PredictStream: grammar type=%s, empty=%d\n", data["grammar"].is_string() ? "string" : "other", data["grammar"].is_string() && data["grammar"].get().empty() ? 1 : 0); } // Copy other relevant fields from data that oaicompat_chat_params_parse expects // Tools and tool_choice are only passed when NoGrammar is true (grammar not provided) // When grammar is provided from Go layer, we use it instead of template-generated grammar if (!has_grammar_from_go) { // NoGrammar=true: pass tools and let template generate grammar if (data.contains("tools")) { body_json["tools"] = data["tools"]; std::string tools_str = data["tools"].dump(); SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str()); // Debug: Log tools count and details before template processing if (data["tools"].is_array()) { SRV_INF("[TOOLS DEBUG] PredictStream: Passing %zu tools to oaicompat_chat_params_parse\n", data["tools"].size()); for (size_t t_idx = 0; t_idx < data["tools"].size(); t_idx++) { const auto& tool = data["tools"][t_idx]; std::string tool_name = "unknown"; std::string tool_desc = ""; if (tool.contains("function")) { const auto& func = tool["function"]; if (func.contains("name")) { tool_name = func["name"].get(); } if (func.contains("description")) { tool_desc = func["description"].is_string() ? func["description"].get() : ""; } } else if (tool.contains("name")) { tool_name = tool["name"].get(); if (tool.contains("description")) { tool_desc = tool["description"].is_string() ? tool["description"].get() : ""; } } SRV_INF("[TOOLS DEBUG] PredictStream: Tool %zu: name=%s, description=%s\n", t_idx, tool_name.c_str(), tool_desc.substr(0, 100).c_str()); } } } else { SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n"); SRV_DBG("[TOOLS DEBUG] PredictStream: No tools in data, tool_choice=%s\n", data.contains("tool_choice") ? data["tool_choice"].dump().c_str() : "not set"); } if (data.contains("tool_choice")) { // tool_choice can be a string or object, but oaicompat_chat_params_parse expects a string // Convert object tool_choice to "required" (since a specific function is requested) if (data["tool_choice"].is_string()) { body_json["tool_choice"] = data["tool_choice"].get(); } else if (data["tool_choice"].is_object()) { // Object tool_choice means a specific function is requested, use "required" body_json["tool_choice"] = "required"; std::string tool_choice_obj_str = data["tool_choice"].dump(); SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str()); } else { // Fallback: convert to string body_json["tool_choice"] = data["tool_choice"].dump(); } std::string tool_choice_str = body_json["tool_choice"].get(); SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str()); } else { // Default to "auto" if not specified body_json["tool_choice"] = "auto"; } } else { // Grammar is provided from Go layer (NoGrammar=false) - use it, don't pass tools SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n"); // Grammar will be copied from data after parsing (it's already in data) } if (data.contains("json_schema")) { body_json["json_schema"] = data["json_schema"]; } // If grammar is provided from Go layer, copy it to body_json so it's preserved // (though oaicompat_chat_params_parse may not use it if tools are present) if (has_grammar_from_go) { body_json["grammar"] = data["grammar"]; } if (data.contains("response_format")) { body_json["response_format"] = data["response_format"]; } if (data.contains("chat_template_kwargs")) { body_json["chat_template_kwargs"] = data["chat_template_kwargs"]; } // Pass parallel_tool_calls if present (used by oaicompat_chat_params_parse) if (data.contains("parallel_tool_calls")) { body_json["parallel_tool_calls"] = data["parallel_tool_calls"]; } // Pass add_generation_prompt if present (used by oaicompat_chat_params_parse) if (data.contains("add_generation_prompt")) { body_json["add_generation_prompt"] = data["add_generation_prompt"]; } // Pass sampling parameters to body_json so oaicompat_chat_params_parse respects them // and doesn't overwrite them with defaults in the returned parsed_data if (data.contains("n_predict")) { body_json["max_tokens"] = data["n_predict"]; } if (data.contains("ignore_eos")) { body_json["ignore_eos"] = data["ignore_eos"]; } if (data.contains("stop")) { body_json["stop"] = data["stop"]; } if (data.contains("temperature")) { body_json["temperature"] = data["temperature"]; } if (data.contains("top_p")) { body_json["top_p"] = data["top_p"]; } if (data.contains("frequency_penalty")) { body_json["frequency_penalty"] = data["frequency_penalty"]; } if (data.contains("presence_penalty")) { body_json["presence_penalty"] = data["presence_penalty"]; } if (data.contains("seed")) { body_json["seed"] = data["seed"]; } if (data.contains("logit_bias")) { body_json["logit_bias"] = data["logit_bias"]; } if (data.contains("top_k")) { body_json["top_k"] = data["top_k"]; } if (data.contains("min_p")) { body_json["min_p"] = data["min_p"]; } // Pass enable_thinking via chat_template_kwargs (where oaicompat_chat_params_parse reads it) const auto& metadata = request->metadata(); auto et_it = metadata.find("enable_thinking"); if (et_it != metadata.end()) { if (!body_json.contains("chat_template_kwargs")) { body_json["chat_template_kwargs"] = json::object(); } body_json["chat_template_kwargs"]["enable_thinking"] = (et_it->second == "true"); } // Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.) SRV_DBG("[CONVERSATION DEBUG] PredictStream: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str()); // Use the same approach as server.cpp: call oaicompat_chat_params_parse // This handles all template application, grammar merging, etc. automatically // Files extracted from multimodal content in messages will be added to the files vector // chat_params already contains tmpls, allow_image, and allow_audio set during model loading // Debug: Log tools before template processing if (body_json.contains("tools")) { SRV_DBG("[TOOLS DEBUG] PredictStream: Before oaicompat_chat_params_parse - tools count: %zu\n", body_json["tools"].is_array() ? body_json["tools"].size() : 0); } // Debug: Verify messages content before template processing // Also ensure ALL messages have content set to string (not null) - templates expect strings if (body_json.contains("messages") && body_json["messages"].is_array()) { SRV_INF("[CONTENT DEBUG] PredictStream: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size()); for (size_t idx = 0; idx < body_json["messages"].size(); idx++) { auto& msg = body_json["messages"][idx]; std::string role_str = msg.contains("role") ? msg["role"].get() : "unknown"; if (msg.contains("content")) { if (msg["content"].is_null()) { SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) has NULL content - FIXING!\n", idx, role_str.c_str()); msg["content"] = ""; // Fix null content } else if (role_str == "tool" && msg["content"].is_array()) { // Tool messages must have string content, not array // oaicompat_chat_params_parse expects tool messages to have string content SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=tool) has array content, converting to string\n", idx); msg["content"] = msg["content"].dump(); } else if (!msg["content"].is_string() && !msg["content"].is_array()) { // If content is object or other non-string type, convert to string for templates SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) content is not string/array, converting\n", idx, role_str.c_str()); if (msg["content"].is_object()) { msg["content"] = msg["content"].dump(); } else { msg["content"] = ""; } } else { SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s): content type=%s\n", idx, role_str.c_str(), msg["content"].is_string() ? "string" : msg["content"].is_array() ? "array" : msg["content"].is_object() ? "object" : "other"); } } else { SRV_INF("[CONTENT DEBUG] PredictStream: BEFORE TEMPLATE - Message %zu (role=%s) MISSING content field - ADDING!\n", idx, role_str.c_str()); msg["content"] = ""; // Add missing content } } } json parsed_data = oaicompat_chat_params_parse(body_json, ctx_server.impl->chat_params, files); // Debug: Log tools after template processing if (parsed_data.contains("tools")) { SRV_DBG("[TOOLS DEBUG] PredictStream: After oaicompat_chat_params_parse - tools count: %zu\n", parsed_data["tools"].is_array() ? parsed_data["tools"].size() : 0); } else { SRV_DBG("%s", "[TOOLS DEBUG] PredictStream: After oaicompat_chat_params_parse - no tools in parsed_data\n"); } // Extract the prompt from parsed data prompt_str = parsed_data.at("prompt").get(); // Preserve grammar from Go layer if it was provided (NoGrammar=false) // Otherwise, use grammar from parsed_data (template-generated when NoGrammar=true) json preserved_grammar; if (has_grammar_from_go && data.contains("grammar")) { preserved_grammar = data["grammar"]; } // Merge all fields from parsed_data into data (grammar, grammar_triggers, preserved_tokens, parse_tool_calls, etc.) // This ensures all template-generated fields are included // parse_tool_calls is set by oaicompat_chat_params_parse when tools are present for (const auto& item : parsed_data.items()) { if (item.key() != "prompt") { // Don't overwrite prompt_str, we already extracted it // If grammar was provided from Go layer, preserve it instead of template-generated grammar if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) { data["grammar"] = preserved_grammar; } else { data[item.key()] = item.value(); } } } // Debug: Log parse_tool_calls if present (set by oaicompat_chat_params_parse when tools are present) if (data.contains("parse_tool_calls")) { SRV_DBG("[TOOLS DEBUG] PredictStream: parse_tool_calls=%s\n", data["parse_tool_calls"].get() ? "true" : "false"); } } else { // Use prompt directly from data if (data.contains("prompt") && data["prompt"].is_string()) { prompt_str = data["prompt"].get(); } else { prompt_str = request->prompt(); } } const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); // If not using chat templates, extract files from image_data/audio_data fields // (If using chat templates, files were already extracted by oaicompat_chat_params_parse) if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_params.tmpls == nullptr) { const auto &images_data = data.find("image_data"); if (images_data != data.end() && images_data->is_array()) { for (const auto &img : *images_data) { auto decoded_data = base64_decode(img["data"].get()); files.push_back(decoded_data); } } const auto &audio_data = data.find("audio_data"); if (audio_data != data.end() && audio_data->is_array()) { for (const auto &audio : *audio_data) { auto decoded_data = base64_decode(audio["data"].get()); files.push_back(decoded_data); } } } const bool has_mtmd = ctx_server.impl->mctx != nullptr; // process prompt std::vector inputs; if (has_mtmd) { // multimodal inputs.push_back(process_mtmd_prompt(ctx_server.impl->mctx, prompt_str, files)); } else { // Everything else, including multimodal completions. inputs = tokenize_input_prompts(ctx_server.impl->vocab, ctx_server.impl->mctx, prompt_str, true, true); } tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); task.id = rd.queue_tasks.get_new_id(); task.index = i; task.tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( ctx_server.impl->vocab, params_base, ctx_server.get_meta().slot_n_ctx, data); task.id_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.res_type = TASK_RESPONSE_TYPE_NONE; task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(std::move(task)); } rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } // Get first result for error checking (following server.cpp pattern) server_task_result_ptr first_result = rd.next([&context]() { return context->IsCancelled(); }); if (first_result == nullptr) { // connection is closed return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); } else if (first_result->is_error()) { json error_json = first_result->to_json(); backend::Reply reply; reply.set_message(error_json.value("message", "")); writer->Write(reply); return grpc::Status(grpc::StatusCode::INTERNAL, error_json.value("message", "Error occurred")); } // Lambda to build a Reply from JSON + attach chat deltas from a result auto build_reply_from_json = [](const json & res_json, server_task_result * raw_result) -> backend::Reply { backend::Reply reply; std::string completion_text = res_json.value("content", ""); reply.set_message(completion_text); reply.set_tokens(res_json.value("tokens_predicted", 0)); reply.set_prompt_tokens(res_json.value("tokens_evaluated", 0)); if (res_json.contains("timings")) { reply.set_timing_prompt_processing(res_json.at("timings").value("prompt_ms", 0.0)); reply.set_timing_token_generation(res_json.at("timings").value("predicted_ms", 0.0)); } json logprobs_json = extract_logprobs_from_json(res_json); if (!logprobs_json.empty() && !logprobs_json.is_null()) { reply.set_logprobs(logprobs_json.dump()); } return reply; }; auto attach_chat_deltas = [](backend::Reply & reply, server_task_result * raw_result) { // Try streaming partial result first auto* partial = dynamic_cast(raw_result); if (partial && !partial->oaicompat_msg_diffs.empty()) { populate_chat_deltas_from_diffs(reply, partial->oaicompat_msg_diffs); return; } // Try final result auto* final_res = dynamic_cast(raw_result); if (final_res && final_res->is_updated) { populate_chat_deltas_from_diffs(reply, final_res->oaicompat_msg_diffs); } }; // Process first result json first_res_json = first_result->to_json(); if (first_res_json.is_array()) { for (const auto & res : first_res_json) { auto reply = build_reply_from_json(res, first_result.get()); attach_chat_deltas(reply, first_result.get()); writer->Write(reply); } } else { auto reply = build_reply_from_json(first_res_json, first_result.get()); attach_chat_deltas(reply, first_result.get()); writer->Write(reply); } // Process subsequent results while (rd.has_next()) { if (context->IsCancelled()) { break; } auto result = rd.next([&context]() { return context->IsCancelled(); }); if (result == nullptr) { break; } json res_json = result->to_json(); if (res_json.is_array()) { for (const auto & res : res_json) { auto reply = build_reply_from_json(res, result.get()); attach_chat_deltas(reply, result.get()); writer->Write(reply); } } else { auto reply = build_reply_from_json(res_json, result.get()); attach_chat_deltas(reply, result.get()); writer->Write(reply); } } // Check if context was cancelled during processing if (context->IsCancelled()) { return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); } return grpc::Status::OK; } grpc::Status Predict(ServerContext* context, const backend::PredictOptions* request, backend::Reply* reply) override { if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } json data = parse_options(true, request, params_base, ctx_server.get_llama_context()); data["stream"] = false; //Raise error if embeddings is set to true if (params_base.embedding) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Embedding is not supported in Predict mode"); } std::cout << "[PREDICT] Received result: " << data.dump(2) << std::endl; auto completion_id = gen_chatcmplid(); auto rd = ctx_server.get_response_reader(); try { std::vector tasks; std::string prompt_str; std::vector files; // Declare files early so it's accessible in both branches // Handle chat templates when UseTokenizerTemplate is enabled and Messages are provided if (request->usetokenizertemplate() && request->messages_size() > 0 && ctx_server.impl->chat_params.tmpls != nullptr) { // Convert proto Messages to JSON format compatible with oaicompat_chat_params_parse json body_json; json messages_json = json::array(); // Find the last user message index to attach images/audio to int last_user_msg_idx = -1; for (int i = request->messages_size() - 1; i >= 0; i--) { if (request->messages(i).role() == "user") { last_user_msg_idx = i; break; } } SRV_INF("[CONTENT DEBUG] Predict: Processing %d messages\n", request->messages_size()); for (int i = 0; i < request->messages_size(); i++) { const auto& msg = request->messages(i); json msg_json; msg_json["role"] = msg.role(); SRV_INF("[CONTENT DEBUG] Predict: Message %d: role=%s, content_empty=%d, content_length=%zu\n", i, msg.role().c_str(), msg.content().empty() ? 1 : 0, msg.content().size()); if (!msg.content().empty()) { SRV_INF("[CONTENT DEBUG] Predict: Message %d content (first 200 chars): %s\n", i, msg.content().substr(0, std::min(200, msg.content().size())).c_str()); } bool is_last_user_msg = (i == last_user_msg_idx); bool has_images_or_audio = (request->images_size() > 0 || request->audios_size() > 0); // Handle content - can be string, null, or array // For multimodal content, we'll embed images/audio from separate fields if (!msg.content().empty()) { // Try to parse content as JSON to see if it's already an array json content_val; try { content_val = json::parse(msg.content()); // Handle null values - convert to empty string to avoid template errors if (content_val.is_null()) { SRV_INF("[CONTENT DEBUG] Predict: Message %d parsed JSON is null, converting to empty string\n", i); content_val = ""; } } catch (const json::parse_error&) { // Not JSON, treat as plain string content_val = msg.content(); } // If content is an object (e.g., from tool call failures), convert to string if (content_val.is_object()) { SRV_INF("[CONTENT DEBUG] Predict: Message %d content is object, converting to string\n", i); content_val = content_val.dump(); } // If content is a string and this is the last user message with images/audio, combine them if (content_val.is_string() && is_last_user_msg && has_images_or_audio) { json content_array = json::array(); // Add text first content_array.push_back({{"type", "text"}, {"text", content_val.get()}}); // Add images if (request->images_size() > 0) { for (int j = 0; j < request->images_size(); j++) { json image_chunk; image_chunk["type"] = "image_url"; json image_url; image_url["url"] = "data:image/jpeg;base64," + request->images(j); image_chunk["image_url"] = image_url; content_array.push_back(image_chunk); } } // Add audios if (request->audios_size() > 0) { for (int j = 0; j < request->audios_size(); j++) { json audio_chunk; audio_chunk["type"] = "input_audio"; json input_audio; input_audio["data"] = request->audios(j); input_audio["format"] = "wav"; // default, could be made configurable audio_chunk["input_audio"] = input_audio; content_array.push_back(audio_chunk); } } msg_json["content"] = content_array; } else { // Use content as-is (already array or not last user message) // Ensure null values are converted to empty string if (content_val.is_null()) { SRV_INF("[CONTENT DEBUG] Predict: Message %d content_val was null, setting to empty string\n", i); msg_json["content"] = ""; } else { msg_json["content"] = content_val; SRV_INF("[CONTENT DEBUG] Predict: Message %d content set, type=%s\n", i, content_val.is_string() ? "string" : content_val.is_array() ? "array" : content_val.is_object() ? "object" : "other"); } } } else if (is_last_user_msg && has_images_or_audio) { // If no content but this is the last user message with images/audio, create content array json content_array = json::array(); if (request->images_size() > 0) { for (int j = 0; j < request->images_size(); j++) { json image_chunk; image_chunk["type"] = "image_url"; json image_url; image_url["url"] = "data:image/jpeg;base64," + request->images(j); image_chunk["image_url"] = image_url; content_array.push_back(image_chunk); } } if (request->audios_size() > 0) { for (int j = 0; j < request->audios_size(); j++) { json audio_chunk; audio_chunk["type"] = "input_audio"; json input_audio; input_audio["data"] = request->audios(j); input_audio["format"] = "wav"; // default, could be made configurable audio_chunk["input_audio"] = input_audio; content_array.push_back(audio_chunk); } } msg_json["content"] = content_array; SRV_INF("[CONTENT DEBUG] Predict: Message %d created content array with media\n", i); } else if (!msg.tool_calls().empty()) { // Tool call messages may have null content, but templates expect string // IMPORTANT: Set to space " " instead of empty string "", because llama.cpp's // common_chat_msgs_to_json_oaicompat converts empty strings to null (line 312), // which causes template errors when accessing message.content[:tool_start_length] SRV_INF("[CONTENT DEBUG] Predict: Message %d has tool_calls, setting content to space (not empty string)\n", i); msg_json["content"] = " "; } else if (msg.role() == "tool") { // Tool role messages must have content field set, even if empty // Jinja templates expect content to be a string, not null or object SRV_INF("[CONTENT DEBUG] Predict: Message %d is tool role, content_empty=%d\n", i, msg.content().empty() ? 1 : 0); if (msg.content().empty()) { msg_json["content"] = ""; SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): empty content, set to empty string\n", i); } else { SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): content exists: %s\n", i, msg.content().substr(0, std::min(200, msg.content().size())).c_str()); // Content exists, parse and ensure it's a string json content_val; try { content_val = json::parse(msg.content()); SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): parsed JSON, type=%s\n", i, content_val.is_null() ? "null" : content_val.is_object() ? "object" : content_val.is_string() ? "string" : content_val.is_array() ? "array" : "other"); // Handle null values - Jinja templates expect content to be a string, not null if (content_val.is_null()) { msg_json["content"] = ""; SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): null content, converted to empty string\n", i); } else if (content_val.is_object()) { // If content is an object (e.g., from tool call failures/errors), convert to string msg_json["content"] = content_val.dump(); SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): object content, converted to string: %s\n", i, content_val.dump().substr(0, std::min(200, content_val.dump().size())).c_str()); } else if (content_val.is_string()) { msg_json["content"] = content_val.get(); SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): string content, using as-is\n", i); } else { // For arrays or other types, convert to string msg_json["content"] = content_val.dump(); SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): %s content, converted to string\n", i, content_val.is_array() ? "array" : "other type"); } } catch (const json::parse_error&) { // Not JSON, treat as plain string msg_json["content"] = msg.content(); SRV_INF("[CONTENT DEBUG] Predict: Message %d (tool): not JSON, using as string\n", i); } } } else { // Ensure all messages have content set (fallback for any unhandled cases) // Jinja templates expect content to be present, default to empty string if not set if (!msg_json.contains("content")) { SRV_INF("[CONTENT DEBUG] Predict: Message %d (role=%s): no content field, adding empty string\n", i, msg.role().c_str()); msg_json["content"] = ""; } } // Add optional fields for OpenAI-compatible message format if (!msg.name().empty()) { msg_json["name"] = msg.name(); } if (!msg.tool_call_id().empty()) { msg_json["tool_call_id"] = msg.tool_call_id(); } if (!msg.reasoning_content().empty()) { msg_json["reasoning_content"] = msg.reasoning_content(); } if (!msg.tool_calls().empty()) { // Parse tool_calls JSON string and add to message try { json tool_calls = json::parse(msg.tool_calls()); msg_json["tool_calls"] = tool_calls; SRV_INF("[TOOL CALLS DEBUG] Predict: Message %d has tool_calls: %s\n", i, tool_calls.dump().c_str()); // IMPORTANT: If message has tool_calls but content is empty or not set, // set content to space " " instead of empty string "", because llama.cpp's // common_chat_msgs_to_json_oaicompat converts empty strings to null (line 312), // which causes template errors when accessing message.content[:tool_start_length] if (!msg_json.contains("content") || (msg_json.contains("content") && msg_json["content"].is_string() && msg_json["content"].get().empty())) { SRV_INF("[CONTENT DEBUG] Predict: Message %d has tool_calls but empty content, setting to space\n", i); msg_json["content"] = " "; } // Log each tool call with name and arguments if (tool_calls.is_array()) { for (size_t tc_idx = 0; tc_idx < tool_calls.size(); tc_idx++) { const auto& tc = tool_calls[tc_idx]; std::string tool_name = "unknown"; std::string tool_args = "{}"; if (tc.contains("function")) { const auto& func = tc["function"]; if (func.contains("name")) { tool_name = func["name"].get(); } if (func.contains("arguments")) { tool_args = func["arguments"].is_string() ? func["arguments"].get() : func["arguments"].dump(); } } else if (tc.contains("name")) { tool_name = tc["name"].get(); if (tc.contains("arguments")) { tool_args = tc["arguments"].is_string() ? tc["arguments"].get() : tc["arguments"].dump(); } } SRV_INF("[TOOL CALLS DEBUG] Predict: Message %d, tool_call %zu: name=%s, arguments=%s\n", i, tc_idx, tool_name.c_str(), tool_args.c_str()); } } } catch (const json::parse_error& e) { SRV_WRN("Failed to parse tool_calls JSON: %s\n", e.what()); } } // Debug: Log final content state before adding to array if (msg_json.contains("content")) { if (msg_json["content"].is_null()) { SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: content is NULL - THIS WILL CAUSE ERROR!\n", i); } else { SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: content type=%s, has_value=%d\n", i, msg_json["content"].is_string() ? "string" : msg_json["content"].is_array() ? "array" : msg_json["content"].is_object() ? "object" : "other", msg_json["content"].is_null() ? 0 : 1); } } else { SRV_INF("[CONTENT DEBUG] Predict: Message %d FINAL STATE: NO CONTENT FIELD - THIS WILL CAUSE ERROR!\n", i); } messages_json.push_back(msg_json); } // Final safety check: Ensure no message has null content (Jinja templates require strings) SRV_INF("[CONTENT DEBUG] Predict: Running final safety check on %zu messages\n", messages_json.size()); for (size_t idx = 0; idx < messages_json.size(); idx++) { auto& msg = messages_json[idx]; std::string role_str = msg.contains("role") ? msg["role"].get() : "unknown"; if (msg.contains("content") && msg["content"].is_null()) { SRV_INF("[CONTENT DEBUG] Predict: Safety check found message %zu (role=%s) with NULL content, converting to empty string\n", idx, role_str.c_str()); msg["content"] = ""; } else if (!msg.contains("content")) { SRV_INF("[CONTENT DEBUG] Predict: Safety check found message %zu (role=%s) without content field, adding empty string\n", idx, role_str.c_str()); msg["content"] = ""; } else { SRV_INF("[CONTENT DEBUG] Predict: Safety check message %zu (role=%s): content OK, type=%s\n", idx, role_str.c_str(), msg["content"].is_string() ? "string" : msg["content"].is_array() ? "array" : msg["content"].is_object() ? "object" : "other"); } } // Debug: Count tool messages int tool_msg_count = 0; for (const auto& msg : messages_json) { if (msg.contains("role") && msg["role"] == "tool") { tool_msg_count++; } } SRV_DBG("[TOOLS DEBUG] Predict: Built %d tool messages out of %zu total messages\n", tool_msg_count, messages_json.size()); // Debug: Print full conversation (messages) SRV_DBG("[CONVERSATION DEBUG] Predict: Full messages array:\n%s\n", messages_json.dump(2).c_str()); body_json["messages"] = messages_json; body_json["stream"] = false; // Check if grammar is provided from Go layer (NoGrammar=false) // If grammar is provided, we must use it and NOT let template generate grammar from tools // oaicompat_chat_params_parse throws an error if both grammar and tools are provided bool has_grammar_from_go = data.contains("grammar") && data["grammar"].is_string() && !data["grammar"].get().empty(); SRV_INF("[TOOLS DEBUG] Predict: has_grammar_from_go=%d, data.contains(\"tools\")=%d, data.contains(\"grammar\")=%d\n", has_grammar_from_go ? 1 : 0, data.contains("tools") ? 1 : 0, data.contains("grammar") ? 1 : 0); if (data.contains("grammar")) { SRV_INF("[TOOLS DEBUG] Predict: grammar type=%s, empty=%d\n", data["grammar"].is_string() ? "string" : "other", data["grammar"].is_string() && data["grammar"].get().empty() ? 1 : 0); } // Copy other relevant fields from data that oaicompat_chat_params_parse expects // Tools and tool_choice are only passed when NoGrammar is true (grammar not provided) // When grammar is provided from Go layer, we use it instead of template-generated grammar if (!has_grammar_from_go) { // NoGrammar=true: pass tools and let template generate grammar if (data.contains("tools")) { body_json["tools"] = data["tools"]; std::string tools_str = data["tools"].dump(); SRV_INF("Using tools from data (NoGrammar=true): %s\n", tools_str.c_str()); // Debug: Log tools count and details before template processing if (data["tools"].is_array()) { SRV_INF("[TOOLS DEBUG] Predict: Passing %zu tools to oaicompat_chat_params_parse\n", data["tools"].size()); for (size_t t_idx = 0; t_idx < data["tools"].size(); t_idx++) { const auto& tool = data["tools"][t_idx]; std::string tool_name = "unknown"; std::string tool_desc = ""; if (tool.contains("function")) { const auto& func = tool["function"]; if (func.contains("name")) { tool_name = func["name"].get(); } if (func.contains("description")) { tool_desc = func["description"].is_string() ? func["description"].get() : ""; } } else if (tool.contains("name")) { tool_name = tool["name"].get(); if (tool.contains("description")) { tool_desc = tool["description"].is_string() ? tool["description"].get() : ""; } } SRV_INF("[TOOLS DEBUG] Predict: Tool %zu: name=%s, description=%s\n", t_idx, tool_name.c_str(), tool_desc.substr(0, 100).c_str()); } } } else { SRV_WRN("%s", "No tools found in data - tool calls will not work without tools field\n"); SRV_DBG("[TOOLS DEBUG] Predict: No tools in data, tool_choice=%s\n", data.contains("tool_choice") ? data["tool_choice"].dump().c_str() : "not set"); } if (data.contains("tool_choice")) { // tool_choice can be a string or object, but oaicompat_chat_params_parse expects a string // Convert object tool_choice to "required" (since a specific function is requested) if (data["tool_choice"].is_string()) { body_json["tool_choice"] = data["tool_choice"].get(); } else if (data["tool_choice"].is_object()) { // Object tool_choice means a specific function is requested, use "required" body_json["tool_choice"] = "required"; std::string tool_choice_obj_str = data["tool_choice"].dump(); SRV_INF("Converted object tool_choice to 'required': %s\n", tool_choice_obj_str.c_str()); } else { // Fallback: convert to string body_json["tool_choice"] = data["tool_choice"].dump(); } std::string tool_choice_str = body_json["tool_choice"].get(); SRV_INF("Using tool_choice: %s\n", tool_choice_str.c_str()); } else { // Default to "auto" if not specified body_json["tool_choice"] = "auto"; } } else { // Grammar is provided from Go layer (NoGrammar=false) - use it, don't pass tools SRV_INF("%s", "Grammar provided from Go layer - using it instead of template-generated grammar\n"); // Grammar will be copied from data after parsing (it's already in data) } if (data.contains("json_schema")) { body_json["json_schema"] = data["json_schema"]; } // If grammar is provided from Go layer, copy it to body_json so it's preserved // (though oaicompat_chat_params_parse may not use it if tools are present) if (has_grammar_from_go) { body_json["grammar"] = data["grammar"]; } if (data.contains("response_format")) { body_json["response_format"] = data["response_format"]; } if (data.contains("chat_template_kwargs")) { body_json["chat_template_kwargs"] = data["chat_template_kwargs"]; } // Pass parallel_tool_calls if present (used by oaicompat_chat_params_parse) if (data.contains("parallel_tool_calls")) { body_json["parallel_tool_calls"] = data["parallel_tool_calls"]; } // Pass add_generation_prompt if present (used by oaicompat_chat_params_parse) if (data.contains("add_generation_prompt")) { body_json["add_generation_prompt"] = data["add_generation_prompt"]; } // Pass sampling parameters to body_json so oaicompat_chat_params_parse respects them // and doesn't overwrite them with defaults in the returned parsed_data if (data.contains("n_predict")) { body_json["max_tokens"] = data["n_predict"]; } if (data.contains("ignore_eos")) { body_json["ignore_eos"] = data["ignore_eos"]; } if (data.contains("stop")) { body_json["stop"] = data["stop"]; } if (data.contains("temperature")) { body_json["temperature"] = data["temperature"]; } if (data.contains("top_p")) { body_json["top_p"] = data["top_p"]; } if (data.contains("frequency_penalty")) { body_json["frequency_penalty"] = data["frequency_penalty"]; } if (data.contains("presence_penalty")) { body_json["presence_penalty"] = data["presence_penalty"]; } if (data.contains("seed")) { body_json["seed"] = data["seed"]; } if (data.contains("logit_bias")) { body_json["logit_bias"] = data["logit_bias"]; } if (data.contains("top_k")) { body_json["top_k"] = data["top_k"]; } if (data.contains("min_p")) { body_json["min_p"] = data["min_p"]; } // Pass enable_thinking via chat_template_kwargs (where oaicompat_chat_params_parse reads it) const auto& predict_metadata = request->metadata(); auto predict_et_it = predict_metadata.find("enable_thinking"); if (predict_et_it != predict_metadata.end()) { if (!body_json.contains("chat_template_kwargs")) { body_json["chat_template_kwargs"] = json::object(); } body_json["chat_template_kwargs"]["enable_thinking"] = (predict_et_it->second == "true"); } // Debug: Print full body_json before template processing (includes messages, tools, tool_choice, etc.) SRV_DBG("[CONVERSATION DEBUG] Predict: Full body_json before oaicompat_chat_params_parse:\n%s\n", body_json.dump(2).c_str()); // Use the same approach as server.cpp: call oaicompat_chat_params_parse // This handles all template application, grammar merging, etc. automatically // Files extracted from multimodal content in messages will be added to the files vector // chat_params already contains tmpls, allow_image, and allow_audio set during model loading // Debug: Log tools before template processing if (body_json.contains("tools")) { SRV_DBG("[TOOLS DEBUG] Predict: Before oaicompat_chat_params_parse - tools count: %zu\n", body_json["tools"].is_array() ? body_json["tools"].size() : 0); } // Debug: Verify messages content before template processing // Also ensure ALL messages have content set to string (not null) - templates expect strings if (body_json.contains("messages") && body_json["messages"].is_array()) { SRV_INF("[CONTENT DEBUG] Predict: Before oaicompat_chat_params_parse - checking %zu messages\n", body_json["messages"].size()); for (size_t idx = 0; idx < body_json["messages"].size(); idx++) { auto& msg = body_json["messages"][idx]; std::string role_str = msg.contains("role") ? msg["role"].get() : "unknown"; if (msg.contains("content")) { if (msg["content"].is_null()) { SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) has NULL content - FIXING!\n", idx, role_str.c_str()); msg["content"] = ""; // Fix null content } else if (role_str == "tool" && msg["content"].is_array()) { // Tool messages must have string content, not array // oaicompat_chat_params_parse expects tool messages to have string content SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=tool) has array content, converting to string\n", idx); msg["content"] = msg["content"].dump(); } else if (!msg["content"].is_string() && !msg["content"].is_array()) { // If content is object or other non-string type, convert to string for templates SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) content is not string/array, converting\n", idx, role_str.c_str()); if (msg["content"].is_object()) { msg["content"] = msg["content"].dump(); } else { msg["content"] = ""; } } else { SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s): content type=%s\n", idx, role_str.c_str(), msg["content"].is_string() ? "string" : msg["content"].is_array() ? "array" : msg["content"].is_object() ? "object" : "other"); } } else { SRV_INF("[CONTENT DEBUG] Predict: BEFORE TEMPLATE - Message %zu (role=%s) MISSING content field - ADDING!\n", idx, role_str.c_str()); msg["content"] = ""; // Add missing content } } } json parsed_data = oaicompat_chat_params_parse(body_json, ctx_server.impl->chat_params, files); // Debug: Log tools after template processing if (parsed_data.contains("tools")) { SRV_DBG("[TOOLS DEBUG] Predict: After oaicompat_chat_params_parse - tools count: %zu\n", parsed_data["tools"].is_array() ? parsed_data["tools"].size() : 0); } else { SRV_DBG("%s", "[TOOLS DEBUG] Predict: After oaicompat_chat_params_parse - no tools in parsed_data\n"); } // Extract the prompt from parsed data prompt_str = parsed_data.at("prompt").get(); // Preserve grammar from Go layer if it was provided (NoGrammar=false) // Otherwise, use grammar from parsed_data (template-generated when NoGrammar=true) json preserved_grammar; if (has_grammar_from_go && data.contains("grammar")) { preserved_grammar = data["grammar"]; } // Merge all fields from parsed_data into data (grammar, grammar_triggers, preserved_tokens, parse_tool_calls, etc.) // This ensures all template-generated fields are included // parse_tool_calls is set by oaicompat_chat_params_parse when tools are present for (const auto& item : parsed_data.items()) { if (item.key() != "prompt") { // Don't overwrite prompt_str, we already extracted it // If grammar was provided from Go layer, preserve it instead of template-generated grammar if (item.key() == "grammar" && has_grammar_from_go && !preserved_grammar.is_null()) { data["grammar"] = preserved_grammar; } else { data[item.key()] = item.value(); } } } // Debug: Log parse_tool_calls if present (set by oaicompat_chat_params_parse when tools are present) if (data.contains("parse_tool_calls")) { SRV_DBG("[TOOLS DEBUG] Predict: parse_tool_calls=%s\n", data["parse_tool_calls"].get() ? "true" : "false"); } } else { // Use prompt directly from data if (data.contains("prompt") && data["prompt"].is_string()) { prompt_str = data["prompt"].get(); } else { prompt_str = request->prompt(); } } const auto type = SERVER_TASK_TYPE_COMPLETION; // TODO: this log can become very long, put it behind a flag or think about a more compact format //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); // If not using chat templates, extract files from image_data/audio_data fields // (If using chat templates, files were already extracted by oaicompat_chat_params_parse) if (!request->usetokenizertemplate() || request->messages_size() == 0 || ctx_server.impl->chat_params.tmpls == nullptr) { const auto &images_data = data.find("image_data"); if (images_data != data.end() && images_data->is_array()) { std::cout << "[PREDICT] Processing " << images_data->size() << " images" << std::endl; for (const auto &img : *images_data) { std::cout << "[PREDICT] Processing image" << std::endl; auto decoded_data = base64_decode(img["data"].get()); files.push_back(decoded_data); } } const auto &audio_data = data.find("audio_data"); if (audio_data != data.end() && audio_data->is_array()) { for (const auto &audio : *audio_data) { auto decoded_data = base64_decode(audio["data"].get()); files.push_back(decoded_data); } } } // process files const bool has_mtmd = ctx_server.impl->mctx != nullptr; // process prompt std::vector inputs; if (has_mtmd) { // multimodal inputs.push_back(process_mtmd_prompt(ctx_server.impl->mctx, prompt_str, files)); } else { // Everything else, including multimodal completions. inputs = tokenize_input_prompts(ctx_server.impl->vocab, ctx_server.impl->mctx, prompt_str, true, true); } tasks.reserve(inputs.size()); for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); task.id = rd.queue_tasks.get_new_id(); task.index = i; task.tokens = std::move(inputs[i]); task.params = server_task::params_from_json_cmpl( ctx_server.impl->vocab, params_base, ctx_server.get_meta().slot_n_ctx, data); task.id_slot = json_value(data, "id_slot", -1); // OAI-compat task.params.res_type = TASK_RESPONSE_TYPE_NONE; task.params.oaicompat_cmpl_id = completion_id; // oaicompat_model is already populated by params_from_json_cmpl tasks.push_back(std::move(task)); } rd.post_tasks(std::move(tasks)); } catch (const std::exception & e) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, e.what()); } std::cout << "[DEBUG] Waiting for results..." << std::endl; // Wait for all results auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); }); if (all_results.is_terminated) { return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); } else if (all_results.error) { std::cout << "[DEBUG] Error in results: " << all_results.error->to_json().value("message", "") << std::endl; reply->set_message(all_results.error->to_json().value("message", "")); return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error occurred")); } else { std::cout << "[DEBUG] Received " << all_results.results.size() << " results" << std::endl; if (all_results.results.size() == 1) { // single result auto* final_res = dynamic_cast(all_results.results[0].get()); GGML_ASSERT(final_res != nullptr); json result_json = all_results.results[0]->to_json(); reply->set_message(result_json.value("content", "")); int32_t tokens_predicted = result_json.value("tokens_predicted", 0); reply->set_tokens(tokens_predicted); int32_t tokens_evaluated = result_json.value("tokens_evaluated", 0); reply->set_prompt_tokens(tokens_evaluated); if (result_json.contains("timings")) { double timing_prompt_processing = result_json.at("timings").value("prompt_ms", 0.0); reply->set_timing_prompt_processing(timing_prompt_processing); double timing_token_generation = result_json.at("timings").value("predicted_ms", 0.0); reply->set_timing_token_generation(timing_token_generation); } // Extract and set logprobs if present json logprobs_json = extract_logprobs_from_json(result_json); if (!logprobs_json.empty() && !logprobs_json.is_null()) { std::string logprobs_str = logprobs_json.dump(); reply->set_logprobs(logprobs_str); } // Populate chat deltas from the autoparser's final parsed message if (final_res->is_updated) { populate_chat_deltas_from_final(*reply, final_res->oaicompat_msg); } } else { // multiple results (multitask) json arr = json::array(); json logprobs_arr = json::array(); bool has_logprobs = false; for (auto & res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); json res_json = res->to_json(); arr.push_back(res_json.value("content", "")); // Extract logprobs for each result json logprobs_json = extract_logprobs_from_json(res_json); if (!logprobs_json.empty() && !logprobs_json.is_null()) { has_logprobs = true; logprobs_arr.push_back(logprobs_json); } else { logprobs_arr.push_back(json::object()); } } reply->set_message(arr); // Set logprobs if any result has them if (has_logprobs) { std::string logprobs_str = logprobs_arr.dump(); reply->set_logprobs(logprobs_str); } } } std::cout << "[DEBUG] Predict request completed successfully" << std::endl; // Check if context was cancelled during processing if (context->IsCancelled()) { return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); } return grpc::Status::OK; } grpc::Status Embedding(ServerContext* context, const backend::PredictOptions* request, backend::EmbeddingResult* embeddingResult) override { if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } json body = parse_options(false, request, params_base, ctx_server.get_llama_context()); body["stream"] = false; /* if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Pooling type 'none' is not OAI compatible. Please use a different pooling type"); } */ // for the shape of input/content, see tokenize_input_prompts() json prompt = body.at("embeddings"); auto tokenized_prompts = tokenize_input_prompts(ctx_server.impl->vocab, ctx_server.impl->mctx, prompt, true, true); for (const auto & tokens : tokenized_prompts) { // this check is necessary for models that do not add BOS token to the input if (tokens.empty()) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "Input content cannot be empty"); } } int embd_normalize = 2; // default to Euclidean/L2 norm // create and queue the task auto rd = ctx_server.get_response_reader(); { std::vector tasks; for (size_t i = 0; i < tokenized_prompts.size(); i++) { server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); task.id = rd.queue_tasks.get_new_id(); task.index = i; task.tokens = std::move(tokenized_prompts[i]); task.params.res_type = TASK_RESPONSE_TYPE_NONE; task.params.embd_normalize = embd_normalize; tasks.push_back(std::move(task)); } rd.post_tasks(std::move(tasks)); } // Wait for all results auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); }); if (all_results.is_terminated) { return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); } else if (all_results.error) { return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results")); } // Collect responses json responses = json::array(); for (auto & res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); responses.push_back(res->to_json()); } std::cout << "[DEBUG] Responses size: " << responses.size() << std::endl; // Process the responses and extract embeddings for (const auto & response_elem : responses) { // Check if the response has an "embedding" field if (response_elem.contains("embedding")) { json embedding_data = json_value(response_elem, "embedding", json::array()); if (embedding_data.is_array() && !embedding_data.empty()) { for (const auto & embedding_vector : embedding_data) { if (embedding_vector.is_array()) { for (const auto & embedding_value : embedding_vector) { embeddingResult->add_embeddings(embedding_value.get()); } } } } } else { // Check if the response itself contains the embedding data directly if (response_elem.is_array()) { for (const auto & embedding_value : response_elem) { embeddingResult->add_embeddings(embedding_value.get()); } } } } return grpc::Status::OK; } grpc::Status Rerank(ServerContext* context, const backend::RerankRequest* request, backend::RerankResult* rerankResult) override { if (!params_base.embedding || params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { return grpc::Status(grpc::StatusCode::UNIMPLEMENTED, "This server does not support reranking. Start it with `--reranking` and without `--embedding`"); } // Validate request if (request->query().empty()) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"query\" must be provided"); } if (request->documents_size() == 0) { return grpc::Status(grpc::StatusCode::INVALID_ARGUMENT, "\"documents\" must be a non-empty string array"); } // Create and queue the task auto rd = ctx_server.get_response_reader(); { std::vector tasks; std::vector documents; for (int i = 0; i < request->documents_size(); i++) { documents.push_back(request->documents(i)); } tasks.reserve(documents.size()); for (size_t i = 0; i < documents.size(); i++) { auto tmp = format_prompt_rerank(ctx_server.impl->model, ctx_server.impl->vocab, ctx_server.impl->mctx, request->query(), documents[i]); server_task task = server_task(SERVER_TASK_TYPE_RERANK); task.id = rd.queue_tasks.get_new_id(); task.index = i; task.tokens = std::move(tmp); tasks.push_back(std::move(task)); } rd.post_tasks(std::move(tasks)); } // Wait for all results auto all_results = rd.wait_for_all([&context]() { return context->IsCancelled(); }); if (all_results.is_terminated) { return grpc::Status(grpc::StatusCode::CANCELLED, "Request cancelled by client"); } else if (all_results.error) { return grpc::Status(grpc::StatusCode::INTERNAL, all_results.error->to_json().value("message", "Error in receiving results")); } // Collect responses json responses = json::array(); for (auto & res : all_results.results) { GGML_ASSERT(dynamic_cast(res.get()) != nullptr); responses.push_back(res->to_json()); } // Sort responses by score in descending order std::sort(responses.begin(), responses.end(), [](const json& a, const json& b) { return a.value("score", 0.0f) > b.value("score", 0.0f); }); // Crop results by request.top_n if specified int top_n = request->top_n(); if (top_n > 0 && top_n < static_cast(responses.size())) { responses = json(responses.begin(), responses.begin() + top_n); } // Set usage information backend::Usage* usage = rerankResult->mutable_usage(); int total_tokens = 0; int prompt_tokens = 0; // Create document results for (const auto& response : responses) { backend::DocumentResult* doc_result = rerankResult->add_results(); doc_result->set_index(response.value("index", 0)); doc_result->set_text(request->documents(response.value("index", 0))); doc_result->set_relevance_score(response.value("score", 0.0f)); // Add tokens evaluated for this document int tokens_evaluated = response.value("tokens_evaluated", 0); total_tokens += tokens_evaluated; prompt_tokens += tokens_evaluated; } // Set the total tokens in usage usage->set_total_tokens(total_tokens); usage->set_prompt_tokens(prompt_tokens); return grpc::Status::OK; } grpc::Status TokenizeString(ServerContext* /*context*/, const backend::PredictOptions* request, backend::TokenizationResponse* response) override { if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } json body = parse_options(false, request, params_base, ctx_server.get_llama_context()); body["stream"] = false; json tokens_response = json::array(); if (body.count("prompt") != 0) { const bool add_special = json_value(body, "add_special", false); llama_tokens tokens = tokenize_mixed(ctx_server.impl->vocab, body.at("content"), add_special, true); for (const auto& token : tokens) { std::string piece = common_token_to_piece(ctx_server.get_llama_context(), token); response->add_tokens(token); } } return grpc::Status::OK; } grpc::Status GetMetrics(ServerContext* /*context*/, const backend::MetricsRequest* /*request*/, backend::MetricsResponse* response) override { // request slots data using task queue auto rd = ctx_server.get_response_reader(); int task_id = rd.queue_tasks.get_new_id(); { server_task task(SERVER_TASK_TYPE_METRICS); task.id = task_id; rd.queue_results.add_waiting_task_id(task_id); rd.queue_tasks.post(std::move(task), true); // high-priority task } // get the result server_task_result_ptr result = rd.queue_results.recv(task_id); rd.queue_results.remove_waiting_task_id(task_id); if (result->is_error()) { // Handle case when no active slot exists response->set_slot_id(0); response->set_prompt_json_for_slot(""); response->set_tokens_per_second(0); response->set_tokens_generated(0); response->set_prompt_tokens_processed(0); return grpc::Status(grpc::StatusCode::INTERNAL, "Error in receiving results"); } // TODO: get rid of this dynamic_cast auto res_metrics = dynamic_cast(result.get()); GGML_ASSERT(res_metrics != nullptr); // Populate the response with metrics response->set_slot_id(0); response->set_prompt_json_for_slot(""); response->set_tokens_per_second(res_metrics->n_prompt_tokens_processed ? 1.e3 / res_metrics->t_prompt_processing * res_metrics->n_prompt_tokens_processed : 0.); response->set_tokens_generated(res_metrics->n_tokens_predicted_total); response->set_prompt_tokens_processed(res_metrics->n_prompt_tokens_processed_total); return grpc::Status::OK; } grpc::Status ModelMetadata(ServerContext* /*context*/, const backend::ModelOptions* /*request*/, backend::ModelMetadataResponse* response) override { // Check if model is loaded if (params_base.model.path.empty()) { return grpc::Status(grpc::StatusCode::FAILED_PRECONDITION, "Model not loaded"); } // Check if chat templates are initialized if (ctx_server.impl->chat_params.tmpls == nullptr) { // If templates are not initialized, we can't detect thinking support // Return false as default response->set_supports_thinking(false); response->set_rendered_template(""); return grpc::Status::OK; } // Detect thinking support using llama.cpp's function bool supports_thinking = common_chat_templates_support_enable_thinking(ctx_server.impl->chat_params.tmpls.get()); response->set_supports_thinking(supports_thinking); // Render the template with enable_thinking=true so Go code can detect thinking tokens // This allows reusing existing detection functions in Go std::string rendered_template = ""; if (params_base.use_jinja) { // Render the template with enable_thinking=true to see what the actual prompt looks like common_chat_templates_inputs dummy_inputs; common_chat_msg msg; msg.role = "user"; msg.content = "test"; dummy_inputs.messages = {msg}; dummy_inputs.enable_thinking = true; dummy_inputs.use_jinja = params_base.use_jinja; const auto rendered = common_chat_templates_apply(ctx_server.impl->chat_params.tmpls.get(), dummy_inputs); rendered_template = rendered.prompt; } response->set_rendered_template(rendered_template); // Run differential template analysis to detect tool format markers if (params_base.use_jinja) { try { // Get template source and reconstruct a common_chat_template for analysis std::string tmpl_src = common_chat_templates_source(ctx_server.impl->chat_params.tmpls.get()); if (!tmpl_src.empty()) { const auto * vocab = llama_model_get_vocab(ctx_server.impl->model); std::string token_bos, token_eos; if (vocab) { auto bos_id = llama_vocab_bos(vocab); auto eos_id = llama_vocab_eos(vocab); if (bos_id != LLAMA_TOKEN_NULL) { token_bos = common_token_to_piece(vocab, bos_id, true); } if (eos_id != LLAMA_TOKEN_NULL) { token_eos = common_token_to_piece(vocab, eos_id, true); } } common_chat_template tmpl(tmpl_src, token_bos, token_eos); struct autoparser::autoparser ap; ap.analyze_template(tmpl); if (ap.analysis_complete && ap.tools.format.mode != autoparser::tool_format::NONE) { auto * tf = response->mutable_tool_format(); // Format type switch (ap.tools.format.mode) { case autoparser::tool_format::JSON_NATIVE: tf->set_format_type("json_native"); break; case autoparser::tool_format::TAG_WITH_JSON: tf->set_format_type("tag_with_json"); break; case autoparser::tool_format::TAG_WITH_TAGGED: tf->set_format_type("tag_with_tagged"); break; default: break; } // Tool section markers tf->set_section_start(ap.tools.format.section_start); tf->set_section_end(ap.tools.format.section_end); tf->set_per_call_start(ap.tools.format.per_call_start); tf->set_per_call_end(ap.tools.format.per_call_end); // Function markers tf->set_func_name_prefix(ap.tools.function.name_prefix); tf->set_func_name_suffix(ap.tools.function.name_suffix); tf->set_func_close(ap.tools.function.close); // Argument markers tf->set_arg_name_prefix(ap.tools.arguments.name_prefix); tf->set_arg_name_suffix(ap.tools.arguments.name_suffix); tf->set_arg_value_prefix(ap.tools.arguments.value_prefix); tf->set_arg_value_suffix(ap.tools.arguments.value_suffix); tf->set_arg_separator(ap.tools.arguments.separator); tf->set_args_start(ap.tools.arguments.start); tf->set_args_end(ap.tools.arguments.end); // JSON format fields tf->set_name_field(ap.tools.format.name_field); tf->set_args_field(ap.tools.format.args_field); tf->set_id_field(ap.tools.format.id_field); tf->set_fun_name_is_key(ap.tools.format.fun_name_is_key); tf->set_tools_array_wrapped(ap.tools.format.tools_array_wrapped); tf->set_uses_python_dicts(ap.tools.format.uses_python_dicts); tf->set_function_field(ap.tools.format.function_field); tf->set_gen_id_field(ap.tools.format.gen_id_field); for (const auto & p : ap.tools.format.parameter_order) { tf->add_parameter_order(p); } // Call ID markers switch (ap.tools.call_id.pos) { case autoparser::call_id_position::NONE: tf->set_call_id_position("none"); break; case autoparser::call_id_position::PRE_FUNC_NAME: tf->set_call_id_position("pre_func_name"); break; case autoparser::call_id_position::BETWEEN_FUNC_AND_ARGS: tf->set_call_id_position("between_func_and_args"); break; case autoparser::call_id_position::POST_ARGS: tf->set_call_id_position("post_args"); break; } tf->set_call_id_prefix(ap.tools.call_id.prefix); tf->set_call_id_suffix(ap.tools.call_id.suffix); // Reasoning markers tf->set_reasoning_start(ap.reasoning.start); tf->set_reasoning_end(ap.reasoning.end); // Content markers tf->set_content_start(ap.content.start); tf->set_content_end(ap.content.end); } } } catch (const std::exception & e) { SRV_WRN("ModelMetadata: failed to run autoparser analysis: %s\n", e.what()); } } return grpc::Status::OK; } }; int main(int argc, char** argv) { std::string server_address("localhost:50051"); // Define long and short options struct option long_options[] = { {"addr", required_argument, nullptr, 'a'}, {nullptr, 0, nullptr, 0} }; // Parse command-line arguments int option; int option_index = 0; while ((option = getopt_long(argc, argv, "a:", long_options, &option_index)) != -1) { switch (option) { case 'a': server_address = optarg; break; default: std::cerr << "Usage: " << argv[0] << " [--addr=
] or [-a
]" << std::endl; return 1; } } server_context ctx_server; BackendServiceImpl service(ctx_server); ServerBuilder builder; builder.AddListeningPort(server_address, grpc::InsecureServerCredentials()); builder.RegisterService(&service); builder.SetMaxMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxSendMessageSize(50 * 1024 * 1024); // 50MB builder.SetMaxReceiveMessageSize(50 * 1024 * 1024); // 50MB std::unique_ptr server(builder.BuildAndStart()); // run the HTTP server in a thread - see comment below std::thread t([&]() { std::cout << "Server listening on " << server_address << std::endl; server->Wait(); return 0; }); // clean up function, to be called before exit auto clean_up = [&server, &ctx_server]() { SRV_INF("%s: cleaning up before exit...\n", __func__); server->Shutdown(); ctx_server.terminate(); llama_backend_free(); }; //); start_llama_server(ctx_server); std::cout << "stopping" << std::endl; clean_up(); t.join(); return 0; } ================================================ FILE: backend/cpp/llama-cpp/package.sh ================================================ #!/bin/bash # Script to copy the appropriate libraries based on architecture # This script is used in the final stage of the Dockerfile set -e CURDIR=$(dirname "$(realpath $0)") REPO_ROOT="${CURDIR}/../../.." # Create lib directory mkdir -p $CURDIR/package/lib cp -avrf $CURDIR/llama-cpp-* $CURDIR/package/ cp -rfv $CURDIR/run.sh $CURDIR/package/ # Detect architecture and copy appropriate libraries if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then # x86_64 architecture echo "Detected x86_64 architecture, copying x86_64 libraries..." cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then # ARM64 architecture echo "Detected ARM64 architecture, copying ARM64 libraries..." cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 else echo "Error: Could not detect architecture" exit 1 fi # Package GPU libraries based on BUILD_TYPE # The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh" if [ -f "$GPU_LIB_SCRIPT" ]; then echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..." source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib" package_gpu_libs fi echo "Packaging completed successfully" ls -liah $CURDIR/package/ ls -liah $CURDIR/package/lib/ ================================================ FILE: backend/cpp/llama-cpp/prepare.sh ================================================ #!/bin/bash ## Patches ## Apply patches from the `patches` directory if [ -d "patches" ]; then for patch in $(ls patches); do echo "Applying patch $patch" patch -d llama.cpp/ -p1 < patches/$patch done fi set -e for file in $(ls llama.cpp/tools/server/); do cp -rfv llama.cpp/tools/server/$file llama.cpp/tools/grpc-server/ done cp -r CMakeLists.txt llama.cpp/tools/grpc-server/ cp -r grpc-server.cpp llama.cpp/tools/grpc-server/ cp -rfv llama.cpp/vendor/nlohmann/json.hpp llama.cpp/tools/grpc-server/ cp -rfv llama.cpp/vendor/cpp-httplib/httplib.h llama.cpp/tools/grpc-server/ set +e if grep -q "grpc-server" llama.cpp/tools/CMakeLists.txt; then echo "grpc-server already added" else echo "add_subdirectory(grpc-server)" >> llama.cpp/tools/CMakeLists.txt fi set -e ================================================ FILE: backend/cpp/llama-cpp/run.sh ================================================ #!/bin/bash set -ex # Get the absolute current dir where the script is located CURDIR=$(dirname "$(realpath $0)") cd / echo "CPU info:" grep -e "model\sname" /proc/cpuinfo | head -1 grep -e "flags" /proc/cpuinfo | head -1 BINARY=llama-cpp-fallback if grep -q -e "\savx\s" /proc/cpuinfo ; then echo "CPU: AVX found OK" if [ -e $CURDIR/llama-cpp-avx ]; then BINARY=llama-cpp-avx fi fi if grep -q -e "\savx2\s" /proc/cpuinfo ; then echo "CPU: AVX2 found OK" if [ -e $CURDIR/llama-cpp-avx2 ]; then BINARY=llama-cpp-avx2 fi fi # Check avx 512 if grep -q -e "\savx512f\s" /proc/cpuinfo ; then echo "CPU: AVX512F found OK" if [ -e $CURDIR/llama-cpp-avx512 ]; then BINARY=llama-cpp-avx512 fi fi if [ -n "$LLAMACPP_GRPC_SERVERS" ]; then if [ -e $CURDIR/llama-cpp-grpc ]; then BINARY=llama-cpp-grpc fi fi # Extend ld library path with the dir where this script is located/lib if [ "$(uname)" == "Darwin" ]; then export DYLD_LIBRARY_PATH=$CURDIR/lib:$DYLD_LIBRARY_PATH #export DYLD_FALLBACK_LIBRARY_PATH=$CURDIR/lib:$DYLD_FALLBACK_LIBRARY_PATH else export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH fi # If there is a lib/ld.so, use it if [ -f $CURDIR/lib/ld.so ]; then echo "Using lib/ld.so" echo "Using binary: $BINARY" exec $CURDIR/lib/ld.so $CURDIR/$BINARY "$@" fi echo "Using binary: $BINARY" exec $CURDIR/$BINARY "$@" # We should never reach this point, however just in case we do, run fallback exec $CURDIR/llama-cpp-fallback "$@" ================================================ FILE: backend/go/acestep-cpp/CMakeLists.txt ================================================ cmake_minimum_required(VERSION 3.14) project(goacestepcpp LANGUAGES C CXX) set(CMAKE_POSITION_INDEPENDENT_CODE ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(ACESTEP_DIR ${CMAKE_CURRENT_SOURCE_DIR}/sources/acestep.cpp) # Override upstream's CMAKE_CUDA_ARCHITECTURES before add_subdirectory. # Upstream sets 120a/121a for CUDA >= 12.8, but those archs require a newer # toolkit than 12.8.x ships. Pre-defining this variable makes the upstream # "if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES)" guard skip its broken defaults. if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) set(CMAKE_CUDA_ARCHITECTURES "75-virtual;80-virtual;86-real;89-real") endif() # EXCLUDE_FROM_ALL: only build targets we explicitly depend on (acestep-core, ggml), # skip upstream standalone executables (ace-understand, dit-vae, etc.) add_subdirectory(${ACESTEP_DIR} acestep EXCLUDE_FROM_ALL) add_library(goacestepcpp MODULE cpp/goacestepcpp.cpp) target_link_libraries(goacestepcpp PRIVATE acestep-core ggml ggml-base ggml-cpu) # Include dirs matching link_ggml_backends macro, but with absolute paths target_include_directories(goacestepcpp PRIVATE ${ACESTEP_DIR}/src ${ACESTEP_DIR}) target_include_directories(goacestepcpp SYSTEM PRIVATE ${ACESTEP_DIR}/ggml/include) # Link GPU backends if available (mirrors link_ggml_backends macro) foreach(backend blas cuda metal vulkan) if(TARGET ggml-${backend}) target_link_libraries(goacestepcpp PRIVATE ggml-${backend}) string(TOUPPER ${backend} BACKEND_UPPER) target_compile_definitions(goacestepcpp PRIVATE ACESTEP_HAVE_${BACKEND_UPPER}) if(backend STREQUAL "cuda") find_package(CUDAToolkit QUIET) if(CUDAToolkit_FOUND) target_link_libraries(goacestepcpp PRIVATE CUDA::cudart) endif() endif() endif() endforeach() if(MSVC) target_compile_options(goacestepcpp PRIVATE /W4 /wd4100 /wd4505) else() target_compile_options(goacestepcpp PRIVATE -Wall -Wextra -Wshadow -Wconversion -Wno-unused-parameter -Wno-unused-function -Wno-sign-conversion) endif() if(CMAKE_CXX_COMPILER_ID MATCHES "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 9.0) target_link_libraries(goacestepcpp PRIVATE stdc++fs) endif() set_property(TARGET goacestepcpp PROPERTY CXX_STANDARD 17) set_target_properties(goacestepcpp PROPERTIES LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}) ================================================ FILE: backend/go/acestep-cpp/Makefile ================================================ CMAKE_ARGS?= BUILD_TYPE?= NATIVE?=false GOCMD?=go GO_TAGS?= JOBS?=$(shell nproc --ignore=1) # acestep.cpp version ACESTEP_REPO?=https://github.com/ace-step/acestep.cpp ACESTEP_CPP_VERSION?=ab020a9aefcd364423e0665da12babc6b0c7b507 SO_TARGET?=libgoacestepcpp.so CMAKE_ARGS+=-DBUILD_SHARED_LIBS=OFF ifeq ($(NATIVE),false) CMAKE_ARGS+=-DGGML_NATIVE=OFF endif ifeq ($(BUILD_TYPE),cublas) CMAKE_ARGS+=-DGGML_CUDA=ON else ifeq ($(BUILD_TYPE),openblas) CMAKE_ARGS+=-DGGML_BLAS=ON -DGGML_BLAS_VENDOR=OpenBLAS else ifeq ($(BUILD_TYPE),clblas) CMAKE_ARGS+=-DGGML_CLBLAST=ON -DCLBlast_DIR=/some/path else ifeq ($(BUILD_TYPE),hipblas) CMAKE_ARGS+=-DGGML_HIPBLAS=ON else ifeq ($(BUILD_TYPE),vulkan) CMAKE_ARGS+=-DGGML_VULKAN=ON else ifeq ($(OS),Darwin) ifneq ($(BUILD_TYPE),metal) CMAKE_ARGS+=-DGGML_METAL=OFF else CMAKE_ARGS+=-DGGML_METAL=ON CMAKE_ARGS+=-DGGML_METAL_EMBED_LIBRARY=ON endif endif ifeq ($(BUILD_TYPE),sycl_f16) CMAKE_ARGS+=-DGGML_SYCL=ON \ -DCMAKE_C_COMPILER=icx \ -DCMAKE_CXX_COMPILER=icpx \ -DGGML_SYCL_F16=ON endif ifeq ($(BUILD_TYPE),sycl_f32) CMAKE_ARGS+=-DGGML_SYCL=ON \ -DCMAKE_C_COMPILER=icx \ -DCMAKE_CXX_COMPILER=icpx endif sources/acestep.cpp: mkdir -p sources/acestep.cpp cd sources/acestep.cpp && \ git init && \ git remote add origin $(ACESTEP_REPO) && \ git fetch origin && \ git checkout $(ACESTEP_CPP_VERSION) && \ git submodule update --init --recursive --depth 1 --single-branch # Detect OS UNAME_S := $(shell uname -s) # Only build CPU variants on Linux ifeq ($(UNAME_S),Linux) VARIANT_TARGETS = libgoacestepcpp-avx.so libgoacestepcpp-avx2.so libgoacestepcpp-avx512.so libgoacestepcpp-fallback.so else # On non-Linux (e.g., Darwin), build only fallback variant VARIANT_TARGETS = libgoacestepcpp-fallback.so endif acestep-cpp: main.go goacestepcpp.go $(VARIANT_TARGETS) CGO_ENABLED=0 $(GOCMD) build -tags "$(GO_TAGS)" -o acestep-cpp ./ package: acestep-cpp bash package.sh build: package clean: purge rm -rf libgoacestepcpp*.so package sources/acestep.cpp acestep-cpp purge: rm -rf build* # Variants must build sequentially: each uses its own build- directory, # but parallel builds can still race on shared resources (jobserver, disk I/O). .NOTPARALLEL: # Build all variants (Linux only) ifeq ($(UNAME_S),Linux) libgoacestepcpp-avx.so: sources/acestep.cpp $(info ${GREEN}I acestep-cpp build info:avx${RESET}) SO_TARGET=libgoacestepcpp-avx.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgoacestepcpp-custom rm -rf build-libgoacestepcpp-avx.so libgoacestepcpp-avx2.so: sources/acestep.cpp $(info ${GREEN}I acestep-cpp build info:avx2${RESET}) SO_TARGET=libgoacestepcpp-avx2.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=off -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgoacestepcpp-custom rm -rf build-libgoacestepcpp-avx2.so libgoacestepcpp-avx512.so: sources/acestep.cpp $(info ${GREEN}I acestep-cpp build info:avx512${RESET}) SO_TARGET=libgoacestepcpp-avx512.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=on -DGGML_AVX2=on -DGGML_AVX512=on -DGGML_FMA=on -DGGML_F16C=on -DGGML_BMI2=on" $(MAKE) libgoacestepcpp-custom rm -rf build-libgoacestepcpp-avx512.so endif # Build fallback variant (all platforms) libgoacestepcpp-fallback.so: sources/acestep.cpp $(info ${GREEN}I acestep-cpp build info:fallback${RESET}) SO_TARGET=libgoacestepcpp-fallback.so CMAKE_ARGS="$(CMAKE_ARGS) -DGGML_AVX=off -DGGML_AVX2=off -DGGML_AVX512=off -DGGML_FMA=off -DGGML_F16C=off -DGGML_BMI2=off" $(MAKE) libgoacestepcpp-custom rm -rf build-libgoacestepcpp-fallback.so libgoacestepcpp-custom: CMakeLists.txt cpp/goacestepcpp.cpp cpp/goacestepcpp.h mkdir -p build-$(SO_TARGET) && \ cd build-$(SO_TARGET) && \ cmake .. $(CMAKE_ARGS) && \ cmake --build . --config Release -j$(JOBS) --target goacestepcpp && \ cd .. && \ mv build-$(SO_TARGET)/libgoacestepcpp.so ./$(SO_TARGET) test: acestep-cpp @echo "Running acestep-cpp tests..." bash test.sh @echo "acestep-cpp tests completed." all: acestep-cpp package ================================================ FILE: backend/go/acestep-cpp/acestepcpp_test.go ================================================ package main import ( "context" "os" "os/exec" "path/filepath" "testing" "time" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" ) const ( testAddr = "localhost:50051" startupWait = 5 * time.Second ) func skipIfNoModel(t *testing.T) string { t.Helper() modelDir := os.Getenv("ACESTEP_MODEL_DIR") if modelDir == "" { t.Skip("ACESTEP_MODEL_DIR not set, skipping test (set to directory with GGUF models)") } if _, err := os.Stat(filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf")); os.IsNotExist(err) { t.Skipf("LM model file not found in %s, skipping", modelDir) } if _, err := os.Stat(filepath.Join(modelDir, "Qwen3-Embedding-0.6B-Q8_0.gguf")); os.IsNotExist(err) { t.Skipf("Text encoder model file not found in %s, skipping", modelDir) } if _, err := os.Stat(filepath.Join(modelDir, "acestep-v15-turbo-Q8_0.gguf")); os.IsNotExist(err) { t.Skipf("DiT model file not found in %s, skipping", modelDir) } if _, err := os.Stat(filepath.Join(modelDir, "vae-BF16.gguf")); os.IsNotExist(err) { t.Skipf("VAE model file not found in %s, skipping", modelDir) } return modelDir } func startServer(t *testing.T) *exec.Cmd { t.Helper() binary := os.Getenv("ACESTEP_BINARY") if binary == "" { binary = "./acestep-cpp" } if _, err := os.Stat(binary); os.IsNotExist(err) { t.Skipf("Backend binary not found at %s, skipping", binary) } cmd := exec.Command(binary, "--addr", testAddr) cmd.Stdout = os.Stderr cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { t.Fatalf("Failed to start server: %v", err) } time.Sleep(startupWait) return cmd } func stopServer(cmd *exec.Cmd) { if cmd != nil && cmd.Process != nil { cmd.Process.Kill() cmd.Wait() } } func dialGRPC(t *testing.T) *grpc.ClientConn { t.Helper() conn, err := grpc.Dial(testAddr, grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithDefaultCallOptions( grpc.MaxCallRecvMsgSize(50*1024*1024), grpc.MaxCallSendMsgSize(50*1024*1024), ), ) if err != nil { t.Fatalf("Failed to dial gRPC: %v", err) } return conn } func TestServerHealth(t *testing.T) { cmd := startServer(t) defer stopServer(cmd) conn := dialGRPC(t) defer conn.Close() client := pb.NewBackendClient(conn) resp, err := client.Health(context.Background(), &pb.HealthMessage{}) if err != nil { t.Fatalf("Health check failed: %v", err) } if string(resp.Message) != "OK" { t.Fatalf("Expected OK, got %s", string(resp.Message)) } } func TestLoadModel(t *testing.T) { modelDir := skipIfNoModel(t) cmd := startServer(t) defer stopServer(cmd) conn := dialGRPC(t) defer conn.Close() client := pb.NewBackendClient(conn) // Get base directory from main model file for relative paths mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf") resp, err := client.LoadModel(context.Background(), &pb.ModelOptions{ ModelFile: mainModelPath, ModelPath: modelDir, Options: []string{ "text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf", "dit_model:acestep-v15-turbo-Q8_0.gguf", "vae_model:vae-BF16.gguf", }, }) if err != nil { t.Fatalf("LoadModel failed: %v", err) } if !resp.Success { t.Fatalf("LoadModel returned failure: %s", resp.Message) } } func TestSoundGeneration(t *testing.T) { modelDir := skipIfNoModel(t) tmpDir, err := os.MkdirTemp("", "acestep-test") if err != nil { t.Fatal(err) } defer os.RemoveAll(tmpDir) outputFile := filepath.Join(tmpDir, "output.wav") cmd := startServer(t) defer stopServer(cmd) conn := dialGRPC(t) defer conn.Close() client := pb.NewBackendClient(conn) // Get base directory from main model file for relative paths mainModelPath := filepath.Join(modelDir, "acestep-5Hz-lm-0.6B-Q8_0.gguf") // Load models loadResp, err := client.LoadModel(context.Background(), &pb.ModelOptions{ ModelFile: mainModelPath, ModelPath: modelDir, Options: []string{ "text_encoder_model:Qwen3-Embedding-0.6B-Q8_0.gguf", "dit_model:acestep-v15-turbo-Q8_0.gguf", "vae_model:vae-BF16.gguf", }, }) if err != nil { t.Fatalf("LoadModel failed: %v", err) } if !loadResp.Success { t.Fatalf("LoadModel returned failure: %s", loadResp.Message) } // Generate music duration := float32(10.0) temperature := float32(0.85) bpm := int32(120) caption := "A cheerful electronic dance track" timesig := "4/4" _, err = client.SoundGeneration(context.Background(), &pb.SoundGenerationRequest{ Text: caption, Caption: &caption, Dst: outputFile, Duration: &duration, Temperature: &temperature, Bpm: &bpm, Timesignature: ×ig, }) if err != nil { t.Fatalf("SoundGeneration failed: %v", err) } // Verify output file exists and has content info, err := os.Stat(outputFile) if os.IsNotExist(err) { t.Fatal("Output audio file was not created") } if err != nil { t.Fatalf("Failed to stat output file: %v", err) } t.Logf("Output file size: %d bytes", info.Size()) // WAV header is 44 bytes minimum; any real audio should be much larger if info.Size() < 1000 { t.Errorf("Output file too small (%d bytes), expected real audio data", info.Size()) } } ================================================ FILE: backend/go/acestep-cpp/cpp/goacestepcpp.cpp ================================================ #include "goacestepcpp.h" #include "ggml-backend.h" #include "audio-io.h" #include "bpe.h" #include "cond-enc.h" #include "dit-sampler.h" #include "dit.h" #include "gguf-weights.h" #include "philox.h" #include "qwen3-enc.h" #include "qwen3-lm.h" #include "request.h" #include "vae.h" #include #include #include #include #include #include #include // Global model contexts (loaded once, reused across requests) static DiTGGML g_dit = {}; static DiTGGMLConfig g_dit_cfg; static VAEGGML g_vae = {}; static bool g_dit_loaded = false; static bool g_vae_loaded = false; static bool g_is_turbo = false; // Silence latent [15000, 64] — read once from DiT GGUF static std::vector g_silence_full; // Paths for per-request loading (text encoder, tokenizer) static std::string g_text_enc_path; static std::string g_dit_path; static std::string g_lm_path; static void ggml_log_cb(enum ggml_log_level level, const char * log, void * data) { const char * level_str; if (!log) return; switch (level) { case GGML_LOG_LEVEL_DEBUG: level_str = "DEBUG"; break; case GGML_LOG_LEVEL_INFO: level_str = "INFO"; break; case GGML_LOG_LEVEL_WARN: level_str = "WARN"; break; case GGML_LOG_LEVEL_ERROR: level_str = "ERROR"; break; default: level_str = "?????"; break; } fprintf(stderr, "[%-5s] ", level_str); fputs(log, stderr); fflush(stderr); } int load_model(const char * lm_model_path, const char * text_encoder_path, const char * dit_model_path, const char * vae_model_path) { ggml_log_set(ggml_log_cb, nullptr); ggml_backend_load_all(); g_lm_path = lm_model_path; g_text_enc_path = text_encoder_path; g_dit_path = dit_model_path; // Load DiT model fprintf(stderr, "[acestep-cpp] Loading DiT from %s\n", dit_model_path); dit_ggml_init_backend(&g_dit); if (!dit_ggml_load(&g_dit, dit_model_path, g_dit_cfg, nullptr, 0.0f)) { fprintf(stderr, "[acestep-cpp] FATAL: failed to load DiT from %s\n", dit_model_path); return 1; } g_dit_loaded = true; // Read DiT GGUF metadata + silence_latent { GGUFModel gf = {}; if (gf_load(&gf, dit_model_path)) { g_is_turbo = gf_get_bool(gf, "acestep.is_turbo"); const void * sl_data = gf_get_data(gf, "silence_latent"); if (sl_data) { g_silence_full.resize(15000 * 64); memcpy(g_silence_full.data(), sl_data, 15000 * 64 * sizeof(float)); fprintf(stderr, "[acestep-cpp] silence_latent: [15000, 64] loaded\n"); } else { fprintf(stderr, "[acestep-cpp] FATAL: silence_latent not found in %s\n", dit_model_path); gf_close(&gf); return 2; } gf_close(&gf); } else { fprintf(stderr, "[acestep-cpp] FATAL: cannot read GGUF metadata from %s\n", dit_model_path); return 2; } } // Load VAE model fprintf(stderr, "[acestep-cpp] Loading VAE from %s\n", vae_model_path); vae_ggml_load(&g_vae, vae_model_path); g_vae_loaded = true; fprintf(stderr, "[acestep-cpp] All models loaded successfully (turbo=%d)\n", g_is_turbo); return 0; } int generate_music(const char * caption, const char * lyrics, int bpm, const char * keyscale, const char * timesignature, float duration, float temperature, bool instrumental, int seed, const char * dst, int threads) { if (!g_dit_loaded || !g_vae_loaded) { fprintf(stderr, "[acestep-cpp] ERROR: models not loaded\n"); return 1; } const int FRAMES_PER_SECOND = 25; // Defaults if (duration <= 0) duration = 30.0f; std::string cap_str = caption ? caption : ""; std::string lyrics_str = (instrumental || !lyrics) ? "" : lyrics; std::string ks_str = keyscale ? keyscale : "N/A"; std::string ts_str = timesignature ? timesignature : "4/4"; std::string lang_str = "unknown"; char bpm_str[16]; if (bpm > 0) { snprintf(bpm_str, sizeof(bpm_str), "%d", bpm); } else { snprintf(bpm_str, sizeof(bpm_str), "N/A"); } int num_steps = 8; float guidance_scale = g_is_turbo ? 1.0f : 7.0f; float shift = 1.0f; if (seed < 0) { std::random_device rd; seed = (int)(rd() & 0x7FFFFFFF); } // Compute T (latent frames at 25Hz) int T = (int)(duration * FRAMES_PER_SECOND); T = ((T + g_dit_cfg.patch_size - 1) / g_dit_cfg.patch_size) * g_dit_cfg.patch_size; int S = T / g_dit_cfg.patch_size; if (T > 15000) { fprintf(stderr, "[acestep-cpp] ERROR: T=%d exceeds max 15000\n", T); return 2; } int Oc = g_dit_cfg.out_channels; // 64 int ctx_ch = g_dit_cfg.in_channels - Oc; // 128 fprintf(stderr, "[acestep-cpp] T=%d, S=%d, duration=%.1fs, seed=%d\n", T, S, duration, seed); // 1. Load BPE tokenizer from text encoder GGUF BPETokenizer tok; if (!load_bpe_from_gguf(&tok, g_text_enc_path.c_str())) { fprintf(stderr, "[acestep-cpp] FATAL: failed to load BPE tokenizer\n"); return 3; } // 2. Build formatted prompts (matches dit-vae.cpp text2music template) std::string instruction = "Fill the audio semantic mask based on the given conditions:"; char metas[512]; snprintf(metas, sizeof(metas), "- bpm: %s\n- timesignature: %s\n- keyscale: %s\n- duration: %d seconds\n", bpm_str, ts_str.c_str(), ks_str.c_str(), (int)duration); std::string text_str = std::string("# Instruction\n") + instruction + "\n\n" + "# Caption\n" + cap_str + "\n\n" + "# Metas\n" + metas + "<|endoftext|>\n"; std::string lyric_str = std::string("# Languages\n") + lang_str + "\n\n# Lyric\n" + lyrics_str + "<|endoftext|>"; // 3. Tokenize auto text_ids = bpe_encode(&tok, text_str.c_str(), true); auto lyric_ids = bpe_encode(&tok, lyric_str.c_str(), true); int S_text = (int)text_ids.size(); int S_lyric = (int)lyric_ids.size(); fprintf(stderr, "[acestep-cpp] caption: %d tokens, lyrics: %d tokens\n", S_text, S_lyric); // 4. Text encoder forward Qwen3GGML text_enc = {}; qwen3_init_backend(&text_enc); if (!qwen3_load_text_encoder(&text_enc, g_text_enc_path.c_str())) { fprintf(stderr, "[acestep-cpp] FATAL: failed to load text encoder\n"); return 4; } int H_text = text_enc.cfg.hidden_size; // 1024 std::vector text_hidden(H_text * S_text); qwen3_forward(&text_enc, text_ids.data(), S_text, text_hidden.data()); fprintf(stderr, "[acestep-cpp] TextEncoder forward done\n"); // 5. Lyric embedding std::vector lyric_embed(H_text * S_lyric); qwen3_embed_lookup(&text_enc, lyric_ids.data(), S_lyric, lyric_embed.data()); // 6. Condition encoder CondGGML cond = {}; cond_ggml_init_backend(&cond); if (!cond_ggml_load(&cond, g_dit_path.c_str())) { fprintf(stderr, "[acestep-cpp] FATAL: failed to load condition encoder\n"); qwen3_free(&text_enc); return 5; } const int S_ref = 750; std::vector silence_feats(S_ref * 64); memcpy(silence_feats.data(), g_silence_full.data(), S_ref * 64 * sizeof(float)); int enc_S = 0; std::vector enc_hidden; cond_ggml_forward(&cond, text_hidden.data(), S_text, lyric_embed.data(), S_lyric, silence_feats.data(), S_ref, enc_hidden, &enc_S); fprintf(stderr, "[acestep-cpp] ConditionEncoder done, enc_S=%d\n", enc_S); qwen3_free(&text_enc); cond_ggml_free(&cond); // 7. Build context [T, ctx_ch] = silence[64] + mask[64] std::vector context(T * ctx_ch); for (int t = 0; t < T; t++) { const float * src = g_silence_full.data() + t * Oc; for (int c = 0; c < Oc; c++) { context[t * ctx_ch + c] = src[c]; } for (int c = 0; c < Oc; c++) { context[t * ctx_ch + Oc + c] = 1.0f; } } // 8. Build schedule std::vector schedule(num_steps); for (int i = 0; i < num_steps; i++) { float t = 1.0f - (float)i / (float)num_steps; schedule[i] = shift * t / (1.0f + (shift - 1.0f) * t); } // 9. Generate noise (Philox) std::vector noise(Oc * T); philox_randn((long long)seed, noise.data(), Oc * T, true); // 10. DiT generate std::vector output(Oc * T); fprintf(stderr, "[acestep-cpp] DiT generate: T=%d, steps=%d, guidance=%.1f\n", T, num_steps, guidance_scale); dit_ggml_generate(&g_dit, noise.data(), context.data(), enc_hidden.data(), enc_S, T, 1, num_steps, schedule.data(), output.data(), guidance_scale, nullptr, nullptr, -1); fprintf(stderr, "[acestep-cpp] DiT generation done\n"); // 11. VAE decode int T_audio_max = T * 1920; std::vector audio(2 * T_audio_max); int T_audio = vae_ggml_decode_tiled(&g_vae, output.data(), T, audio.data(), T_audio_max, 256, 64); if (T_audio < 0) { fprintf(stderr, "[acestep-cpp] ERROR: VAE decode failed\n"); return 6; } fprintf(stderr, "[acestep-cpp] VAE decode done: %d samples (%.2fs @ 48kHz)\n", T_audio, (float)T_audio / 48000.0f); // 12. Peak normalization to -1.0 dB { float peak = 0.0f; int n_samples = 2 * T_audio; for (int i = 0; i < n_samples; i++) { float a = audio[i] < 0 ? -audio[i] : audio[i]; if (a > peak) { peak = a; } } if (peak > 1e-6f) { const float target_amp = powf(10.0f, -1.0f / 20.0f); float gain = target_amp / peak; for (int i = 0; i < n_samples; i++) { audio[i] *= gain; } } } // 13. Write WAV output if (!audio_write_wav(dst, audio.data(), T_audio, 48000)) { fprintf(stderr, "[acestep-cpp] ERROR: failed to write %s\n", dst); return 7; } fprintf(stderr, "[acestep-cpp] Wrote %s: %d samples (%.2fs @ 48kHz stereo)\n", dst, T_audio, (float)T_audio / 48000.0f); return 0; } ================================================ FILE: backend/go/acestep-cpp/cpp/goacestepcpp.h ================================================ #include #include extern "C" { int load_model(const char *lm_model_path, const char *text_encoder_path, const char *dit_model_path, const char *vae_model_path); int generate_music(const char *caption, const char *lyrics, int bpm, const char *keyscale, const char *timesignature, float duration, float temperature, bool instrumental, int seed, const char *dst, int threads); } ================================================ FILE: backend/go/acestep-cpp/goacestepcpp.go ================================================ package main import ( "fmt" "os" "path/filepath" "strings" "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) var ( CppLoadModel func(lmModelPath, textEncoderPath, ditModelPath, vaeModelPath string) int CppGenerateMusic func(caption, lyrics string, bpm int, keyscale, timesignature string, duration, temperature float32, instrumental bool, seed int, dst string, threads int) int ) type AceStepCpp struct { base.SingleThread } func (a *AceStepCpp) Load(opts *pb.ModelOptions) error { // ModelFile is the LM model path lmModel := opts.ModelFile // Get the base directory from ModelFile for resolving relative paths baseDir := opts.ModelPath var textEncoderModel, ditModel, vaeModel string for _, oo := range opts.Options { parts := strings.SplitN(oo, ":", 2) if len(parts) != 2 { fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo) continue } switch parts[0] { case "text_encoder_model": textEncoderModel = parts[1] case "dit_model": ditModel = parts[1] case "vae_model": vaeModel = parts[1] default: fmt.Fprintf(os.Stderr, "Unrecognized option: %v\n", oo) } } if textEncoderModel == "" { return fmt.Errorf("text_encoder_model option is required") } if ditModel == "" { return fmt.Errorf("dit_model option is required") } if vaeModel == "" { return fmt.Errorf("vae_model option is required") } // Resolve relative paths to the base directory // If the path doesn't start with "/" it's relative if !filepath.IsAbs(textEncoderModel) { textEncoderModel = filepath.Join(baseDir, textEncoderModel) } if !filepath.IsAbs(ditModel) { ditModel = filepath.Join(baseDir, ditModel) } if !filepath.IsAbs(vaeModel) { vaeModel = filepath.Join(baseDir, vaeModel) } // Also resolve the lmModel if it's relative if !filepath.IsAbs(lmModel) { lmModel = filepath.Join(baseDir, lmModel) } fmt.Fprintf(os.Stderr, "[acestep-cpp] Resolved paths:\n") fmt.Fprintf(os.Stderr, " LM Model: %s\n", lmModel) fmt.Fprintf(os.Stderr, " Text Encoder: %s\n", textEncoderModel) fmt.Fprintf(os.Stderr, " DiT Model: %s\n", ditModel) fmt.Fprintf(os.Stderr, " VAE Model: %s\n", vaeModel) if ret := CppLoadModel(lmModel, textEncoderModel, ditModel, vaeModel); ret != 0 { return fmt.Errorf("failed to load acestep models (error code: %d)", ret) } return nil } func (a *AceStepCpp) SoundGeneration(req *pb.SoundGenerationRequest) error { caption := req.GetCaption() if caption == "" { caption = req.GetText() } lyrics := req.GetLyrics() bpm := int(req.GetBpm()) keyscale := req.GetKeyscale() timesignature := req.GetTimesignature() duration := req.GetDuration() temperature := req.GetTemperature() instrumental := req.GetInstrumental() seed := 42 threads := 4 if ret := CppGenerateMusic(caption, lyrics, bpm, keyscale, timesignature, duration, temperature, instrumental, seed, req.GetDst(), threads); ret != 0 { return fmt.Errorf("failed to generate music (error code: %d)", ret) } return nil } ================================================ FILE: backend/go/acestep-cpp/main.go ================================================ package main // Note: this is started internally by LocalAI and a server is allocated for each model import ( "flag" "os" "github.com/ebitengine/purego" grpc "github.com/mudler/LocalAI/pkg/grpc" ) var ( addr = flag.String("addr", "localhost:50051", "the address to connect to") ) type LibFuncs struct { FuncPtr any Name string } func main() { // Get library name from environment variable, default to fallback libName := os.Getenv("ACESTEP_LIBRARY") if libName == "" { libName = "./libgoacestepcpp-fallback.so" } gosd, err := purego.Dlopen(libName, purego.RTLD_NOW|purego.RTLD_GLOBAL) if err != nil { panic(err) } libFuncs := []LibFuncs{ {&CppLoadModel, "load_model"}, {&CppGenerateMusic, "generate_music"}, } for _, lf := range libFuncs { purego.RegisterLibFunc(lf.FuncPtr, gosd, lf.Name) } flag.Parse() if err := grpc.StartServer(*addr, &AceStepCpp{}); err != nil { panic(err) } } ================================================ FILE: backend/go/acestep-cpp/package.sh ================================================ #!/bin/bash # Script to copy the appropriate libraries based on architecture # This script is used in the final stage of the Dockerfile set -e CURDIR=$(dirname "$(realpath $0)") REPO_ROOT="${CURDIR}/../../.." # Create lib directory mkdir -p $CURDIR/package/lib cp -avf $CURDIR/acestep-cpp $CURDIR/package/ cp -fv $CURDIR/libgoacestepcpp-*.so $CURDIR/package/ cp -fv $CURDIR/run.sh $CURDIR/package/ # Detect architecture and copy appropriate libraries if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then # x86_64 architecture echo "Detected x86_64 architecture, copying x86_64 libraries..." cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then # ARM64 architecture echo "Detected ARM64 architecture, copying ARM64 libraries..." cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 elif [ $(uname -s) = "Darwin" ]; then echo "Detected Darwin" else echo "Error: Could not detect architecture" exit 1 fi # Package GPU libraries based on BUILD_TYPE # The GPU library packaging script will detect BUILD_TYPE and copy appropriate GPU libraries GPU_LIB_SCRIPT="${REPO_ROOT}/scripts/build/package-gpu-libs.sh" if [ -f "$GPU_LIB_SCRIPT" ]; then echo "Packaging GPU libraries for BUILD_TYPE=${BUILD_TYPE:-cpu}..." source "$GPU_LIB_SCRIPT" "$CURDIR/package/lib" package_gpu_libs fi echo "Packaging completed successfully" ls -liah $CURDIR/package/ ls -liah $CURDIR/package/lib/ ================================================ FILE: backend/go/acestep-cpp/run.sh ================================================ #!/bin/bash set -ex # Get the absolute current dir where the script is located CURDIR=$(dirname "$(realpath $0)") cd / echo "CPU info:" if [ "$(uname)" != "Darwin" ]; then grep -e "model\sname" /proc/cpuinfo | head -1 grep -e "flags" /proc/cpuinfo | head -1 fi LIBRARY="$CURDIR/libgoacestepcpp-fallback.so" if [ "$(uname)" != "Darwin" ]; then if grep -q -e "\savx\s" /proc/cpuinfo ; then echo "CPU: AVX found OK" if [ -e $CURDIR/libgoacestepcpp-avx.so ]; then LIBRARY="$CURDIR/libgoacestepcpp-avx.so" fi fi if grep -q -e "\savx2\s" /proc/cpuinfo ; then echo "CPU: AVX2 found OK" if [ -e $CURDIR/libgoacestepcpp-avx2.so ]; then LIBRARY="$CURDIR/libgoacestepcpp-avx2.so" fi fi # Check avx 512 if grep -q -e "\savx512f\s" /proc/cpuinfo ; then echo "CPU: AVX512F found OK" if [ -e $CURDIR/libgoacestepcpp-avx512.so ]; then LIBRARY="$CURDIR/libgoacestepcpp-avx512.so" fi fi fi export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH export ACESTEP_LIBRARY=$LIBRARY # If there is a lib/ld.so, use it if [ -f $CURDIR/lib/ld.so ]; then echo "Using lib/ld.so" echo "Using library: $LIBRARY" exec $CURDIR/lib/ld.so $CURDIR/acestep-cpp "$@" fi echo "Using library: $LIBRARY" exec $CURDIR/acestep-cpp "$@" ================================================ FILE: backend/go/acestep-cpp/test.sh ================================================ #!/bin/bash set -e CURDIR=$(dirname "$(realpath $0)") echo "Running acestep-cpp backend tests..." # The test requires: # - ACESTEP_MODEL_DIR: path to directory containing GGUF model files # - ACESTEP_BINARY: path to the acestep-cpp binary (defaults to ./acestep-cpp) # # Tests that require the model will be skipped if ACESTEP_MODEL_DIR is not set # or the directory does not contain the required model files. cd "$CURDIR" # Only auto-download models when ACESTEP_MODEL_DIR is not explicitly set if [ -z "$ACESTEP_MODEL_DIR" ]; then export ACESTEP_MODEL_DIR="./acestep-models" if [ ! -d "$ACESTEP_MODEL_DIR" ]; then echo "Creating acestep-models directory for tests..." mkdir -p "$ACESTEP_MODEL_DIR" REPO_ID="Serveurperso/ACE-Step-1.5-GGUF" echo "Repository: ${REPO_ID}" echo "" # Files to download (smallest quantizations for testing) FILES=( "acestep-5Hz-lm-0.6B-Q8_0.gguf" "Qwen3-Embedding-0.6B-Q8_0.gguf" "acestep-v15-turbo-Q8_0.gguf" "vae-BF16.gguf" ) BASE_URL="https://huggingface.co/${REPO_ID}/resolve/main" for file in "${FILES[@]}"; do dest="${ACESTEP_MODEL_DIR}/${file}" if [ -f "${dest}" ]; then echo " [skip] ${file} (already exists)" else echo " [download] ${file}..." curl -L -o "${dest}" "${BASE_URL}/${file}" --progress-bar echo " [done] ${file}" fi done fi fi # Run Go tests go test -v -timeout 600s . echo "All acestep-cpp tests passed." ================================================ FILE: backend/go/llm/llama/llama.go ================================================ package main // This is a wrapper to statisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( "fmt" "path/filepath" "github.com/go-skynet/go-llama.cpp" "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" ) type LLM struct { base.SingleThread llama *llama.LLama draftModel *llama.LLama } // Free releases GPU resources and frees the llama model // This should be called when the model is being unloaded to properly release VRAM func (llm *LLM) Free() error { if llm.llama != nil { llm.llama.Free() llm.llama = nil } if llm.draftModel != nil { llm.draftModel.Free() llm.draftModel = nil } return nil } func (llm *LLM) Load(opts *pb.ModelOptions) error { ropeFreqBase := float32(10000) ropeFreqScale := float32(1) if opts.RopeFreqBase != 0 { ropeFreqBase = opts.RopeFreqBase } if opts.RopeFreqScale != 0 { ropeFreqScale = opts.RopeFreqScale } llamaOpts := []llama.ModelOption{ llama.WithRopeFreqBase(ropeFreqBase), llama.WithRopeFreqScale(ropeFreqScale), } if opts.NoMulMatQ { llamaOpts = append(llamaOpts, llama.SetMulMatQ(false)) } // Get base path of opts.ModelFile and use the same for lora (assume the same path) basePath := filepath.Dir(opts.ModelFile) if opts.LoraAdapter != "" { llamaOpts = append(llamaOpts, llama.SetLoraAdapter(filepath.Join(basePath, opts.LoraAdapter))) } if opts.LoraBase != "" { llamaOpts = append(llamaOpts, llama.SetLoraBase(filepath.Join(basePath, opts.LoraBase))) } if opts.ContextSize != 0 { llamaOpts = append(llamaOpts, llama.SetContext(int(opts.ContextSize))) } if opts.F16Memory { llamaOpts = append(llamaOpts, llama.EnableF16Memory) } if opts.Embeddings { llamaOpts = append(llamaOpts, llama.EnableEmbeddings) } if opts.Reranking { llamaOpts = append(llamaOpts, llama.EnableReranking) } if opts.NGPULayers != 0 { llamaOpts = append(llamaOpts, llama.SetGPULayers(int(opts.NGPULayers))) } llamaOpts = append(llamaOpts, llama.SetMMap(opts.MMap)) llamaOpts = append(llamaOpts, llama.SetMainGPU(opts.MainGPU)) llamaOpts = append(llamaOpts, llama.SetTensorSplit(opts.TensorSplit)) if opts.NBatch != 0 { llamaOpts = append(llamaOpts, llama.SetNBatch(int(opts.NBatch))) } else { llamaOpts = append(llamaOpts, llama.SetNBatch(512)) } if opts.NUMA { llamaOpts = append(llamaOpts, llama.EnableNUMA) } if opts.LowVRAM { llamaOpts = append(llamaOpts, llama.EnabelLowVRAM) } if opts.DraftModel != "" { // https://github.com/ggerganov/llama.cpp/blob/71ca2fad7d6c0ef95ef9944fb3a1a843e481f314/examples/speculative/speculative.cpp#L40 llamaOpts = append(llamaOpts, llama.SetPerplexity(true)) } model, err := llama.New(opts.ModelFile, llamaOpts...) if opts.DraftModel != "" { // opts.DraftModel is relative to opts.ModelFile, so we need to get the basepath of opts.ModelFile if !filepath.IsAbs(opts.DraftModel) { dir := filepath.Dir(opts.ModelFile) opts.DraftModel = filepath.Join(dir, opts.DraftModel) } draftModel, err := llama.New(opts.DraftModel, llamaOpts...) if err != nil { return err } llm.draftModel = draftModel } llm.llama = model return err } func buildPredictOptions(opts *pb.PredictOptions) []llama.PredictOption { ropeFreqBase := float32(10000) ropeFreqScale := float32(1) if opts.RopeFreqBase != 0 { ropeFreqBase = opts.RopeFreqBase } if opts.RopeFreqScale != 0 { ropeFreqScale = opts.RopeFreqScale } predictOptions := []llama.PredictOption{ llama.SetTemperature(opts.Temperature), llama.SetTopP(opts.TopP), llama.SetTopK(int(opts.TopK)), llama.SetTokens(int(opts.Tokens)), llama.SetThreads(int(opts.Threads)), llama.WithGrammar(opts.Grammar), llama.SetRopeFreqBase(ropeFreqBase), llama.SetRopeFreqScale(ropeFreqScale), llama.SetNegativePromptScale(opts.NegativePromptScale), llama.SetNegativePrompt(opts.NegativePrompt), } if opts.PromptCacheAll { predictOptions = append(predictOptions, llama.EnablePromptCacheAll) } if opts.PromptCacheRO { predictOptions = append(predictOptions, llama.EnablePromptCacheRO) } // Expected absolute path if opts.PromptCachePath != "" { predictOptions = append(predictOptions, llama.SetPathPromptCache(opts.PromptCachePath)) } if opts.Mirostat != 0 { predictOptions = append(predictOptions, llama.SetMirostat(int(opts.Mirostat))) } if opts.MirostatETA != 0 { predictOptions = append(predictOptions, llama.SetMirostatETA(opts.MirostatETA)) } if opts.MirostatTAU != 0 { predictOptions = append(predictOptions, llama.SetMirostatTAU(opts.MirostatTAU)) } if opts.Debug { predictOptions = append(predictOptions, llama.Debug) } predictOptions = append(predictOptions, llama.SetStopWords(opts.StopPrompts...)) if opts.PresencePenalty != 0 { predictOptions = append(predictOptions, llama.SetPenalty(opts.PresencePenalty)) } if opts.NKeep != 0 { predictOptions = append(predictOptions, llama.SetNKeep(int(opts.NKeep))) } if opts.Batch != 0 { predictOptions = append(predictOptions, llama.SetBatch(int(opts.Batch))) } if opts.F16KV { predictOptions = append(predictOptions, llama.EnableF16KV) } if opts.IgnoreEOS { predictOptions = append(predictOptions, llama.IgnoreEOS) } if opts.Seed != 0 { predictOptions = append(predictOptions, llama.SetSeed(int(opts.Seed))) } if opts.NDraft != 0 { predictOptions = append(predictOptions, llama.SetNDraft(int(opts.NDraft))) } //predictOptions = append(predictOptions, llama.SetLogitBias(c.Seed)) predictOptions = append(predictOptions, llama.SetFrequencyPenalty(opts.FrequencyPenalty)) predictOptions = append(predictOptions, llama.SetMlock(opts.MLock)) predictOptions = append(predictOptions, llama.SetMemoryMap(opts.MMap)) predictOptions = append(predictOptions, llama.SetPredictionMainGPU(opts.MainGPU)) predictOptions = append(predictOptions, llama.SetPredictionTensorSplit(opts.TensorSplit)) predictOptions = append(predictOptions, llama.SetTailFreeSamplingZ(opts.TailFreeSamplingZ)) predictOptions = append(predictOptions, llama.SetTypicalP(opts.TypicalP)) return predictOptions } func (llm *LLM) Predict(opts *pb.PredictOptions) (string, error) { if llm.draftModel != nil { return llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...) } return llm.llama.Predict(opts.Prompt, buildPredictOptions(opts)...) } func (llm *LLM) PredictStream(opts *pb.PredictOptions, results chan string) error { predictOptions := buildPredictOptions(opts) predictOptions = append(predictOptions, llama.SetTokenCallback(func(token string) bool { results <- token return true })) go func() { var err error if llm.draftModel != nil { _, err = llm.llama.SpeculativeSampling(llm.draftModel, opts.Prompt, buildPredictOptions(opts)...) } else { _, err = llm.llama.Predict(opts.Prompt, predictOptions...) } if err != nil { fmt.Println("err: ", err) } close(results) }() return nil } func (llm *LLM) Embeddings(opts *pb.PredictOptions) ([]float32, error) { predictOptions := buildPredictOptions(opts) if len(opts.EmbeddingTokens) > 0 { tokens := []int{} for _, t := range opts.EmbeddingTokens { tokens = append(tokens, int(t)) } return llm.llama.TokenEmbeddings(tokens, predictOptions...) } return llm.llama.Embeddings(opts.Embeddings, predictOptions...) } func (llm *LLM) TokenizeString(opts *pb.PredictOptions) (pb.TokenizationResponse, error) { predictOptions := buildPredictOptions(opts) l, tokens, err := llm.llama.TokenizeString(opts.Prompt, predictOptions...) if err != nil { return pb.TokenizationResponse{}, err } return pb.TokenizationResponse{ Length: l, Tokens: tokens, }, nil } ================================================ FILE: backend/go/llm/llama/main.go ================================================ package main // GRPC Falcon server // Note: this is started internally by LocalAI and a server is allocated for each model import ( "flag" grpc "github.com/mudler/LocalAI/pkg/grpc" ) var ( addr = flag.String("addr", "localhost:50051", "the address to connect to") ) func main() { flag.Parse() if err := grpc.StartServer(*addr, &LLM{}); err != nil { panic(err) } } ================================================ FILE: backend/go/local-store/Makefile ================================================ GOCMD=go local-store: CGO_ENABLED=0 $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o local-store ./ package: bash package.sh build: local-store package clean: rm -f local-store ================================================ FILE: backend/go/local-store/debug.go ================================================ //go:build debug // +build debug package main import ( "github.com/mudler/xlog" ) func assert(cond bool, msg string) { if !cond { xlog.Fatal().Stack().Msg(msg) } } ================================================ FILE: backend/go/local-store/main.go ================================================ package main // Note: this is started internally by LocalAI and a server is allocated for each store import ( "flag" "os" grpc "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/xlog" ) var ( addr = flag.String("addr", "localhost:50051", "the address to connect to") ) func main() { xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(os.Getenv("LOCALAI_LOG_LEVEL")), os.Getenv("LOCALAI_LOG_FORMAT"))) flag.Parse() if err := grpc.StartServer(*addr, NewStore()); err != nil { panic(err) } } ================================================ FILE: backend/go/local-store/package.sh ================================================ #!/bin/bash # Script to copy the appropriate libraries based on architecture # This script is used in the final stage of the Dockerfile set -e CURDIR=$(dirname "$(realpath $0)") mkdir -p $CURDIR/package cp -avf $CURDIR/local-store $CURDIR/package/ cp -rfv $CURDIR/run.sh $CURDIR/package/ ================================================ FILE: backend/go/local-store/production.go ================================================ //go:build !debug // +build !debug package main func assert(cond bool, msg string) { } ================================================ FILE: backend/go/local-store/run.sh ================================================ #!/bin/bash set -ex CURDIR=$(dirname "$(realpath $0)") exec $CURDIR/local-store "$@" ================================================ FILE: backend/go/local-store/store.go ================================================ package main // This is a wrapper to statisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( "container/heap" "errors" "fmt" "math" "slices" "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/xlog" ) type Store struct { base.SingleThread // The sorted keys keys [][]float32 // The sorted values values [][]byte // If for every K it holds that ||k||^2 = 1, then we can use the normalized distance functions // TODO: Should we normalize incoming keys if they are not instead? keysAreNormalized bool // The first key decides the length of the keys keyLen int } // TODO: Only used for sorting using Go's builtin implementation. The interfaces are columnar because // that's theoretically best for memory layout and cache locality, but this isn't optimized yet. type Pair struct { Key []float32 Value []byte } func NewStore() *Store { return &Store{ keys: make([][]float32, 0), values: make([][]byte, 0), keysAreNormalized: true, keyLen: -1, } } func compareSlices(k1, k2 []float32) int { assert(len(k1) == len(k2), fmt.Sprintf("compareSlices: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) return slices.Compare(k1, k2) } func hasKey(unsortedSlice [][]float32, target []float32) bool { return slices.ContainsFunc(unsortedSlice, func(k []float32) bool { return compareSlices(k, target) == 0 }) } func findInSortedSlice(sortedSlice [][]float32, target []float32) (int, bool) { return slices.BinarySearchFunc(sortedSlice, target, func(k, t []float32) int { return compareSlices(k, t) }) } func isSortedPairs(kvs []Pair) bool { for i := 1; i < len(kvs); i++ { if compareSlices(kvs[i-1].Key, kvs[i].Key) > 0 { return false } } return true } func isSortedKeys(keys [][]float32) bool { for i := 1; i < len(keys); i++ { if compareSlices(keys[i-1], keys[i]) > 0 { return false } } return true } func sortIntoKeySlicese(keys []*pb.StoresKey) [][]float32 { ks := make([][]float32, len(keys)) for i, k := range keys { ks[i] = k.Floats } slices.SortFunc(ks, compareSlices) assert(len(ks) == len(keys), fmt.Sprintf("len(ks) = %d, len(keys) = %d", len(ks), len(keys))) assert(isSortedKeys(ks), "keys are not sorted") return ks } func (s *Store) Load(opts *pb.ModelOptions) error { if opts.Model != "" { return errors.New("not implemented") } return nil } // Sort the incoming kvs and merge them with the existing sorted kvs func (s *Store) StoresSet(opts *pb.StoresSetOptions) error { if len(opts.Keys) == 0 { return fmt.Errorf("no keys to add") } if len(opts.Keys) != len(opts.Values) { return fmt.Errorf("len(keys) = %d, len(values) = %d", len(opts.Keys), len(opts.Values)) } if s.keyLen == -1 { s.keyLen = len(opts.Keys[0].Floats) } else { if len(opts.Keys[0].Floats) != s.keyLen { return fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) } } kvs := make([]Pair, len(opts.Keys)) for i, k := range opts.Keys { if s.keysAreNormalized && !isNormalized(k.Floats) { s.keysAreNormalized = false var sample []float32 if len(s.keys) > 5 { sample = k.Floats[:5] } else { sample = k.Floats } xlog.Debug("Key is not normalized", "sample", sample) } kvs[i] = Pair{ Key: k.Floats, Value: opts.Values[i].Bytes, } } slices.SortFunc(kvs, func(a, b Pair) int { return compareSlices(a.Key, b.Key) }) assert(len(kvs) == len(opts.Keys), fmt.Sprintf("len(kvs) = %d, len(opts.Keys) = %d", len(kvs), len(opts.Keys))) assert(isSortedPairs(kvs), "keys are not sorted") l := len(kvs) + len(s.keys) merge_ks := make([][]float32, 0, l) merge_vs := make([][]byte, 0, l) i, j := 0, 0 for { if i+j >= l { break } if i >= len(kvs) { merge_ks = append(merge_ks, s.keys[j]) merge_vs = append(merge_vs, s.values[j]) j++ continue } if j >= len(s.keys) { merge_ks = append(merge_ks, kvs[i].Key) merge_vs = append(merge_vs, kvs[i].Value) i++ continue } c := compareSlices(kvs[i].Key, s.keys[j]) if c < 0 { merge_ks = append(merge_ks, kvs[i].Key) merge_vs = append(merge_vs, kvs[i].Value) i++ } else if c > 0 { merge_ks = append(merge_ks, s.keys[j]) merge_vs = append(merge_vs, s.values[j]) j++ } else { merge_ks = append(merge_ks, kvs[i].Key) merge_vs = append(merge_vs, kvs[i].Value) i++ j++ } } assert(len(merge_ks) == l, fmt.Sprintf("len(merge_ks) = %d, l = %d", len(merge_ks), l)) assert(isSortedKeys(merge_ks), "merge keys are not sorted") s.keys = merge_ks s.values = merge_vs return nil } func (s *Store) StoresDelete(opts *pb.StoresDeleteOptions) error { if len(opts.Keys) == 0 { return fmt.Errorf("no keys to delete") } if len(opts.Keys) == 0 { return fmt.Errorf("no keys to add") } if s.keyLen == -1 { s.keyLen = len(opts.Keys[0].Floats) } else { if len(opts.Keys[0].Floats) != s.keyLen { return fmt.Errorf("Trying to delete key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) } } ks := sortIntoKeySlicese(opts.Keys) l := len(s.keys) - len(ks) merge_ks := make([][]float32, 0, l) merge_vs := make([][]byte, 0, l) tail_ks := s.keys tail_vs := s.values for _, k := range ks { j, found := findInSortedSlice(tail_ks, k) if found { merge_ks = append(merge_ks, tail_ks[:j]...) merge_vs = append(merge_vs, tail_vs[:j]...) tail_ks = tail_ks[j+1:] tail_vs = tail_vs[j+1:] } else { assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: t=%d, %v", len(tail_ks), k)) } xlog.Debug("Delete", "found", found, "tailLen", len(tail_ks), "j", j, "mergeKeysLen", len(merge_ks), "mergeValuesLen", len(merge_vs)) } merge_ks = append(merge_ks, tail_ks...) merge_vs = append(merge_vs, tail_vs...) assert(len(merge_ks) <= len(s.keys), fmt.Sprintf("len(merge_ks) = %d, len(s.keys) = %d", len(merge_ks), len(s.keys))) s.keys = merge_ks s.values = merge_vs assert(len(s.keys) >= l, fmt.Sprintf("len(s.keys) = %d, l = %d", len(s.keys), l)) assert(isSortedKeys(s.keys), "keys are not sorted") assert(func() bool { for _, k := range ks { if _, found := findInSortedSlice(s.keys, k); found { return false } } return true }(), "Keys to delete still present") if len(s.keys) != l { xlog.Debug("Delete: Some keys not found", "keysLen", len(s.keys), "expectedLen", l) } return nil } func (s *Store) StoresGet(opts *pb.StoresGetOptions) (pb.StoresGetResult, error) { pbKeys := make([]*pb.StoresKey, 0, len(opts.Keys)) pbValues := make([]*pb.StoresValue, 0, len(opts.Keys)) ks := sortIntoKeySlicese(opts.Keys) if len(s.keys) == 0 { xlog.Debug("Get: No keys in store") } if s.keyLen == -1 { s.keyLen = len(opts.Keys[0].Floats) } else { if len(opts.Keys[0].Floats) != s.keyLen { return pb.StoresGetResult{}, fmt.Errorf("Try to get a key with length %d when existing length is %d", len(opts.Keys[0].Floats), s.keyLen) } } tail_k := s.keys tail_v := s.values for i, k := range ks { j, found := findInSortedSlice(tail_k, k) if found { pbKeys = append(pbKeys, &pb.StoresKey{ Floats: k, }) pbValues = append(pbValues, &pb.StoresValue{ Bytes: tail_v[j], }) tail_k = tail_k[j+1:] tail_v = tail_v[j+1:] } else { assert(!hasKey(s.keys, k), fmt.Sprintf("Key exists, but was not found: i=%d, %v", i, k)) } } if len(pbKeys) != len(opts.Keys) { xlog.Debug("Get: Some keys not found", "pbKeysLen", len(pbKeys), "optsKeysLen", len(opts.Keys), "storeKeysLen", len(s.keys)) } return pb.StoresGetResult{ Keys: pbKeys, Values: pbValues, }, nil } func isNormalized(k []float32) bool { var sum float64 for _, v := range k { v64 := float64(v) sum += v64 * v64 } s := math.Sqrt(sum) return s >= 0.99 && s <= 1.01 } // TODO: This we could replace with handwritten SIMD code func normalizedCosineSimilarity(k1, k2 []float32) float32 { assert(len(k1) == len(k2), fmt.Sprintf("normalizedCosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) var dot float32 for i := 0; i < len(k1); i++ { dot += k1[i] * k2[i] } assert(dot >= -1.01 && dot <= 1.01, fmt.Sprintf("dot = %f", dot)) // 2.0 * (1.0 - dot) would be the Euclidean distance return dot } type PriorityItem struct { Similarity float32 Key []float32 Value []byte } type PriorityQueue []*PriorityItem func (pq PriorityQueue) Len() int { return len(pq) } func (pq PriorityQueue) Less(i, j int) bool { // Inverted because the most similar should be at the top return pq[i].Similarity < pq[j].Similarity } func (pq PriorityQueue) Swap(i, j int) { pq[i], pq[j] = pq[j], pq[i] } func (pq *PriorityQueue) Push(x any) { item := x.(*PriorityItem) *pq = append(*pq, item) } func (pq *PriorityQueue) Pop() any { old := *pq n := len(old) item := old[n-1] *pq = old[0 : n-1] return item } func (s *Store) StoresFindNormalized(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { tk := opts.Key.Floats top_ks := make(PriorityQueue, 0, int(opts.TopK)) heap.Init(&top_ks) for i, k := range s.keys { sim := normalizedCosineSimilarity(tk, k) heap.Push(&top_ks, &PriorityItem{ Similarity: sim, Key: k, Value: s.values[i], }) if top_ks.Len() > int(opts.TopK) { heap.Pop(&top_ks) } } similarities := make([]float32, top_ks.Len()) pbKeys := make([]*pb.StoresKey, top_ks.Len()) pbValues := make([]*pb.StoresValue, top_ks.Len()) for i := top_ks.Len() - 1; i >= 0; i-- { item := heap.Pop(&top_ks).(*PriorityItem) similarities[i] = item.Similarity pbKeys[i] = &pb.StoresKey{ Floats: item.Key, } pbValues[i] = &pb.StoresValue{ Bytes: item.Value, } } return pb.StoresFindResult{ Keys: pbKeys, Values: pbValues, Similarities: similarities, }, nil } func cosineSimilarity(k1, k2 []float32, mag1 float64) float32 { assert(len(k1) == len(k2), fmt.Sprintf("cosineSimilarity: len(k1) = %d, len(k2) = %d", len(k1), len(k2))) var dot, mag2 float64 for i := 0; i < len(k1); i++ { dot += float64(k1[i] * k2[i]) mag2 += float64(k2[i] * k2[i]) } sim := float32(dot / (mag1 * math.Sqrt(mag2))) assert(sim >= -1.01 && sim <= 1.01, fmt.Sprintf("sim = %f", sim)) return sim } func (s *Store) StoresFindFallback(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { tk := opts.Key.Floats top_ks := make(PriorityQueue, 0, int(opts.TopK)) heap.Init(&top_ks) var mag1 float64 for _, v := range tk { mag1 += float64(v * v) } mag1 = math.Sqrt(mag1) for i, k := range s.keys { dist := cosineSimilarity(tk, k, mag1) heap.Push(&top_ks, &PriorityItem{ Similarity: dist, Key: k, Value: s.values[i], }) if top_ks.Len() > int(opts.TopK) { heap.Pop(&top_ks) } } similarities := make([]float32, top_ks.Len()) pbKeys := make([]*pb.StoresKey, top_ks.Len()) pbValues := make([]*pb.StoresValue, top_ks.Len()) for i := top_ks.Len() - 1; i >= 0; i-- { item := heap.Pop(&top_ks).(*PriorityItem) similarities[i] = item.Similarity pbKeys[i] = &pb.StoresKey{ Floats: item.Key, } pbValues[i] = &pb.StoresValue{ Bytes: item.Value, } } return pb.StoresFindResult{ Keys: pbKeys, Values: pbValues, Similarities: similarities, }, nil } func (s *Store) StoresFind(opts *pb.StoresFindOptions) (pb.StoresFindResult, error) { tk := opts.Key.Floats if len(tk) != s.keyLen { return pb.StoresFindResult{}, fmt.Errorf("Try to find key with length %d when existing length is %d", len(tk), s.keyLen) } if opts.TopK < 1 { return pb.StoresFindResult{}, fmt.Errorf("opts.TopK = %d, must be >= 1", opts.TopK) } if s.keyLen == -1 { s.keyLen = len(opts.Key.Floats) } else { if len(opts.Key.Floats) != s.keyLen { return pb.StoresFindResult{}, fmt.Errorf("Try to add key with length %d when existing length is %d", len(opts.Key.Floats), s.keyLen) } } if s.keysAreNormalized && isNormalized(tk) { return s.StoresFindNormalized(opts) } else { if s.keysAreNormalized { var sample []float32 if len(s.keys) > 5 { sample = tk[:5] } else { sample = tk } xlog.Debug("Trying to compare non-normalized key with normalized keys", "sample", sample) } return s.StoresFindFallback(opts) } } ================================================ FILE: backend/go/opus/Makefile ================================================ GOCMD?=go GO_TAGS?= OPUS_CFLAGS := $(shell pkg-config --cflags opus) OPUS_LIBS := $(shell pkg-config --libs opus) libopusshim.so: csrc/opus_shim.c $(CC) -shared -fPIC -o $@ $< $(OPUS_CFLAGS) $(OPUS_LIBS) opus: libopusshim.so $(GOCMD) build -tags "$(GO_TAGS)" -o opus ./ package: opus bash package.sh build: package clean: rm -f opus libopusshim.so ================================================ FILE: backend/go/opus/codec.go ================================================ package main import ( "errors" "fmt" "os" "path/filepath" "runtime" "sync" "github.com/ebitengine/purego" ) const ( ApplicationVoIP = 2048 ApplicationAudio = 2049 ApplicationRestrictedLowDelay = 2051 ) var ( initOnce sync.Once initErr error opusLib uintptr shimLib uintptr // libopus functions cEncoderCreate func(fs int32, channels int32, application int32, errPtr *int32) uintptr cEncode func(st uintptr, pcm *int16, frameSize int32, data *byte, maxBytes int32) int32 cEncoderDestroy func(st uintptr) cDecoderCreate func(fs int32, channels int32, errPtr *int32) uintptr cDecode func(st uintptr, data *byte, dataLen int32, pcm *int16, frameSize int32, decodeFec int32) int32 cDecoderDestroy func(st uintptr) // shim functions (non-variadic wrappers for opus_encoder_ctl) cSetBitrate func(st uintptr, bitrate int32) int32 cSetComplexity func(st uintptr, complexity int32) int32 ) func loadLib(names []string) (uintptr, error) { var firstErr error for _, name := range names { h, err := purego.Dlopen(name, purego.RTLD_NOW|purego.RTLD_GLOBAL) if err == nil { return h, nil } if firstErr == nil { firstErr = err } } return 0, firstErr } func ensureInit() error { initOnce.Do(func() { initErr = doInit() }) return initErr } const shimHint = "ensure libopus-dev is installed and rebuild, or set OPUS_LIBRARY / OPUS_SHIM_LIBRARY env vars" func doInit() error { opusNames := opusSearchPaths() var err error opusLib, err = loadLib(opusNames) if err != nil { return fmt.Errorf("opus: failed to load libopus (%s): %w", shimHint, err) } purego.RegisterLibFunc(&cEncoderCreate, opusLib, "opus_encoder_create") purego.RegisterLibFunc(&cEncode, opusLib, "opus_encode") purego.RegisterLibFunc(&cEncoderDestroy, opusLib, "opus_encoder_destroy") purego.RegisterLibFunc(&cDecoderCreate, opusLib, "opus_decoder_create") purego.RegisterLibFunc(&cDecode, opusLib, "opus_decode") purego.RegisterLibFunc(&cDecoderDestroy, opusLib, "opus_decoder_destroy") shimNames := shimSearchPaths() shimLib, err = loadLib(shimNames) if err != nil { return fmt.Errorf("opus: failed to load libopusshim (%s): %w", shimHint, err) } purego.RegisterLibFunc(&cSetBitrate, shimLib, "opus_shim_encoder_set_bitrate") purego.RegisterLibFunc(&cSetComplexity, shimLib, "opus_shim_encoder_set_complexity") return nil } func opusSearchPaths() []string { var paths []string if env := os.Getenv("OPUS_LIBRARY"); env != "" { paths = append(paths, env) } if exe, err := os.Executable(); err == nil { dir := filepath.Dir(exe) paths = append(paths, filepath.Join(dir, "libopus.so.0"), filepath.Join(dir, "libopus.so")) if runtime.GOOS == "darwin" { paths = append(paths, filepath.Join(dir, "libopus.dylib")) } } paths = append(paths, "libopus.so.0", "libopus.so", "libopus.dylib", "opus.dll") if runtime.GOOS == "darwin" { paths = append(paths, "/opt/homebrew/lib/libopus.dylib", "/usr/local/lib/libopus.dylib", ) } return paths } func shimSearchPaths() []string { var paths []string if env := os.Getenv("OPUS_SHIM_LIBRARY"); env != "" { paths = append(paths, env) } if exe, err := os.Executable(); err == nil { dir := filepath.Dir(exe) paths = append(paths, filepath.Join(dir, "libopusshim.so")) if runtime.GOOS == "darwin" { paths = append(paths, filepath.Join(dir, "libopusshim.dylib")) } } paths = append(paths, "./libopusshim.so", "libopusshim.so") if runtime.GOOS == "darwin" { paths = append(paths, "./libopusshim.dylib", "libopusshim.dylib") } return paths } // Encoder wraps a libopus OpusEncoder via purego. type Encoder struct { st uintptr } func NewEncoder(sampleRate, channels, application int) (*Encoder, error) { if err := ensureInit(); err != nil { return nil, err } var opusErr int32 st := cEncoderCreate(int32(sampleRate), int32(channels), int32(application), &opusErr) if opusErr != 0 || st == 0 { return nil, fmt.Errorf("opus_encoder_create failed: error %d", opusErr) } return &Encoder{st: st}, nil } // Encode encodes a frame of PCM int16 samples. It returns the number of bytes // written to out, or a negative error code. func (e *Encoder) Encode(pcm []int16, frameSize int, out []byte) (int, error) { if len(pcm) == 0 || len(out) == 0 { return 0, errors.New("opus encode: empty input or output buffer") } n := cEncode(e.st, &pcm[0], int32(frameSize), &out[0], int32(len(out))) if n < 0 { return 0, fmt.Errorf("opus_encode failed: error %d", n) } return int(n), nil } func (e *Encoder) SetBitrate(bitrate int) error { if ret := cSetBitrate(e.st, int32(bitrate)); ret != 0 { return fmt.Errorf("opus set bitrate: error %d", ret) } return nil } func (e *Encoder) SetComplexity(complexity int) error { if ret := cSetComplexity(e.st, int32(complexity)); ret != 0 { return fmt.Errorf("opus set complexity: error %d", ret) } return nil } func (e *Encoder) Close() { if e.st != 0 { cEncoderDestroy(e.st) e.st = 0 } } // Decoder wraps a libopus OpusDecoder via purego. type Decoder struct { st uintptr } func NewDecoder(sampleRate, channels int) (*Decoder, error) { if err := ensureInit(); err != nil { return nil, err } var opusErr int32 st := cDecoderCreate(int32(sampleRate), int32(channels), &opusErr) if opusErr != 0 || st == 0 { return nil, fmt.Errorf("opus_decoder_create failed: error %d", opusErr) } return &Decoder{st: st}, nil } // Decode decodes an Opus packet into pcm. frameSize is the max number of // samples per channel that pcm can hold. Returns the number of decoded samples // per channel. func (d *Decoder) Decode(data []byte, pcm []int16, frameSize int, fec bool) (int, error) { if len(pcm) == 0 { return 0, errors.New("opus decode: empty output buffer") } var dataPtr *byte var dataLen int32 if len(data) > 0 { dataPtr = &data[0] dataLen = int32(len(data)) } decodeFec := int32(0) if fec { decodeFec = 1 } n := cDecode(d.st, dataPtr, dataLen, &pcm[0], int32(frameSize), decodeFec) if n < 0 { return 0, fmt.Errorf("opus_decode failed: error %d", n) } return int(n), nil } func (d *Decoder) Close() { if d.st != 0 { cDecoderDestroy(d.st) d.st = 0 } } // Init eagerly loads the opus libraries, returning any error. // Calling this is optional; the libraries are loaded lazily on first use. func Init() error { return ensureInit() } // Reset allows re-initialization (for testing). func Reset() { initOnce = sync.Once{} initErr = nil opusLib = 0 shimLib = 0 } ================================================ FILE: backend/go/opus/csrc/opus_shim.c ================================================ #include int opus_shim_encoder_set_bitrate(OpusEncoder *st, opus_int32 bitrate) { return opus_encoder_ctl(st, OPUS_SET_BITRATE(bitrate)); } int opus_shim_encoder_set_complexity(OpusEncoder *st, opus_int32 complexity) { return opus_encoder_ctl(st, OPUS_SET_COMPLEXITY(complexity)); } ================================================ FILE: backend/go/opus/main.go ================================================ package main import ( "flag" grpc "github.com/mudler/LocalAI/pkg/grpc" ) var addr = flag.String("addr", "localhost:50051", "the address to connect to") func main() { flag.Parse() if err := grpc.StartServer(*addr, &Opus{}); err != nil { panic(err) } } ================================================ FILE: backend/go/opus/opus.go ================================================ package main import ( "fmt" "sync" "time" "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/sound" ) const ( opusSampleRate = 48000 opusChannels = 1 opusFrameSize = 960 // 20ms at 48kHz opusMaxPacketSize = 4000 opusMaxFrameSize = 5760 // 120ms at 48kHz decoderIdleTTL = 60 * time.Second decoderEvictTick = 30 * time.Second ) type cachedDecoder struct { mu sync.Mutex dec *Decoder lastUsed time.Time } type Opus struct { base.Base decodersMu sync.Mutex decoders map[string]*cachedDecoder } func (o *Opus) Load(opts *pb.ModelOptions) error { o.decoders = make(map[string]*cachedDecoder) go o.evictLoop() return Init() } func (o *Opus) evictLoop() { ticker := time.NewTicker(decoderEvictTick) defer ticker.Stop() for range ticker.C { o.decodersMu.Lock() now := time.Now() for id, cd := range o.decoders { if now.Sub(cd.lastUsed) > decoderIdleTTL { cd.dec.Close() delete(o.decoders, id) } } o.decodersMu.Unlock() } } // getOrCreateDecoder returns a cached decoder for the given session ID, // creating one if it doesn't exist yet. func (o *Opus) getOrCreateDecoder(sessionID string) (*cachedDecoder, error) { o.decodersMu.Lock() defer o.decodersMu.Unlock() if cd, ok := o.decoders[sessionID]; ok { cd.lastUsed = time.Now() return cd, nil } dec, err := NewDecoder(opusSampleRate, opusChannels) if err != nil { return nil, err } cd := &cachedDecoder{dec: dec, lastUsed: time.Now()} o.decoders[sessionID] = cd return cd, nil } func (o *Opus) AudioEncode(req *pb.AudioEncodeRequest) (*pb.AudioEncodeResult, error) { enc, err := NewEncoder(opusSampleRate, opusChannels, ApplicationAudio) if err != nil { return nil, fmt.Errorf("opus encoder create: %w", err) } defer enc.Close() if err := enc.SetBitrate(64000); err != nil { return nil, fmt.Errorf("opus set bitrate: %w", err) } if err := enc.SetComplexity(10); err != nil { return nil, fmt.Errorf("opus set complexity: %w", err) } samples := sound.BytesToInt16sLE(req.PcmData) if len(samples) == 0 { return &pb.AudioEncodeResult{ SampleRate: opusSampleRate, SamplesPerFrame: opusFrameSize, }, nil } if req.SampleRate != 0 && int(req.SampleRate) != opusSampleRate { samples = sound.ResampleInt16(samples, int(req.SampleRate), opusSampleRate) } var frames [][]byte packet := make([]byte, opusMaxPacketSize) for offset := 0; offset+opusFrameSize <= len(samples); offset += opusFrameSize { frame := samples[offset : offset+opusFrameSize] n, err := enc.Encode(frame, opusFrameSize, packet) if err != nil { return nil, fmt.Errorf("opus encode: %w", err) } out := make([]byte, n) copy(out, packet[:n]) frames = append(frames, out) } return &pb.AudioEncodeResult{ Frames: frames, SampleRate: opusSampleRate, SamplesPerFrame: opusFrameSize, }, nil } func (o *Opus) AudioDecode(req *pb.AudioDecodeRequest) (*pb.AudioDecodeResult, error) { if len(req.Frames) == 0 { return &pb.AudioDecodeResult{ SampleRate: opusSampleRate, SamplesPerFrame: opusFrameSize, }, nil } // Use a persistent decoder when a session ID is provided so that Opus // prediction state carries across batches. Fall back to a fresh decoder // for backward compatibility. sessionID := req.Options["session_id"] var cd *cachedDecoder var ownedDec *Decoder if sessionID != "" && o.decoders != nil { var err error cd, err = o.getOrCreateDecoder(sessionID) if err != nil { return nil, fmt.Errorf("opus decoder create: %w", err) } cd.mu.Lock() defer cd.mu.Unlock() } else { dec, err := NewDecoder(opusSampleRate, opusChannels) if err != nil { return nil, fmt.Errorf("opus decoder create: %w", err) } ownedDec = dec defer ownedDec.Close() } dec := ownedDec if cd != nil { dec = cd.dec } var allSamples []int16 var samplesPerFrame int32 pcm := make([]int16, opusMaxFrameSize) for _, frame := range req.Frames { n, err := dec.Decode(frame, pcm, opusMaxFrameSize, false) if err != nil { return nil, fmt.Errorf("opus decode: %w", err) } if samplesPerFrame == 0 { samplesPerFrame = int32(n) } allSamples = append(allSamples, pcm[:n]...) } return &pb.AudioDecodeResult{ PcmData: sound.Int16toBytesLE(allSamples), SampleRate: opusSampleRate, SamplesPerFrame: samplesPerFrame, }, nil } ================================================ FILE: backend/go/opus/opus_test.go ================================================ package main import ( "encoding/binary" "fmt" "io" "math" "math/rand/v2" "os" "os/exec" "path/filepath" "sync" "testing" "time" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/sound" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/pion/rtp" "github.com/pion/webrtc/v4" ) func TestOpusBackend(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Opus Backend Suite") } // --- helpers --- func generateSineWave(freq float64, sampleRate, numSamples int) []int16 { out := make([]int16, numSamples) for i := range out { t := float64(i) / float64(sampleRate) out[i] = int16(math.MaxInt16 / 2 * math.Sin(2*math.Pi*freq*t)) } return out } func computeRMS(samples []int16) float64 { if len(samples) == 0 { return 0 } var sum float64 for _, s := range samples { v := float64(s) sum += v * v } return math.Sqrt(sum / float64(len(samples))) } func estimateFrequency(samples []int16, sampleRate int) float64 { if len(samples) < 2 { return 0 } crossings := 0 for i := 1; i < len(samples); i++ { if (samples[i-1] >= 0 && samples[i] < 0) || (samples[i-1] < 0 && samples[i] >= 0) { crossings++ } } duration := float64(len(samples)) / float64(sampleRate) return float64(crossings) / (2 * duration) } // encodeDecodeRoundtrip uses the Opus backend to encode PCM and decode all // resulting frames, returning the concatenated decoded samples. func encodeDecodeRoundtrip(o *Opus, pcmBytes []byte, sampleRate int) []int16 { encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: int32(sampleRate), Channels: 1, }) Expect(err).ToNot(HaveOccurred(), "AudioEncode") if len(encResult.Frames) == 0 { return nil } decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames, }) Expect(err).ToNot(HaveOccurred(), "AudioDecode") return sound.BytesToInt16sLE(decResult.PcmData) } func extractOpusFramesFromOgg(data []byte) [][]byte { var frames [][]byte pos := 0 pageNum := 0 for pos+27 <= len(data) { Expect(string(data[pos:pos+4])).To(Equal("OggS"), fmt.Sprintf("invalid Ogg page at offset %d", pos)) nSegments := int(data[pos+26]) if pos+27+nSegments > len(data) { break } segTable := data[pos+27 : pos+27+nSegments] dataStart := pos + 27 + nSegments var totalDataSize int for _, s := range segTable { totalDataSize += int(s) } if dataStart+totalDataSize > len(data) { break } if pageNum >= 2 { pageData := data[dataStart : dataStart+totalDataSize] offset := 0 var packet []byte for _, segSize := range segTable { packet = append(packet, pageData[offset:offset+int(segSize)]...) offset += int(segSize) if segSize < 255 { if len(packet) > 0 { frameCopy := make([]byte, len(packet)) copy(frameCopy, packet) frames = append(frames, frameCopy) } packet = nil } } if len(packet) > 0 { frameCopy := make([]byte, len(packet)) copy(frameCopy, packet) frames = append(frames, frameCopy) } } pos = dataStart + totalDataSize pageNum++ } return frames } func parseTestWAV(data []byte) (pcm []byte, sampleRate int) { if len(data) < 44 || string(data[0:4]) != "RIFF" { return data, 0 } pos := 12 sr := int(binary.LittleEndian.Uint32(data[24:28])) for pos+8 <= len(data) { id := string(data[pos : pos+4]) sz := int(binary.LittleEndian.Uint32(data[pos+4 : pos+8])) if id == "data" { end := pos + 8 + sz if end > len(data) { end = len(data) } return data[pos+8 : end], sr } pos += 8 + sz if sz%2 != 0 { pos++ } } return data[44:], sr } func writeOggOpus(path string, frames [][]byte, sampleRate, channels int) error { f, err := os.Create(path) if err != nil { return err } defer f.Close() serial := uint32(0x4C6F6341) // "LocA" var pageSeq uint32 const preSkip = 312 opusHead := make([]byte, 19) copy(opusHead[0:8], "OpusHead") opusHead[8] = 1 opusHead[9] = byte(channels) binary.LittleEndian.PutUint16(opusHead[10:12], uint16(preSkip)) binary.LittleEndian.PutUint32(opusHead[12:16], uint32(sampleRate)) binary.LittleEndian.PutUint16(opusHead[16:18], 0) opusHead[18] = 0 if err := writeOggPage(f, serial, pageSeq, 0, 0x02, [][]byte{opusHead}); err != nil { return err } pageSeq++ opusTags := make([]byte, 16) copy(opusTags[0:8], "OpusTags") binary.LittleEndian.PutUint32(opusTags[8:12], 0) binary.LittleEndian.PutUint32(opusTags[12:16], 0) if err := writeOggPage(f, serial, pageSeq, 0, 0x00, [][]byte{opusTags}); err != nil { return err } pageSeq++ var granulePos uint64 for i, frame := range frames { granulePos += 960 headerType := byte(0x00) if i == len(frames)-1 { headerType = 0x04 } if err := writeOggPage(f, serial, pageSeq, granulePos, headerType, [][]byte{frame}); err != nil { return err } pageSeq++ } return nil } func writeOggPage(w io.Writer, serial, pageSeq uint32, granulePos uint64, headerType byte, packets [][]byte) error { var segments []byte var pageData []byte for _, pkt := range packets { remaining := len(pkt) for remaining >= 255 { segments = append(segments, 255) remaining -= 255 } segments = append(segments, byte(remaining)) pageData = append(pageData, pkt...) } hdr := make([]byte, 27+len(segments)) copy(hdr[0:4], "OggS") hdr[4] = 0 hdr[5] = headerType binary.LittleEndian.PutUint64(hdr[6:14], granulePos) binary.LittleEndian.PutUint32(hdr[14:18], serial) binary.LittleEndian.PutUint32(hdr[18:22], pageSeq) hdr[26] = byte(len(segments)) copy(hdr[27:], segments) crc := oggCRC32(hdr, pageData) binary.LittleEndian.PutUint32(hdr[22:26], crc) if _, err := w.Write(hdr); err != nil { return err } _, err := w.Write(pageData) return err } func oggCRC32(header, data []byte) uint32 { var crc uint32 for _, b := range header { crc = (crc << 8) ^ oggCRCTable[byte(crc>>24)^b] } for _, b := range data { crc = (crc << 8) ^ oggCRCTable[byte(crc>>24)^b] } return crc } var oggCRCTable = func() [256]uint32 { var t [256]uint32 for i := range 256 { r := uint32(i) << 24 for range 8 { if r&0x80000000 != 0 { r = (r << 1) ^ 0x04C11DB7 } else { r <<= 1 } } t[i] = r } return t }() func goertzel(samples []int16, targetFreq float64, sampleRate int) float64 { N := len(samples) if N == 0 { return 0 } k := 0.5 + float64(N)*targetFreq/float64(sampleRate) w := 2 * math.Pi * k / float64(N) coeff := 2 * math.Cos(w) var s1, s2 float64 for _, sample := range samples { s0 := float64(sample) + coeff*s1 - s2 s2 = s1 s1 = s0 } return s1*s1 + s2*s2 - coeff*s1*s2 } func computeTHD(samples []int16, fundamentalHz float64, sampleRate, numHarmonics int) float64 { fundPower := goertzel(samples, fundamentalHz, sampleRate) if fundPower <= 0 { return 0 } var harmonicSum float64 for h := 2; h <= numHarmonics; h++ { harmonicSum += goertzel(samples, fundamentalHz*float64(h), sampleRate) } return math.Sqrt(harmonicSum/fundPower) * 100 } // --- Opus specs --- var _ = Describe("Opus", func() { var o *Opus BeforeEach(func() { o = &Opus{} Expect(o.Load(&pb.ModelOptions{})).To(Succeed()) }) It("decodes Chrome-like VoIP frames", func() { enc, err := NewEncoder(48000, 1, ApplicationVoIP) Expect(err).ToNot(HaveOccurred()) defer enc.Close() Expect(enc.SetBitrate(32000)).To(Succeed()) Expect(enc.SetComplexity(5)).To(Succeed()) sine := generateSineWave(440, 48000, 48000) packet := make([]byte, 4000) var opusFrames [][]byte for offset := 0; offset+opusFrameSize <= len(sine); offset += opusFrameSize { frame := sine[offset : offset+opusFrameSize] n, err := enc.Encode(frame, opusFrameSize, packet) Expect(err).ToNot(HaveOccurred(), "VoIP encode") out := make([]byte, n) copy(out, packet[:n]) opusFrames = append(opusFrames, out) } result, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: opusFrames}) Expect(err).ToNot(HaveOccurred()) allDecoded := sound.BytesToInt16sLE(result.PcmData) Expect(allDecoded).ToNot(BeEmpty(), "no decoded samples from VoIP encoder") skip := min(len(allDecoded)/4, 48000*100/1000) tail := allDecoded[skip:] rms := computeRMS(tail) GinkgoWriter.Printf("VoIP/SILK roundtrip: %d decoded samples, RMS=%.1f\n", len(allDecoded), rms) Expect(rms).To(BeNumerically(">=", 50), "VoIP decoded RMS is too low; SILK decoder may be broken") }) It("decodes stereo-encoded Opus with a mono decoder", func() { enc, err := NewEncoder(48000, 2, ApplicationVoIP) Expect(err).ToNot(HaveOccurred()) defer enc.Close() Expect(enc.SetBitrate(32000)).To(Succeed()) mono := generateSineWave(440, 48000, 48000) stereo := make([]int16, len(mono)*2) for i, s := range mono { stereo[i*2] = s stereo[i*2+1] = s } packet := make([]byte, 4000) var opusFrames [][]byte for offset := 0; offset+opusFrameSize*2 <= len(stereo); offset += opusFrameSize * 2 { frame := stereo[offset : offset+opusFrameSize*2] n, err := enc.Encode(frame, opusFrameSize, packet) Expect(err).ToNot(HaveOccurred(), "Stereo encode") out := make([]byte, n) copy(out, packet[:n]) opusFrames = append(opusFrames, out) } result, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: opusFrames}) Expect(err).ToNot(HaveOccurred()) allDecoded := sound.BytesToInt16sLE(result.PcmData) Expect(allDecoded).ToNot(BeEmpty(), "no decoded samples from stereo encoder") skip := min(len(allDecoded)/4, 48000*100/1000) tail := allDecoded[skip:] rms := computeRMS(tail) GinkgoWriter.Printf("Stereo->Mono: %d decoded samples, RMS=%.1f\n", len(allDecoded), rms) Expect(rms).To(BeNumerically(">=", 50), "Stereo->Mono decoded RMS is too low") }) Describe("decoding libopus-encoded audio", func() { var ffmpegPath string var tmpDir string var pcmPath string var sine []int16 BeforeEach(func() { var err error ffmpegPath, err = exec.LookPath("ffmpeg") if err != nil { Skip("ffmpeg not found") } tmpDir = GinkgoT().TempDir() sine = generateSineWave(440, 48000, 48000) pcmBytes := sound.Int16toBytesLE(sine) pcmPath = filepath.Join(tmpDir, "input.raw") Expect(os.WriteFile(pcmPath, pcmBytes, 0644)).To(Succeed()) }) for _, tc := range []struct { name string bitrate string app string }{ {"voip_32k", "32000", "voip"}, {"voip_64k", "64000", "voip"}, {"audio_64k", "64000", "audio"}, {"audio_128k", "128000", "audio"}, } { tc := tc It(tc.name, func() { oggPath := filepath.Join(tmpDir, fmt.Sprintf("libopus_%s_%s.ogg", tc.app, tc.bitrate)) cmd := exec.Command(ffmpegPath, "-y", "-f", "s16le", "-ar", "48000", "-ac", "1", "-i", pcmPath, "-c:a", "libopus", "-b:a", tc.bitrate, "-application", tc.app, "-frame_duration", "20", "-vbr", "on", oggPath, ) out, err := cmd.CombinedOutput() Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("ffmpeg encode: %s", out)) oggData, err := os.ReadFile(oggPath) Expect(err).ToNot(HaveOccurred()) opusFrames := extractOpusFramesFromOgg(oggData) Expect(opusFrames).ToNot(BeEmpty(), "no Opus frames extracted from Ogg container") GinkgoWriter.Printf("Extracted %d Opus frames from libopus encoder (first frame %d bytes)\n", len(opusFrames), len(opusFrames[0])) result, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: opusFrames}) Expect(err).ToNot(HaveOccurred()) allDecoded := sound.BytesToInt16sLE(result.PcmData) Expect(allDecoded).ToNot(BeEmpty(), "no decoded samples from libopus-encoded Opus") skip := min(len(allDecoded)/4, 48000*100/1000) tail := allDecoded[skip:] rms := computeRMS(tail) freq := estimateFrequency(tail, 48000) GinkgoWriter.Printf("libopus->opus-go: %d decoded samples, RMS=%.1f, freq≈%.0f Hz\n", len(allDecoded), rms, freq) Expect(rms).To(BeNumerically(">=", 50), "RMS is too low — opus-go cannot decode libopus output") Expect(freq).To(BeNumerically("~", 440, 30), fmt.Sprintf("frequency %.0f Hz deviates from expected 440 Hz", freq)) }) } }) It("roundtrips at 48kHz", func() { sine := generateSineWave(440, 48000, 48000) pcmBytes := sound.Int16toBytesLE(sine) decoded := encodeDecodeRoundtrip(o, pcmBytes, 48000) Expect(decoded).ToNot(BeEmpty()) decodedSR := 48000 skipDecoded := decodedSR * 50 / 1000 if skipDecoded > len(decoded)/2 { skipDecoded = len(decoded) / 4 } tail := decoded[skipDecoded:] rms := computeRMS(tail) GinkgoWriter.Printf("48kHz roundtrip: %d decoded samples, RMS=%.1f\n", len(decoded), rms) Expect(rms).To(BeNumerically(">=", 50), "decoded audio RMS is too low; signal appears silent") }) It("roundtrips at 16kHz", func() { sine16k := generateSineWave(440, 16000, 16000) pcmBytes := sound.Int16toBytesLE(sine16k) decoded := encodeDecodeRoundtrip(o, pcmBytes, 16000) Expect(decoded).ToNot(BeEmpty()) decoded16k := sound.ResampleInt16(decoded, 48000, 16000) skip := min(len(decoded16k)/4, 16000*50/1000) tail := decoded16k[skip:] rms := computeRMS(tail) GinkgoWriter.Printf("16kHz roundtrip: %d decoded@48k -> %d resampled@16k, RMS=%.1f\n", len(decoded), len(decoded16k), rms) Expect(rms).To(BeNumerically(">=", 50), "decoded audio RMS is too low; signal appears silent") }) It("returns empty frames for empty input", func() { result, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: []byte{}, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) Expect(result.Frames).To(BeEmpty()) }) It("silently drops sub-frame input", func() { sine := generateSineWave(440, 48000, 500) // < 960 pcmBytes := sound.Int16toBytesLE(sine) result, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) Expect(result.Frames).To(BeEmpty(), fmt.Sprintf("expected 0 frames for %d samples (< 960)", len(sine))) }) It("encodes multiple frames", func() { sine := generateSineWave(440, 48000, 2880) // exactly 3 frames pcmBytes := sound.Int16toBytesLE(sine) result, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) Expect(result.Frames).To(HaveLen(3)) }) It("produces expected decoded frame size", func() { sine := generateSineWave(440, 48000, 960) pcmBytes := sound.Int16toBytesLE(sine) encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) Expect(encResult.Frames).To(HaveLen(1)) decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames, }) Expect(err).ToNot(HaveOccurred()) decoded := sound.BytesToInt16sLE(decResult.PcmData) GinkgoWriter.Printf("Encoder input: 960 samples (20ms @ 48kHz)\n") GinkgoWriter.Printf("Decoder output: %d samples (%.1fms @ 48kHz)\n", len(decoded), float64(len(decoded))/48.0) Expect(len(decoded)).To(SatisfyAny(Equal(960), Equal(480)), fmt.Sprintf("unexpected decoded frame size %d", len(decoded))) }) It("handles the full WebRTC output path", func() { sine16k := generateSineWave(440, 16000, 16000) pcmBytes := sound.Int16toBytesLE(sine16k) decoded := encodeDecodeRoundtrip(o, pcmBytes, 16000) Expect(decoded).ToNot(BeEmpty()) rms := computeRMS(decoded) GinkgoWriter.Printf("WebRTC output path: %d decoded samples at 48kHz, RMS=%.1f\n", len(decoded), rms) Expect(rms).To(BeNumerically(">=", 50), "decoded audio RMS is too low") }) It("handles the full WebRTC input path", func() { sine48k := generateSineWave(440, 48000, 48000) pcmBytes := sound.Int16toBytesLE(sine48k) decoded48k := encodeDecodeRoundtrip(o, pcmBytes, 48000) Expect(decoded48k).ToNot(BeEmpty()) step24k := sound.ResampleInt16(decoded48k, 48000, 24000) webrtcPath := sound.ResampleInt16(step24k, 24000, 16000) rms := computeRMS(webrtcPath) GinkgoWriter.Printf("WebRTC input path: %d decoded@48k -> %d@24k -> %d@16k, RMS=%.1f\n", len(decoded48k), len(step24k), len(webrtcPath), rms) Expect(rms).To(BeNumerically(">=", 50), "WebRTC input path signal lost in pipeline") }) Context("bug documentation", func() { It("documents trailing sample loss", func() { sine := generateSineWave(440, 48000, 1000) pcmBytes := sound.Int16toBytesLE(sine) result, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) Expect(result.Frames).To(HaveLen(1)) decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: result.Frames}) Expect(err).ToNot(HaveOccurred()) decoded := sound.BytesToInt16sLE(decResult.PcmData) GinkgoWriter.Printf("Input: 1000 samples, Encoded: 1 frame, Decoded: %d samples (40 samples lost)\n", len(decoded)) Expect(len(decoded)).To(BeNumerically("<=", 960), fmt.Sprintf("decoded more samples (%d) than the encoder consumed (960)", len(decoded))) }) It("documents TTS sample rate mismatch", func() { sine24k := generateSineWave(440, 24000, 24000) pcmBytes := sound.Int16toBytesLE(sine24k) decodedBug := encodeDecodeRoundtrip(o, pcmBytes, 16000) decodedCorrect := encodeDecodeRoundtrip(o, pcmBytes, 24000) skipBug := min(len(decodedBug)/4, 48000*100/1000) skipCorrect := min(len(decodedCorrect)/4, 48000*100/1000) bugTail := decodedBug[skipBug:] correctTail := decodedCorrect[skipCorrect:] bugFreq := estimateFrequency(bugTail, 48000) correctFreq := estimateFrequency(correctTail, 48000) GinkgoWriter.Printf("Bug path: %d decoded samples, freq≈%.0f Hz (expected ~660 Hz = 440*1.5)\n", len(decodedBug), bugFreq) GinkgoWriter.Printf("Correct path: %d decoded samples, freq≈%.0f Hz (expected ~440 Hz)\n", len(decodedCorrect), correctFreq) if len(decodedBug) > 0 && len(decodedCorrect) > 0 { ratio := float64(len(decodedBug)) / float64(len(decodedCorrect)) GinkgoWriter.Printf("Sample count ratio (bug/correct): %.2f (expected ~1.5)\n", ratio) Expect(ratio).To(BeNumerically(">=", 1.1), "expected bug path to produce significantly more samples due to wrong resample ratio") } }) }) Context("batch boundary discontinuity", func() { // These tests simulate the exact production pipeline: // Browser encodes → RTP → batch 15 frames (300ms) → decode → resample 48k→16k → append // They test both with and without persistent decoders to verify // that the session_id persistent decoder path works correctly. It("batched decode+resample with persistent decoder matches one-shot", func() { // Encode 3 seconds of 440Hz at 48kHz — enough for 10 batches sine := generateSineWave(440, 48000, 48000*3) pcmBytes := sound.Int16toBytesLE(sine) encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) GinkgoWriter.Printf("Encoded %d frames (%.0fms)\n", len(encResult.Frames), float64(len(encResult.Frames))*20.0) // Ground truth: decode ALL frames with one decoder, resample in one shot decAll, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames, Options: map[string]string{"session_id": "ground-truth"}, }) Expect(err).ToNot(HaveOccurred()) allSamples := sound.BytesToInt16sLE(decAll.PcmData) oneShotResampled := sound.ResampleInt16(allSamples, 48000, 16000) // Production path: decode in 15-frame batches with persistent decoder, // resample each batch independently, concatenate const framesPerBatch = 15 sessionID := "batch-test" var batchedResampled []int16 batchCount := 0 for i := 0; i < len(encResult.Frames); i += framesPerBatch { end := min(i+framesPerBatch, len(encResult.Frames)) decBatch, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames[i:end], Options: map[string]string{"session_id": sessionID}, }) Expect(err).ToNot(HaveOccurred()) batchSamples := sound.BytesToInt16sLE(decBatch.PcmData) batchResampled := sound.ResampleInt16(batchSamples, 48000, 16000) batchedResampled = append(batchedResampled, batchResampled...) batchCount++ } GinkgoWriter.Printf("Decoded in %d batches, oneshot=%d samples, batched=%d samples\n", batchCount, len(oneShotResampled), len(batchedResampled)) // Skip codec startup transient (first 100ms) skip := 16000 * 100 / 1000 oneShotTail := oneShotResampled[skip:] batchedTail := batchedResampled[skip:] minLen := min(len(oneShotTail), len(batchedTail)) // With persistent decoder, batched decode should be nearly identical // to one-shot (only difference is resampler batch boundaries). var maxDiff float64 var sumDiffSq float64 for i := 0; i < minLen; i++ { diff := math.Abs(float64(oneShotTail[i]) - float64(batchedTail[i])) if diff > maxDiff { maxDiff = diff } sumDiffSq += diff * diff } rmsDiff := math.Sqrt(sumDiffSq / float64(minLen)) GinkgoWriter.Printf("Persistent decoder: maxDiff=%.0f, rmsDiff=%.1f\n", maxDiff, rmsDiff) // Tight threshold: with persistent decoder and fixed resampler, // the output should be very close to one-shot Expect(maxDiff).To(BeNumerically("<", 500), "persistent decoder batched path diverges too much from one-shot") Expect(rmsDiff).To(BeNumerically("<", 50), "RMS deviation too high between batched and one-shot") }) It("fresh decoder per batch produces worse quality than persistent", func() { // This test proves the value of persistent decoders by showing // that fresh decoders produce larger deviations at batch boundaries. sine := generateSineWave(440, 48000, 48000*2) pcmBytes := sound.Int16toBytesLE(sine) encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) // Ground truth: one-shot decode decAll, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames, Options: map[string]string{"session_id": "ref"}, }) Expect(err).ToNot(HaveOccurred()) refSamples := sound.BytesToInt16sLE(decAll.PcmData) const framesPerBatch = 15 // Path A: persistent decoder var persistentSamples []int16 for i := 0; i < len(encResult.Frames); i += framesPerBatch { end := min(i+framesPerBatch, len(encResult.Frames)) dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames[i:end], Options: map[string]string{"session_id": "persistent"}, }) Expect(err).ToNot(HaveOccurred()) persistentSamples = append(persistentSamples, sound.BytesToInt16sLE(dec.PcmData)...) } // Path B: fresh decoder per batch (no session_id) var freshSamples []int16 for i := 0; i < len(encResult.Frames); i += framesPerBatch { end := min(i+framesPerBatch, len(encResult.Frames)) dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames[i:end], }) Expect(err).ToNot(HaveOccurred()) freshSamples = append(freshSamples, sound.BytesToInt16sLE(dec.PcmData)...) } // Compare both to reference skip := 48000 * 100 / 1000 refTail := refSamples[skip:] persistentTail := persistentSamples[skip:] freshTail := freshSamples[skip:] minLen := min(len(refTail), min(len(persistentTail), len(freshTail))) var persistentMaxDiff, freshMaxDiff float64 for i := 0; i < minLen; i++ { pd := math.Abs(float64(refTail[i]) - float64(persistentTail[i])) fd := math.Abs(float64(refTail[i]) - float64(freshTail[i])) if pd > persistentMaxDiff { persistentMaxDiff = pd } if fd > freshMaxDiff { freshMaxDiff = fd } } GinkgoWriter.Printf("vs reference: persistent maxDiff=%.0f, fresh maxDiff=%.0f\n", persistentMaxDiff, freshMaxDiff) // Persistent decoder should be closer to reference than fresh Expect(persistentMaxDiff).To(BeNumerically("<=", freshMaxDiff), "persistent decoder should match reference at least as well as fresh decoder") }) It("checks for PCM discontinuities at batch boundaries", func() { // Encode 2 seconds, decode in batches, resample, and check // for anomalous jumps at the exact batch boundaries in the output sine := generateSineWave(440, 48000, 48000*2) pcmBytes := sound.Int16toBytesLE(sine) encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) const framesPerBatch = 15 sessionID := "boundary-check" var batchedOutput []int16 var batchBoundaries []int // indices where batch boundaries fall in output for i := 0; i < len(encResult.Frames); i += framesPerBatch { end := min(i+framesPerBatch, len(encResult.Frames)) dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames[i:end], Options: map[string]string{"session_id": sessionID}, }) Expect(err).ToNot(HaveOccurred()) batchSamples := sound.BytesToInt16sLE(dec.PcmData) batchResampled := sound.ResampleInt16(batchSamples, 48000, 16000) if len(batchedOutput) > 0 { batchBoundaries = append(batchBoundaries, len(batchedOutput)) } batchedOutput = append(batchedOutput, batchResampled...) } GinkgoWriter.Printf("Output: %d samples, %d batch boundaries\n", len(batchedOutput), len(batchBoundaries)) // For each batch boundary, check if the sample-to-sample jump // is anomalously large compared to neighboring deltas for bIdx, boundary := range batchBoundaries { if boundary < 10 || boundary+10 >= len(batchedOutput) { continue } jump := math.Abs(float64(batchedOutput[boundary]) - float64(batchedOutput[boundary-1])) // Compute average delta in the 20-sample neighborhood (excluding boundary) var avgDelta float64 count := 0 for i := boundary - 10; i < boundary+10; i++ { if i == boundary-1 || i == boundary { continue } if i+1 < len(batchedOutput) { avgDelta += math.Abs(float64(batchedOutput[i+1]) - float64(batchedOutput[i])) count++ } } if count > 0 { avgDelta /= float64(count) } ratio := 0.0 if avgDelta > 0 { ratio = jump / avgDelta } GinkgoWriter.Printf("Boundary %d (idx %d): jump=%.0f, avg_delta=%.0f, ratio=%.1f\n", bIdx, boundary, jump, avgDelta, ratio) // The boundary jump should not be more than 5x the average // (with codec artifacts, some variation is expected) Expect(jump).To(BeNumerically("<=", avgDelta*5+1), fmt.Sprintf("discontinuity at batch boundary %d: jump=%.0f vs avg=%.0f (ratio=%.1f)", bIdx, jump, avgDelta, ratio)) } }) It("maintains sine wave phase continuity across batches", func() { sine := generateSineWave(440, 48000, 48000*2) // 2 seconds pcmBytes := sound.Int16toBytesLE(sine) encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) // Decode in batches with persistent decoder, resample each const framesPerBatch = 15 sessionID := "phase-test" var fullOutput []int16 for i := 0; i < len(encResult.Frames); i += framesPerBatch { end := min(i+framesPerBatch, len(encResult.Frames)) dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames[i:end], Options: map[string]string{"session_id": sessionID}, }) Expect(err).ToNot(HaveOccurred()) samples := sound.BytesToInt16sLE(dec.PcmData) resampled := sound.ResampleInt16(samples, 48000, 16000) fullOutput = append(fullOutput, resampled...) } // Check zero-crossing regularity after startup transient skip := 16000 * 200 / 1000 // skip first 200ms tail := fullOutput[skip:] var crossingPositions []int for i := 1; i < len(tail); i++ { if (tail[i-1] >= 0 && tail[i] < 0) || (tail[i-1] < 0 && tail[i] >= 0) { crossingPositions = append(crossingPositions, i) } } Expect(crossingPositions).ToNot(BeEmpty(), "no zero crossings found") var intervals []float64 for i := 1; i < len(crossingPositions); i++ { intervals = append(intervals, float64(crossingPositions[i]-crossingPositions[i-1])) } var sum float64 for _, v := range intervals { sum += v } mean := sum / float64(len(intervals)) var variance float64 for _, v := range intervals { d := v - mean variance += d * d } stddev := math.Sqrt(variance / float64(len(intervals))) GinkgoWriter.Printf("Zero-crossing intervals: mean=%.2f stddev=%.2f CV=%.3f (expected period ~%.1f)\n", mean, stddev, stddev/mean, 16000.0/440.0/2.0) Expect(stddev / mean).To(BeNumerically("<", 0.15), fmt.Sprintf("irregular zero crossings suggest discontinuity: CV=%.3f", stddev/mean)) // Also check frequency is correct freq := estimateFrequency(tail, 16000) GinkgoWriter.Printf("Estimated frequency: %.0f Hz (expected 440)\n", freq) Expect(freq).To(BeNumerically("~", 440, 20)) }) It("produces identical resampled output for batched vs one-shot resample", func() { // Isolate the resampler from the codec: decode once, then compare // one-shot resample vs batched resample of the same PCM. sine := generateSineWave(440, 48000, 48000*3) pcmBytes := sound.Int16toBytesLE(sine) encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames, Options: map[string]string{"session_id": "resample-test"}, }) Expect(err).ToNot(HaveOccurred()) allSamples := sound.BytesToInt16sLE(decResult.PcmData) // One-shot resample oneShot := sound.ResampleInt16(allSamples, 48000, 16000) // Batched resample (300ms chunks at 48kHz = 14400 samples) batchSize := 48000 * 300 / 1000 var batched []int16 for offset := 0; offset < len(allSamples); offset += batchSize { end := min(offset+batchSize, len(allSamples)) chunk := sound.ResampleInt16(allSamples[offset:end], 48000, 16000) batched = append(batched, chunk...) } Expect(len(batched)).To(Equal(len(oneShot)), fmt.Sprintf("length mismatch: batched=%d oneshot=%d", len(batched), len(oneShot))) // Every sample must be identical — the resampler is deterministic var maxDiff float64 for i := 0; i < len(oneShot); i++ { diff := math.Abs(float64(oneShot[i]) - float64(batched[i])) if diff > maxDiff { maxDiff = diff } } GinkgoWriter.Printf("Resample-only: batched vs one-shot maxDiff=%.0f\n", maxDiff) Expect(maxDiff).To(BeNumerically("==", 0), "batched resample should produce identical output to one-shot resample") }) It("writes WAV files for manual inspection", func() { // This test writes WAV files of the batched vs one-shot pipeline // so you can visually/audibly inspect for discontinuities. tmpDir := GinkgoT().TempDir() sine := generateSineWave(440, 48000, 48000*3) // 3 seconds pcmBytes := sound.Int16toBytesLE(sine) encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) // One-shot path (reference) decAll, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames, Options: map[string]string{"session_id": "wav-ref"}, }) Expect(err).ToNot(HaveOccurred()) refSamples := sound.BytesToInt16sLE(decAll.PcmData) refResampled := sound.ResampleInt16(refSamples, 48000, 16000) // Batched path (production simulation) const framesPerBatch = 15 var batchedResampled []int16 for i := 0; i < len(encResult.Frames); i += framesPerBatch { end := min(i+framesPerBatch, len(encResult.Frames)) dec, err := o.AudioDecode(&pb.AudioDecodeRequest{ Frames: encResult.Frames[i:end], Options: map[string]string{"session_id": "wav-batched"}, }) Expect(err).ToNot(HaveOccurred()) samples := sound.BytesToInt16sLE(dec.PcmData) resampled := sound.ResampleInt16(samples, 48000, 16000) batchedResampled = append(batchedResampled, resampled...) } // Write WAV files writeWAV := func(path string, samples []int16, sampleRate int) { dataLen := len(samples) * 2 hdr := make([]byte, 44) copy(hdr[0:4], "RIFF") binary.LittleEndian.PutUint32(hdr[4:8], uint32(36+dataLen)) copy(hdr[8:12], "WAVE") copy(hdr[12:16], "fmt ") binary.LittleEndian.PutUint32(hdr[16:20], 16) // chunk size binary.LittleEndian.PutUint16(hdr[20:22], 1) // PCM binary.LittleEndian.PutUint16(hdr[22:24], 1) // mono binary.LittleEndian.PutUint32(hdr[24:28], uint32(sampleRate)) // sample rate binary.LittleEndian.PutUint32(hdr[28:32], uint32(sampleRate*2)) // byte rate binary.LittleEndian.PutUint16(hdr[32:34], 2) // block align binary.LittleEndian.PutUint16(hdr[34:36], 16) // bits per sample copy(hdr[36:40], "data") binary.LittleEndian.PutUint32(hdr[40:44], uint32(dataLen)) f, err := os.Create(path) Expect(err).ToNot(HaveOccurred()) defer f.Close() _, err = f.Write(hdr) Expect(err).ToNot(HaveOccurred()) _, err = f.Write(sound.Int16toBytesLE(samples)) Expect(err).ToNot(HaveOccurred()) } refPath := filepath.Join(tmpDir, "oneshot_16k.wav") batchedPath := filepath.Join(tmpDir, "batched_16k.wav") writeWAV(refPath, refResampled, 16000) writeWAV(batchedPath, batchedResampled, 16000) GinkgoWriter.Printf("WAV files written for manual inspection:\n") GinkgoWriter.Printf(" Reference: %s\n", refPath) GinkgoWriter.Printf(" Batched: %s\n", batchedPath) GinkgoWriter.Printf(" Ref samples: %d, Batched samples: %d\n", len(refResampled), len(batchedResampled)) }) }) It("produces frames decodable by ffmpeg (cross-library compat)", func() { ffmpegPath, err := exec.LookPath("ffmpeg") if err != nil { Skip("ffmpeg not found") } sine := generateSineWave(440, 48000, 48000) pcmBytes := sound.Int16toBytesLE(sine) result, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcmBytes, SampleRate: 48000, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) Expect(result.Frames).ToNot(BeEmpty()) GinkgoWriter.Printf("opus-go produced %d frames (first frame %d bytes)\n", len(result.Frames), len(result.Frames[0])) tmpDir := GinkgoT().TempDir() oggPath := filepath.Join(tmpDir, "opus_go_output.ogg") Expect(writeOggOpus(oggPath, result.Frames, 48000, 1)).To(Succeed()) decodedWavPath := filepath.Join(tmpDir, "ffmpeg_decoded.wav") cmd := exec.Command(ffmpegPath, "-y", "-i", oggPath, "-ar", "48000", "-ac", "1", "-c:a", "pcm_s16le", decodedWavPath) out, err := cmd.CombinedOutput() Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("ffmpeg failed to decode opus-go output: %s", out)) decodedData, err := os.ReadFile(decodedWavPath) Expect(err).ToNot(HaveOccurred()) decodedPCM, sr := parseTestWAV(decodedData) Expect(sr).ToNot(BeZero(), "ffmpeg output has no WAV header") decodedSamples := sound.BytesToInt16sLE(decodedPCM) skip := min(len(decodedSamples)/4, sr*100/1000) if skip >= len(decodedSamples) { skip = 0 } tail := decodedSamples[skip:] rms := computeRMS(tail) GinkgoWriter.Printf("ffmpeg decoded opus-go output: %d samples at %dHz, RMS=%.1f\n", len(decodedSamples), sr, rms) Expect(rms).To(BeNumerically(">=", 50), "ffmpeg decoded RMS is too low — opus-go frames are likely incompatible with standard decoders") }) It("delivers audio through a full WebRTC pipeline", func() { const ( toneFreq = 440.0 toneSampleRate = 24000 toneDuration = 1 toneAmplitude = 16000 toneNumSamples = toneSampleRate * toneDuration ) pcm := make([]byte, toneNumSamples*2) for i := 0; i < toneNumSamples; i++ { sample := int16(toneAmplitude * math.Sin(2*math.Pi*toneFreq*float64(i)/float64(toneSampleRate))) binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) } encResult, err := o.AudioEncode(&pb.AudioEncodeRequest{ PcmData: pcm, SampleRate: toneSampleRate, Channels: 1, }) Expect(err).ToNot(HaveOccurred()) Expect(encResult.Frames).ToNot(BeEmpty()) GinkgoWriter.Printf("Encoded %d Opus frames from %d PCM samples at %dHz\n", len(encResult.Frames), toneNumSamples, toneSampleRate) // Create sender PeerConnection senderME := &webrtc.MediaEngine{} Expect(senderME.RegisterDefaultCodecs()).To(Succeed()) senderAPI := webrtc.NewAPI(webrtc.WithMediaEngine(senderME)) senderPC, err := senderAPI.NewPeerConnection(webrtc.Configuration{}) Expect(err).ToNot(HaveOccurred()) defer senderPC.Close() audioTrack, err := webrtc.NewTrackLocalStaticRTP( webrtc.RTPCodecCapability{ MimeType: webrtc.MimeTypeOpus, ClockRate: 48000, Channels: 2, }, "audio", "test", ) Expect(err).ToNot(HaveOccurred()) rtpSender, err := senderPC.AddTrack(audioTrack) Expect(err).ToNot(HaveOccurred()) go func() { buf := make([]byte, 1500) for { if _, _, err := rtpSender.Read(buf); err != nil { return } } }() // Create receiver PeerConnection receiverME := &webrtc.MediaEngine{} Expect(receiverME.RegisterDefaultCodecs()).To(Succeed()) receiverAPI := webrtc.NewAPI(webrtc.WithMediaEngine(receiverME)) receiverPC, err := receiverAPI.NewPeerConnection(webrtc.Configuration{}) Expect(err).ToNot(HaveOccurred()) defer receiverPC.Close() type receivedPacket struct { seqNum uint16 timestamp uint32 marker bool payload []byte } var ( receivedMu sync.Mutex receivedPackets []receivedPacket trackDone = make(chan struct{}) ) receiverPC.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { defer close(trackDone) for { pkt, _, err := track.ReadRTP() if err != nil { return } payload := make([]byte, len(pkt.Payload)) copy(payload, pkt.Payload) receivedMu.Lock() receivedPackets = append(receivedPackets, receivedPacket{ seqNum: pkt.Header.SequenceNumber, timestamp: pkt.Header.Timestamp, marker: pkt.Header.Marker, payload: payload, }) receivedMu.Unlock() } }) // Exchange SDP offer, err := senderPC.CreateOffer(nil) Expect(err).ToNot(HaveOccurred()) Expect(senderPC.SetLocalDescription(offer)).To(Succeed()) senderGatherDone := webrtc.GatheringCompletePromise(senderPC) Eventually(senderGatherDone, 5*time.Second).Should(BeClosed()) Expect(receiverPC.SetRemoteDescription(*senderPC.LocalDescription())).To(Succeed()) answer, err := receiverPC.CreateAnswer(nil) Expect(err).ToNot(HaveOccurred()) Expect(receiverPC.SetLocalDescription(answer)).To(Succeed()) receiverGatherDone := webrtc.GatheringCompletePromise(receiverPC) Eventually(receiverGatherDone, 5*time.Second).Should(BeClosed()) Expect(senderPC.SetRemoteDescription(*receiverPC.LocalDescription())).To(Succeed()) // Wait for connection connected := make(chan struct{}) senderPC.OnConnectionStateChange(func(s webrtc.PeerConnectionState) { if s == webrtc.PeerConnectionStateConnected { select { case <-connected: default: close(connected) } } }) Eventually(connected, 5*time.Second).Should(BeClosed()) // Send test tone via RTP const samplesPerFrame = 960 seqNum := uint16(rand.UintN(65536)) timestamp := rand.Uint32() marker := true ticker := time.NewTicker(20 * time.Millisecond) defer ticker.Stop() for i, frame := range encResult.Frames { pkt := &rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: marker, SequenceNumber: seqNum, Timestamp: timestamp, }, Payload: frame, } seqNum++ timestamp += samplesPerFrame marker = false Expect(audioTrack.WriteRTP(pkt)).To(Succeed(), fmt.Sprintf("WriteRTP frame %d", i)) if i < len(encResult.Frames)-1 { <-ticker.C } } // Wait for packets to arrive time.Sleep(500 * time.Millisecond) senderPC.Close() select { case <-trackDone: case <-time.After(2 * time.Second): } // Decode received Opus frames via the backend receivedMu.Lock() pkts := make([]receivedPacket, len(receivedPackets)) copy(pkts, receivedPackets) receivedMu.Unlock() Expect(pkts).ToNot(BeEmpty(), "no RTP packets received") var receivedFrames [][]byte for _, pkt := range pkts { receivedFrames = append(receivedFrames, pkt.payload) } decResult, err := o.AudioDecode(&pb.AudioDecodeRequest{Frames: receivedFrames}) Expect(err).ToNot(HaveOccurred()) allDecoded := sound.BytesToInt16sLE(decResult.PcmData) Expect(allDecoded).ToNot(BeEmpty(), "no decoded samples") // Analyse RTP packet delivery frameLoss := len(encResult.Frames) - len(pkts) seqGaps := 0 for i := 1; i < len(pkts); i++ { expected := pkts[i-1].seqNum + 1 if pkts[i].seqNum != expected { seqGaps++ } } markerCount := 0 for _, pkt := range pkts { if pkt.marker { markerCount++ } } GinkgoWriter.Println("── RTP Delivery ──") GinkgoWriter.Printf(" Frames sent: %d\n", len(encResult.Frames)) GinkgoWriter.Printf(" Packets recv: %d\n", len(pkts)) GinkgoWriter.Printf(" Frame loss: %d\n", frameLoss) GinkgoWriter.Printf(" Sequence gaps: %d\n", seqGaps) GinkgoWriter.Printf(" Marker packets: %d (expect 1)\n", markerCount) // Audio quality metrics skip := 48000 * 100 / 1000 if skip > len(allDecoded)/2 { skip = len(allDecoded) / 4 } tail := allDecoded[skip:] rms := computeRMS(tail) freq := estimateFrequency(tail, 48000) thd := computeTHD(tail, toneFreq, 48000, 10) GinkgoWriter.Println("── Audio Quality ──") GinkgoWriter.Printf(" Decoded samples: %d (%.1f ms at 48kHz)\n", len(allDecoded), float64(len(allDecoded))/48.0) GinkgoWriter.Printf(" RMS level: %.1f\n", rms) GinkgoWriter.Printf(" Peak frequency: %.0f Hz (expected %.0f Hz)\n", freq, toneFreq) GinkgoWriter.Printf(" THD (h2-h10): %.1f%%\n", thd) Expect(frameLoss).To(BeZero(), "lost frames in localhost transport") Expect(seqGaps).To(BeZero(), "sequence number gaps detected") Expect(markerCount).To(Equal(1), "expected exactly 1 marker packet") Expect(rms).To(BeNumerically(">=", 50), "signal appears silent or severely attenuated") Expect(freq).To(BeNumerically("~", toneFreq, 20), fmt.Sprintf("peak frequency %.0f Hz deviates from expected", freq)) Expect(thd).To(BeNumerically("<", 50), "signal is severely distorted") }) }) ================================================ FILE: backend/go/opus/package.sh ================================================ #!/bin/bash set -e CURDIR=$(dirname "$(realpath $0)") mkdir -p $CURDIR/package/lib cp -avf $CURDIR/opus $CURDIR/package/ cp -avf $CURDIR/run.sh $CURDIR/package/ # Copy the opus shim library cp -avf $CURDIR/libopusshim.so $CURDIR/package/lib/ # Copy system libopus if command -v pkg-config >/dev/null 2>&1 && pkg-config --exists opus; then LIBOPUS_DIR=$(pkg-config --variable=libdir opus) cp -avfL $LIBOPUS_DIR/libopus.so* $CURDIR/package/lib/ 2>/dev/null || true fi # Detect architecture and copy appropriate libraries if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then echo "Detected x86_64 architecture, copying x86_64 libraries..." cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then echo "Detected ARM64 architecture, copying ARM64 libraries..." cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 else echo "Warning: Could not detect architecture for system library bundling" fi echo "Packaging completed successfully" ls -liah $CURDIR/package/ ls -liah $CURDIR/package/lib/ ================================================ FILE: backend/go/opus/run.sh ================================================ #!/bin/bash set -ex CURDIR=$(dirname "$(realpath $0)") export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH export OPUS_SHIM_LIBRARY=$CURDIR/lib/libopusshim.so # If there is a lib/ld.so, use it if [ -f $CURDIR/lib/ld.so ]; then echo "Using lib/ld.so" exec $CURDIR/lib/ld.so $CURDIR/opus "$@" fi exec $CURDIR/opus "$@" ================================================ FILE: backend/go/piper/Makefile ================================================ # go-piper version PIPER_REPO?=https://github.com/mudler/go-piper PIPER_VERSION?=e10ca041a885d4a8f3871d52924b47792d5e5aa0 CURRENT_DIR=$(abspath ./) GOCMD=go PIPER_CGO_CXXFLAGS+=-I$(CURRENT_DIR)/sources/go-piper/piper/src/cpp -I$(CURRENT_DIR)/sources/go-piper/piper/build/fi/include -I$(CURRENT_DIR)/sources/go-piper/piper/build/pi/include -I$(CURRENT_DIR)/sources/go-piper/piper/build/si/include PIPER_CGO_LDFLAGS+=-L$(CURRENT_DIR)/sources/go-piper/piper/build/fi/lib -L$(CURRENT_DIR)/sources/go-piper/piper/build/pi/lib -L$(CURRENT_DIR)/sources/go-piper/piper/build/si/lib -lfmt -lspdlog -lucd ## go-piper sources/go-piper: mkdir -p sources/go-piper cd sources/go-piper && \ git init && \ git remote add origin $(PIPER_REPO) && \ git fetch origin && \ git checkout $(PIPER_VERSION) && \ git submodule update --init --recursive --depth 1 --single-branch sources/go-piper/libpiper_binding.a: sources/go-piper $(MAKE) -C sources/go-piper libpiper_binding.a example/main piper.o espeak-ng-data: sources/go-piper sources/go-piper/libpiper_binding.a mkdir -p espeak-ng-data @cp -rf sources/go-piper/piper-phonemize/pi/share/espeak-ng-data/. espeak-ng-data piper: sources/go-piper sources/go-piper/libpiper_binding.a espeak-ng-data $(GOCMD) mod edit -replace github.com/mudler/go-piper=$(CURRENT_DIR)/sources/go-piper CGO_CXXFLAGS="$(PIPER_CGO_CXXFLAGS)" CGO_LDFLAGS="$(PIPER_CGO_LDFLAGS)" LIBRARY_PATH=$(CURRENT_DIR)/sources/go-piper \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o piper ./ package: bash package.sh build: piper package clean: rm -f piper ================================================ FILE: backend/go/piper/main.go ================================================ package main // Note: this is started internally by LocalAI and a server is allocated for each model import ( "flag" grpc "github.com/mudler/LocalAI/pkg/grpc" ) var ( addr = flag.String("addr", "localhost:50051", "the address to connect to") ) func main() { flag.Parse() if err := grpc.StartServer(*addr, &Piper{}); err != nil { panic(err) } } ================================================ FILE: backend/go/piper/package.sh ================================================ #!/bin/bash # Script to copy the appropriate libraries based on architecture # This script is used in the final stage of the Dockerfile set -e CURDIR=$(dirname "$(realpath $0)") # Create lib directory mkdir -p $CURDIR/package/lib cp -avf $CURDIR/piper $CURDIR/package/ cp -avf $CURDIR/espeak-ng-data $CURDIR/package/ cp -rfv $CURDIR/run.sh $CURDIR/package/ cp -rfLv $CURDIR/sources/go-piper/piper-phonemize/pi/lib/* $CURDIR/package/lib/ # Detect architecture and copy appropriate libraries if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then # x86_64 architecture echo "Detected x86_64 architecture, copying x86_64 libraries..." cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then # ARM64 architecture echo "Detected ARM64 architecture, copying ARM64 libraries..." cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 else echo "Error: Could not detect architecture" exit 1 fi echo "Packaging completed successfully" ls -liah $CURDIR/package/ ls -liah $CURDIR/package/lib/ ================================================ FILE: backend/go/piper/piper.go ================================================ package main // This is a wrapper to statisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( "fmt" "os" "path/filepath" "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" piper "github.com/mudler/go-piper" ) type Piper struct { base.SingleThread piper *PiperB } func (sd *Piper) Load(opts *pb.ModelOptions) error { if filepath.Ext(opts.ModelFile) != ".onnx" { return fmt.Errorf("unsupported model type %s (should end with .onnx)", opts.ModelFile) } var err error // Note: the Model here is a path to a directory containing the model files sd.piper, err = New(os.Getenv("ESPEAK_NG_DATA")) return err } func (sd *Piper) TTS(opts *pb.TTSRequest) error { return sd.piper.TTS(opts.Text, opts.Model, opts.Dst) } type PiperB struct { assetDir string } func New(assetDir string) (*PiperB, error) { if _, err := os.Stat(assetDir); err != nil { return nil, err } return &PiperB{ assetDir: assetDir, }, nil } func (s *PiperB) TTS(text, model, dst string) error { return piper.TextToWav(text, model, s.assetDir, "", dst) } ================================================ FILE: backend/go/piper/run.sh ================================================ #!/bin/bash set -ex CURDIR=$(dirname "$(realpath $0)") export ESPEAK_NG_DATA=$CURDIR/espeak-ng-data export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH # If there is a lib/ld.so, use it if [ -f $CURDIR/lib/ld.so ]; then echo "Using lib/ld.so" exec $CURDIR/lib/ld.so $CURDIR/piper "$@" fi exec $CURDIR/piper "$@" ================================================ FILE: backend/go/silero-vad/Makefile ================================================ CURRENT_DIR=$(abspath ./) GOCMD=go ONNX_VERSION?=1.20.0 ONNX_ARCH?=x64 ONNX_OS?=linux # Detect if we are running on arm64 ifneq (,$(findstring aarch64,$(shell uname -m))) ONNX_ARCH=aarch64 endif ifeq ($(OS),Darwin) ONNX_OS=osx ifneq (,$(findstring aarch64,$(shell uname -m))) ONNX_ARCH=arm64 else ifneq (,$(findstring arm64,$(shell uname -m))) ONNX_ARCH=arm64 else ONNX_ARCH=x86_64 endif endif sources/onnxruntime: mkdir -p sources/onnxruntime curl -L https://github.com/microsoft/onnxruntime/releases/download/v$(ONNX_VERSION)/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz -o sources/onnxruntime/onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz cd sources/onnxruntime && tar -xvf onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz && rm onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION).tgz cd sources/onnxruntime && mv onnxruntime-$(ONNX_OS)-$(ONNX_ARCH)-$(ONNX_VERSION)/* ./ backend-assets/lib/libonnxruntime.so.1: sources/onnxruntime mkdir -p backend-assets/lib cp -rfLv sources/onnxruntime/lib/* backend-assets/lib/ ifeq ($(OS),Darwin) mv backend-assets/lib/libonnxruntime.$(ONNX_VERSION).dylib backend-assets/lib/libonnxruntime.dylib else mv backend-assets/lib/libonnxruntime.so.$(ONNX_VERSION) backend-assets/lib/libonnxruntime.so.1 endif silero-vad: backend-assets/lib/libonnxruntime.so.1 CGO_LDFLAGS="$(CGO_LDFLAGS)" CPATH="$(CPATH):$(CURRENT_DIR)/sources/onnxruntime/include/" LIBRARY_PATH=$(CURRENT_DIR)/backend-assets/lib \ $(GOCMD) build -ldflags "$(LD_FLAGS)" -tags "$(GO_TAGS)" -o silero-vad ./ package: bash package.sh build: silero-vad package clean: rm -f silero-vad ================================================ FILE: backend/go/silero-vad/main.go ================================================ package main // Note: this is started internally by LocalAI and a server is allocated for each model import ( "flag" grpc "github.com/mudler/LocalAI/pkg/grpc" ) var ( addr = flag.String("addr", "localhost:50051", "the address to connect to") ) func main() { flag.Parse() if err := grpc.StartServer(*addr, &VAD{}); err != nil { panic(err) } } ================================================ FILE: backend/go/silero-vad/package.sh ================================================ #!/bin/bash # Script to copy the appropriate libraries based on architecture # This script is used in the final stage of the Dockerfile set -e CURDIR=$(dirname "$(realpath $0)") # Create lib directory mkdir -p $CURDIR/package/lib cp -avf $CURDIR/silero-vad $CURDIR/package/ cp -avf $CURDIR/run.sh $CURDIR/package/ cp -rfLv $CURDIR/backend-assets/lib/* $CURDIR/package/lib/ # Detect architecture and copy appropriate libraries if [ -f "/lib64/ld-linux-x86-64.so.2" ]; then # x86_64 architecture echo "Detected x86_64 architecture, copying x86_64 libraries..." cp -arfLv /lib64/ld-linux-x86-64.so.2 $CURDIR/package/lib/ld.so cp -arfLv /lib/x86_64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/x86_64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/x86_64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 cp -arfLv /lib/x86_64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/x86_64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/x86_64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 cp -arfLv /lib/x86_64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 cp -arfLv /lib/x86_64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 elif [ -f "/lib/ld-linux-aarch64.so.1" ]; then # ARM64 architecture echo "Detected ARM64 architecture, copying ARM64 libraries..." cp -arfLv /lib/ld-linux-aarch64.so.1 $CURDIR/package/lib/ld.so cp -arfLv /lib/aarch64-linux-gnu/libc.so.6 $CURDIR/package/lib/libc.so.6 cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/aarch64-linux-gnu/libm.so.6 $CURDIR/package/lib/libm.so.6 cp -arfLv /lib/aarch64-linux-gnu/libgomp.so.1 $CURDIR/package/lib/libgomp.so.1 cp -arfLv /lib/aarch64-linux-gnu/libgcc_s.so.1 $CURDIR/package/lib/libgcc_s.so.1 cp -arfLv /lib/aarch64-linux-gnu/libstdc++.so.6 $CURDIR/package/lib/libstdc++.so.6 cp -arfLv /lib/aarch64-linux-gnu/libdl.so.2 $CURDIR/package/lib/libdl.so.2 cp -arfLv /lib/aarch64-linux-gnu/librt.so.1 $CURDIR/package/lib/librt.so.1 cp -arfLv /lib/aarch64-linux-gnu/libpthread.so.0 $CURDIR/package/lib/libpthread.so.0 else echo "Error: Could not detect architecture" exit 1 fi echo "Packaging completed successfully" ls -liah $CURDIR/package/ ls -liah $CURDIR/package/lib/ ================================================ FILE: backend/go/silero-vad/run.sh ================================================ #!/bin/bash set -ex CURDIR=$(dirname "$(realpath $0)") export LD_LIBRARY_PATH=$CURDIR/lib:$LD_LIBRARY_PATH # If there is a lib/ld.so, use it if [ -f $CURDIR/lib/ld.so ]; then echo "Using lib/ld.so" exec $CURDIR/lib/ld.so $CURDIR/silero-vad "$@" fi exec $CURDIR/silero-vad "$@" ================================================ FILE: backend/go/silero-vad/vad.go ================================================ package main // This is a wrapper to statisfy the GRPC service interface // It is meant to be used by the main executable that is the server for the specific backend type (falcon, gpt3, etc) import ( "fmt" "github.com/mudler/LocalAI/pkg/grpc/base" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/streamer45/silero-vad-go/speech" ) type VAD struct { base.SingleThread detector *speech.Detector } func (vad *VAD) Load(opts *pb.ModelOptions) error { v, err := speech.NewDetector(speech.DetectorConfig{ ModelPath: opts.ModelFile, SampleRate: 16000, //WindowSize: 1024, Threshold: 0.5, MinSilenceDurationMs: 100, SpeechPadMs: 30, }) if err != nil { return fmt.Errorf("create silero detector: %w", err) } vad.detector = v return err } func (vad *VAD) VAD(req *pb.VADRequest) (pb.VADResponse, error) { audio := req.Audio if err := vad.detector.Reset(); err != nil { return pb.VADResponse{}, fmt.Errorf("reset: %w", err) } segments, err := vad.detector.Detect(audio) if err != nil { return pb.VADResponse{}, fmt.Errorf("detect: %w", err) } vadSegments := []*pb.VADSegment{} for _, s := range segments { vadSegments = append(vadSegments, &pb.VADSegment{ Start: float32(s.SpeechStartAt), End: float32(s.SpeechEndAt), }) } return pb.VADResponse{ Segments: vadSegments, }, nil } ================================================ FILE: backend/index.yaml ================================================ --- ## metas - &llamacpp name: "llama-cpp" alias: "llama-cpp" license: mit icon: https://user-images.githubusercontent.com/1991296/230134379-7181e485-c521-4d23-a0d6-f7b3b61ba524.png description: | LLM inference in C/C++ urls: - https://github.com/ggerganov/llama.cpp tags: - text-to-text - LLM - CPU - GPU - Metal - CUDA - HIP capabilities: default: "cpu-llama-cpp" nvidia: "cuda12-llama-cpp" intel: "intel-sycl-f16-llama-cpp" amd: "rocm-llama-cpp" metal: "metal-llama-cpp" vulkan: "vulkan-llama-cpp" nvidia-l4t: "nvidia-l4t-arm64-llama-cpp" nvidia-cuda-13: "cuda13-llama-cpp" nvidia-cuda-12: "cuda12-llama-cpp" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-llama-cpp" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-llama-cpp" - &whispercpp name: "whisper" alias: "whisper" license: mit icon: https://user-images.githubusercontent.com/1991296/235238348-05d0f6a4-da44-4900-a1de-d0707e75b763.jpeg description: | Port of OpenAI's Whisper model in C/C++ urls: - https://github.com/ggml-org/whisper.cpp tags: - audio-transcription - CPU - GPU - CUDA - HIP capabilities: default: "cpu-whisper" nvidia: "cuda12-whisper" intel: "intel-sycl-f16-whisper" metal: "metal-whisper" amd: "rocm-whisper" vulkan: "vulkan-whisper" nvidia-l4t: "nvidia-l4t-arm64-whisper" nvidia-cuda-13: "cuda13-whisper" nvidia-cuda-12: "cuda12-whisper" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-whisper" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-whisper" - &voxtral name: "voxtral" alias: "voxtral" license: mit description: | Voxtral Realtime 4B Pure C speech-to-text inference engine urls: - https://github.com/mudler/voxtral.c tags: - audio-transcription - CPU - Metal capabilities: default: "cpu-voxtral" metal-darwin-arm64: "metal-voxtral" - &stablediffusionggml name: "stablediffusion-ggml" alias: "stablediffusion-ggml" license: mit icon: https://github.com/leejet/stable-diffusion.cpp/raw/master/assets/cat_with_sd_cpp_42.png description: | Stable Diffusion and Flux in pure C/C++ urls: - https://github.com/leejet/stable-diffusion.cpp tags: - image-generation - CPU - GPU - Metal - CUDA - HIP capabilities: default: "cpu-stablediffusion-ggml" nvidia: "cuda12-stablediffusion-ggml" intel: "intel-sycl-f16-stablediffusion-ggml" # amd: "rocm-stablediffusion-ggml" vulkan: "vulkan-stablediffusion-ggml" nvidia-l4t: "nvidia-l4t-arm64-stablediffusion-ggml" metal: "metal-stablediffusion-ggml" nvidia-cuda-13: "cuda13-stablediffusion-ggml" nvidia-cuda-12: "cuda12-stablediffusion-ggml" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-stablediffusion-ggml" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-stablediffusion-ggml" - &rfdetr name: "rfdetr" alias: "rfdetr" license: apache-2.0 icon: https://avatars.githubusercontent.com/u/53104118?s=200&v=4 description: | RF-DETR is a real-time, transformer-based object detection model architecture developed by Roboflow and released under the Apache 2.0 license. RF-DETR is the first real-time model to exceed 60 AP on the Microsoft COCO benchmark alongside competitive performance at base sizes. It also achieves state-of-the-art performance on RF100-VL, an object detection benchmark that measures model domain adaptability to real world problems. RF-DETR is fastest and most accurate for its size when compared current real-time objection models. RF-DETR is small enough to run on the edge using Inference, making it an ideal model for deployments that need both strong accuracy and real-time performance. urls: - https://github.com/roboflow/rf-detr tags: - object-detection - rfdetr - gpu - cpu capabilities: nvidia: "cuda12-rfdetr" intel: "intel-rfdetr" #amd: "rocm-rfdetr" nvidia-l4t: "nvidia-l4t-arm64-rfdetr" metal: "metal-rfdetr" default: "cpu-rfdetr" nvidia-cuda-13: "cuda13-rfdetr" nvidia-cuda-12: "cuda12-rfdetr" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-rfdetr" - &vllm name: "vllm" license: apache-2.0 urls: - https://github.com/vllm-project/vllm tags: - text-to-text - multimodal - GPTQ - AWQ - AutoRound - INT4 - INT8 - FP8 icon: https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png description: | vLLM is a fast and easy-to-use library for LLM inference and serving. Originally developed in the Sky Computing Lab at UC Berkeley, vLLM has evolved into a community-driven project with contributions from both academia and industry. vLLM is fast with: State-of-the-art serving throughput Efficient management of attention key and value memory with PagedAttention Continuous batching of incoming requests Fast model execution with CUDA/HIP graph Quantizations: GPTQ, AWQ, AutoRound, INT4, INT8, and FP8 Optimized CUDA kernels, including integration with FlashAttention and FlashInfer Speculative decoding Chunked prefill alias: "vllm" capabilities: nvidia: "cuda12-vllm" amd: "rocm-vllm" intel: "intel-vllm" nvidia-cuda-12: "cuda12-vllm" - &vllm-omni name: "vllm-omni" license: apache-2.0 urls: - https://github.com/vllm-project/vllm-omni tags: - text-to-image - image-generation - text-to-video - video-generation - text-to-speech - TTS - multimodal - LLM icon: https://raw.githubusercontent.com/vllm-project/vllm/main/docs/assets/logos/vllm-logo-text-dark.png description: | vLLM-Omni is a unified interface for multimodal generation with vLLM. It supports image generation (text-to-image, image editing), video generation (text-to-video, image-to-video), text generation with multimodal inputs, and text-to-speech generation. Only supports NVIDIA (CUDA) and ROCm platforms. alias: "vllm-omni" capabilities: nvidia: "cuda12-vllm-omni" amd: "rocm-vllm-omni" nvidia-cuda-12: "cuda12-vllm-omni" - &mlx name: "mlx" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx" icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4 urls: - https://github.com/ml-explore/mlx-lm mirrors: - localai/localai-backends:latest-metal-darwin-arm64-mlx license: MIT description: | Run LLMs with MLX tags: - text-to-text - LLM - MLX capabilities: default: "cpu-mlx" nvidia: "cuda12-mlx" metal: "metal-mlx" nvidia-cuda-12: "cuda12-mlx" nvidia-cuda-13: "cuda13-mlx" nvidia-l4t: "nvidia-l4t-mlx" nvidia-l4t-cuda-12: "nvidia-l4t-mlx" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx" - &mlx-vlm name: "mlx-vlm" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-vlm" icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4 urls: - https://github.com/Blaizzy/mlx-vlm mirrors: - localai/localai-backends:latest-metal-darwin-arm64-mlx-vlm license: MIT description: | Run Vision-Language Models with MLX tags: - text-to-text - multimodal - vision-language - LLM - MLX capabilities: default: "cpu-mlx-vlm" nvidia: "cuda12-mlx-vlm" metal: "metal-mlx-vlm" nvidia-cuda-12: "cuda12-mlx-vlm" nvidia-cuda-13: "cuda13-mlx-vlm" nvidia-l4t: "nvidia-l4t-mlx-vlm" nvidia-l4t-cuda-12: "nvidia-l4t-mlx-vlm" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-vlm" - &mlx-audio name: "mlx-audio" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-audio" icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4 urls: - https://github.com/Blaizzy/mlx-audio mirrors: - localai/localai-backends:latest-metal-darwin-arm64-mlx-audio license: MIT description: | Run Audio Models with MLX tags: - audio-to-text - audio-generation - text-to-audio - LLM - MLX capabilities: default: "cpu-mlx-audio" nvidia: "cuda12-mlx-audio" metal: "metal-mlx-audio" nvidia-cuda-12: "cuda12-mlx-audio" nvidia-cuda-13: "cuda13-mlx-audio" nvidia-l4t: "nvidia-l4t-mlx-audio" nvidia-l4t-cuda-12: "nvidia-l4t-mlx-audio" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-audio" - &mlx-distributed name: "mlx-distributed" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx-distributed" icon: https://avatars.githubusercontent.com/u/102832242?s=200&v=4 urls: - https://github.com/ml-explore/mlx-lm mirrors: - localai/localai-backends:latest-metal-darwin-arm64-mlx-distributed license: MIT description: | Run distributed LLM inference with MLX across multiple Apple Silicon Macs tags: - text-to-text - LLM - MLX - distributed capabilities: default: "cpu-mlx-distributed" nvidia: "cuda12-mlx-distributed" metal: "metal-mlx-distributed" nvidia-cuda-12: "cuda12-mlx-distributed" nvidia-cuda-13: "cuda13-mlx-distributed" nvidia-l4t: "nvidia-l4t-mlx-distributed" nvidia-l4t-cuda-12: "nvidia-l4t-mlx-distributed" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-mlx-distributed" - &rerankers name: "rerankers" alias: "rerankers" capabilities: nvidia: "cuda12-rerankers" intel: "intel-rerankers" amd: "rocm-rerankers" metal: "metal-rerankers" - &transformers name: "transformers" icon: https://avatars.githubusercontent.com/u/25720743?s=200&v=4 alias: "transformers" license: apache-2.0 description: | Transformers acts as the model-definition framework for state-of-the-art machine learning models in text, computer vision, audio, video, and multimodal model, for both inference and training. It centralizes the model definition so that this definition is agreed upon across the ecosystem. transformers is the pivot across frameworks: if a model definition is supported, it will be compatible with the majority of training frameworks (Axolotl, Unsloth, DeepSpeed, FSDP, PyTorch-Lightning, ...), inference engines (vLLM, SGLang, TGI, ...), and adjacent modeling libraries (llama.cpp, mlx, ...) which leverage the model definition from transformers. urls: - https://github.com/huggingface/transformers tags: - text-to-text - multimodal capabilities: nvidia: "cuda12-transformers" intel: "intel-transformers" amd: "rocm-transformers" metal: "metal-transformers" nvidia-cuda-13: "cuda13-transformers" nvidia-cuda-12: "cuda12-transformers" - &diffusers name: "diffusers" icon: https://raw.githubusercontent.com/huggingface/diffusers/main/docs/source/en/imgs/diffusers_library.jpg description: | 🤗 Diffusers is the go-to library for state-of-the-art pretrained diffusion models for generating images, audio, and even 3D structures of molecules. Whether you're looking for a simple inference solution or training your own diffusion models, 🤗 Diffusers is a modular toolbox that supports both. urls: - https://github.com/huggingface/diffusers tags: - image-generation - video-generation - diffusion-models license: apache-2.0 alias: "diffusers" capabilities: nvidia: "cuda12-diffusers" intel: "intel-diffusers" amd: "rocm-diffusers" nvidia-l4t: "nvidia-l4t-diffusers" metal: "metal-diffusers" default: "cpu-diffusers" nvidia-cuda-13: "cuda13-diffusers" nvidia-cuda-12: "cuda12-diffusers" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-diffusers" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-diffusers" - &ace-step name: "ace-step" description: | ACE-Step 1.5 is an open-source music generation model. It supports simple mode (natural language description) and advanced mode (caption, lyrics, think, bpm, keyscale, etc.). Uses in-process acestep (LLMHandler for metadata, DiT for audio). urls: - https://github.com/ace-step/ACE-Step-1.5 tags: - music-generation - sound-generation alias: "ace-step" capabilities: nvidia: "cuda12-ace-step" intel: "intel-ace-step" amd: "rocm-ace-step" metal: "metal-ace-step" default: "cpu-ace-step" nvidia-cuda-13: "cuda13-ace-step" nvidia-cuda-12: "cuda12-ace-step" - !!merge <<: *ace-step name: "ace-step-development" capabilities: nvidia: "cuda12-ace-step-development" intel: "intel-ace-step-development" amd: "rocm-ace-step-development" metal: "metal-ace-step-development" default: "cpu-ace-step-development" nvidia-cuda-13: "cuda13-ace-step-development" nvidia-cuda-12: "cuda12-ace-step-development" - &acestepcpp name: "acestep-cpp" description: | ACE-Step 1.5 C++ backend using GGML. Native C++ implementation of ACE-Step music generation with GPU support through GGML backends. Generates stereo 48kHz audio from text descriptions and optional lyrics via a two-stage pipeline: text-to-code (ace-qwen3 LLM) + code-to-audio (DiT-VAE). urls: - https://github.com/ace-step/acestep.cpp tags: - music-generation - sound-generation alias: "acestep-cpp" capabilities: default: "cpu-acestep-cpp" nvidia: "cuda12-acestep-cpp" nvidia-cuda-13: "cuda13-acestep-cpp" nvidia-cuda-12: "cuda12-acestep-cpp" intel: "intel-sycl-f16-acestep-cpp" metal: "metal-acestep-cpp" amd: "rocm-acestep-cpp" vulkan: "vulkan-acestep-cpp" nvidia-l4t: "nvidia-l4t-arm64-acestep-cpp" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-acestep-cpp" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-acestep-cpp" - &faster-whisper icon: https://avatars.githubusercontent.com/u/1520500?s=200&v=4 description: | faster-whisper is a reimplementation of OpenAI's Whisper model using CTranslate2, which is a fast inference engine for Transformer models. This implementation is up to 4 times faster than openai/whisper for the same accuracy while using less memory. The efficiency can be further improved with 8-bit quantization on both CPU and GPU. urls: - https://github.com/SYSTRAN/faster-whisper tags: - speech-to-text - Whisper license: MIT name: "faster-whisper" capabilities: nvidia: "cuda12-faster-whisper" intel: "intel-faster-whisper" amd: "rocm-faster-whisper" metal: "metal-faster-whisper" nvidia-cuda-13: "cuda13-faster-whisper" nvidia-cuda-12: "cuda12-faster-whisper" - &moonshine description: | Moonshine is a fast, accurate, and efficient speech-to-text transcription model using ONNX Runtime. It provides real-time transcription capabilities with support for multiple model sizes and GPU acceleration. urls: - https://github.com/moonshine-ai/moonshine tags: - speech-to-text - transcription - ONNX license: MIT name: "moonshine" alias: "moonshine" capabilities: nvidia: "cuda12-moonshine" metal: "metal-moonshine" default: "cpu-moonshine" nvidia-cuda-13: "cuda13-moonshine" nvidia-cuda-12: "cuda12-moonshine" - &whisperx description: | WhisperX provides fast automatic speech recognition with word-level timestamps, speaker diarization, and forced alignment. Built on faster-whisper and pyannote-audio for high-accuracy transcription with speaker identification. urls: - https://github.com/m-bain/whisperX tags: - speech-to-text - diarization - whisperx license: BSD-4-Clause name: "whisperx" capabilities: nvidia: "cuda12-whisperx" amd: "rocm-whisperx" metal: "metal-whisperx" default: "cpu-whisperx" nvidia-cuda-13: "cuda13-whisperx" nvidia-cuda-12: "cuda12-whisperx" - &kokoro icon: https://avatars.githubusercontent.com/u/166769057?v=4 description: | Kokoro is an open-weight TTS model with 82 million parameters. Despite its lightweight architecture, it delivers comparable quality to larger models while being significantly faster and more cost-efficient. With Apache-licensed weights, Kokoro can be deployed anywhere from production environments to personal projects. urls: - https://huggingface.co/hexgrad/Kokoro-82M - https://github.com/hexgrad/kokoro tags: - text-to-speech - TTS - LLM license: apache-2.0 alias: "kokoro" name: "kokoro" capabilities: nvidia: "cuda12-kokoro" intel: "intel-kokoro" amd: "rocm-kokoro" nvidia-l4t: "nvidia-l4t-kokoro" metal: "metal-kokoro" nvidia-cuda-13: "cuda13-kokoro" nvidia-cuda-12: "cuda12-kokoro" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-kokoro" - &coqui urls: - https://github.com/idiap/coqui-ai-TTS description: | 🐸 Coqui TTS is a library for advanced Text-to-Speech generation. 🚀 Pretrained models in +1100 languages. 🛠️ Tools for training new models and fine-tuning existing models in any language. 📚 Utilities for dataset analysis and curation. tags: - text-to-speech - TTS license: mpl-2.0 name: "coqui" alias: "coqui" capabilities: nvidia: "cuda12-coqui" intel: "intel-coqui" amd: "rocm-coqui" metal: "metal-coqui" nvidia-cuda-13: "cuda13-coqui" nvidia-cuda-12: "cuda12-coqui" icon: https://avatars.githubusercontent.com/u/1338804?s=200&v=4 - &outetts urls: - https://github.com/OuteAI/outetts description: | OuteTTS is an open-weight text-to-speech model from OuteAI (OuteAI/OuteTTS-0.3-1B). Supports custom speaker voices via audio path or default speakers. tags: - text-to-speech - TTS license: apache-2.0 name: "outetts" alias: "outetts" capabilities: default: "cpu-outetts" nvidia-cuda-12: "cuda12-outetts" - &chatterbox urls: - https://github.com/resemble-ai/chatterbox description: | Resemble AI's first production-grade open source TTS model. Licensed under MIT, Chatterbox has been benchmarked against leading closed-source systems like ElevenLabs, and is consistently preferred in side-by-side evaluations. Whether you're working on memes, videos, games, or AI agents, Chatterbox brings your content to life. It's also the first open source TTS model to support emotion exaggeration control, a powerful feature that makes your voices stand out. tags: - text-to-speech - TTS license: MIT icon: https://avatars.githubusercontent.com/u/49844015?s=200&v=4 name: "chatterbox" alias: "chatterbox" capabilities: nvidia: "cuda12-chatterbox" metal: "metal-chatterbox" default: "cpu-chatterbox" nvidia-l4t: "nvidia-l4t-arm64-chatterbox" nvidia-cuda-13: "cuda13-chatterbox" nvidia-cuda-12: "cuda12-chatterbox" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-chatterbox" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-chatterbox" - &vibevoice urls: - https://github.com/microsoft/VibeVoice description: | VibeVoice-Realtime is a real-time text-to-speech model that generates natural-sounding speech. tags: - text-to-speech - TTS license: mit name: "vibevoice" alias: "vibevoice" capabilities: nvidia: "cuda12-vibevoice" intel: "intel-vibevoice" amd: "rocm-vibevoice" nvidia-l4t: "nvidia-l4t-vibevoice" metal: "metal-vibevoice" default: "cpu-vibevoice" nvidia-cuda-13: "cuda13-vibevoice" nvidia-cuda-12: "cuda12-vibevoice" nvidia-l4t-cuda-12: "nvidia-l4t-vibevoice" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-vibevoice" icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4 - &qwen-tts urls: - https://github.com/QwenLM/Qwen3-TTS description: | Qwen3-TTS is a high-quality text-to-speech model supporting custom voice, voice design, and voice cloning. tags: - text-to-speech - TTS license: apache-2.0 name: "qwen-tts" alias: "qwen-tts" capabilities: nvidia: "cuda12-qwen-tts" intel: "intel-qwen-tts" amd: "rocm-qwen-tts" nvidia-l4t: "nvidia-l4t-qwen-tts" metal: "metal-qwen-tts" default: "cpu-qwen-tts" nvidia-cuda-13: "cuda13-qwen-tts" nvidia-cuda-12: "cuda12-qwen-tts" nvidia-l4t-cuda-12: "nvidia-l4t-qwen-tts" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-qwen-tts" icon: https://cdn-avatars.huggingface.co/v1/production/uploads/620760a26e3b7210c2ff1943/-s1gyJfvbE1RgO5iBeNOi.png - &fish-speech urls: - https://github.com/fishaudio/fish-speech description: | Fish Speech is a high-quality text-to-speech model supporting voice cloning via reference audio. tags: - text-to-speech - TTS - voice-cloning license: apache-2.0 name: "fish-speech" alias: "fish-speech" capabilities: nvidia: "cuda12-fish-speech" intel: "intel-fish-speech" amd: "rocm-fish-speech" nvidia-l4t: "nvidia-l4t-fish-speech" metal: "metal-fish-speech" default: "cpu-fish-speech" nvidia-cuda-13: "cuda13-fish-speech" nvidia-cuda-12: "cuda12-fish-speech" nvidia-l4t-cuda-12: "nvidia-l4t-fish-speech" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-fish-speech" icon: https://avatars.githubusercontent.com/u/148526220?s=200&v=4 - &faster-qwen3-tts urls: - https://github.com/andimarafioti/faster-qwen3-tts - https://pypi.org/project/faster-qwen3-tts/ description: | Real-time Qwen3-TTS inference using CUDA graph capture. Voice clone only; requires NVIDIA GPU with CUDA. tags: - text-to-speech - TTS - voice-clone license: apache-2.0 name: "faster-qwen3-tts" alias: "faster-qwen3-tts" capabilities: nvidia: "cuda12-faster-qwen3-tts" default: "cuda12-faster-qwen3-tts" nvidia-cuda-13: "cuda13-faster-qwen3-tts" nvidia-cuda-12: "cuda12-faster-qwen3-tts" nvidia-l4t: "nvidia-l4t-faster-qwen3-tts" nvidia-l4t-cuda-12: "nvidia-l4t-faster-qwen3-tts" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-faster-qwen3-tts" icon: https://cdn-avatars.huggingface.co/v1/production/uploads/620760a26e3b7210c2ff1943/-s1gyJfvbE1RgO5iBeNOi.png - &qwen-asr urls: - https://github.com/QwenLM/Qwen3-ASR description: | Qwen3-ASR is an automatic speech recognition model supporting multiple languages and batch inference. tags: - speech-recognition - ASR license: apache-2.0 name: "qwen-asr" alias: "qwen-asr" capabilities: nvidia: "cuda12-qwen-asr" intel: "intel-qwen-asr" amd: "rocm-qwen-asr" nvidia-l4t: "nvidia-l4t-qwen-asr" metal: "metal-qwen-asr" default: "cpu-qwen-asr" nvidia-cuda-13: "cuda13-qwen-asr" nvidia-cuda-12: "cuda12-qwen-asr" nvidia-l4t-cuda-12: "nvidia-l4t-qwen-asr" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-qwen-asr" icon: https://cdn-avatars.huggingface.co/v1/production/uploads/620760a26e3b7210c2ff1943/-s1gyJfvbE1RgO5iBeNOi.png - &nemo urls: - https://github.com/NVIDIA/NeMo description: | NVIDIA NEMO Toolkit for ASR provides state-of-the-art automatic speech recognition models including Parakeet models for various languages and use cases. tags: - speech-recognition - ASR - NVIDIA license: apache-2.0 name: "nemo" alias: "nemo" capabilities: nvidia: "cuda12-nemo" intel: "intel-nemo" amd: "rocm-nemo" metal: "metal-nemo" default: "cpu-nemo" nvidia-cuda-13: "cuda13-nemo" nvidia-cuda-12: "cuda12-nemo" icon: https://www.nvidia.com/favicon.ico - &voxcpm urls: - https://github.com/ModelBest/VoxCPM description: | VoxCPM is an innovative end-to-end TTS model from ModelBest, designed to generate highly expressive speech. tags: - text-to-speech - TTS license: mit name: "voxcpm" alias: "voxcpm" capabilities: nvidia: "cuda12-voxcpm" intel: "intel-voxcpm" amd: "rocm-voxcpm" metal: "metal-voxcpm" default: "cpu-voxcpm" nvidia-cuda-13: "cuda13-voxcpm" nvidia-cuda-12: "cuda12-voxcpm" icon: https://avatars.githubusercontent.com/u/6154722?s=200&v=4 - &pocket-tts urls: - https://github.com/kyutai-labs/pocket-tts description: | Pocket TTS is a lightweight text-to-speech model designed to run efficiently on CPUs. tags: - text-to-speech - TTS license: mit name: "pocket-tts" alias: "pocket-tts" capabilities: nvidia: "cuda12-pocket-tts" intel: "intel-pocket-tts" amd: "rocm-pocket-tts" nvidia-l4t: "nvidia-l4t-pocket-tts" metal: "metal-pocket-tts" default: "cpu-pocket-tts" nvidia-cuda-13: "cuda13-pocket-tts" nvidia-cuda-12: "cuda12-pocket-tts" nvidia-l4t-cuda-12: "nvidia-l4t-pocket-tts" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-pocket-tts" icon: https://avatars.githubusercontent.com/u/151010778?s=200&v=4 - &piper name: "piper" uri: "quay.io/go-skynet/local-ai-backends:latest-piper" icon: https://github.com/OHF-Voice/piper1-gpl/raw/main/etc/logo.png urls: - https://github.com/rhasspy/piper - https://github.com/mudler/go-piper mirrors: - localai/localai-backends:latest-piper license: MIT description: | A fast, local neural text to speech system tags: - text-to-speech - TTS - &opus name: "opus" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-opus" urls: - https://opus-codec.org/ mirrors: - localai/localai-backends:latest-cpu-opus license: BSD-3-Clause description: | Opus audio codec backend for encoding and decoding audio. Required for WebRTC transport in the Realtime API. tags: - audio-codec - opus - WebRTC - realtime - CPU - &silero-vad name: "silero-vad" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-silero-vad" icon: https://user-images.githubusercontent.com/12515440/89997349-b3523080-dc94-11ea-9906-ca2e8bc50535.png urls: - https://github.com/snakers4/silero-vad mirrors: - localai/localai-backends:latest-cpu-silero-vad description: | Silero VAD: pre-trained enterprise-grade Voice Activity Detector. Silero VAD is a voice activity detection model that can be used to detect whether a given audio contains speech or not. tags: - voice-activity-detection - VAD - silero-vad - CPU - &local-store name: "local-store" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-local-store" mirrors: - localai/localai-backends:latest-cpu-local-store urls: - https://github.com/mudler/LocalAI description: | Local Store is a local-first, self-hosted, and open-source vector database. tags: - vector-database - local-first - open-source - CPU license: MIT - &kitten-tts name: "kitten-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-kitten-tts" mirrors: - localai/localai-backends:latest-kitten-tts urls: - https://github.com/KittenML/KittenTTS description: | Kitten TTS is a text-to-speech model that can generate speech from text. tags: - text-to-speech - TTS license: apache-2.0 - &neutts name: "neutts" urls: - https://github.com/neuphonic/neutts-air description: | NeuTTS Air is the world’s first super-realistic, on-device, TTS speech language model with instant voice cloning. Built off a 0.5B LLM backbone, NeuTTS Air brings natural-sounding speech, real-time performance, built-in security and speaker cloning to your local device - unlocking a new category of embedded voice agents, assistants, toys, and compliance-safe apps. tags: - text-to-speech - TTS license: apache-2.0 capabilities: default: "cpu-neutts" nvidia: "cuda12-neutts" amd: "rocm-neutts" nvidia-cuda-12: "cuda12-neutts" - !!merge <<: *neutts name: "neutts-development" capabilities: default: "cpu-neutts-development" nvidia: "cuda12-neutts-development" amd: "rocm-neutts-development" nvidia-cuda-12: "cuda12-neutts-development" - !!merge <<: *llamacpp name: "llama-cpp-development" capabilities: default: "cpu-llama-cpp-development" nvidia: "cuda12-llama-cpp-development" intel: "intel-sycl-f16-llama-cpp-development" amd: "rocm-llama-cpp-development" metal: "metal-llama-cpp-development" vulkan: "vulkan-llama-cpp-development" nvidia-l4t: "nvidia-l4t-arm64-llama-cpp-development" nvidia-cuda-13: "cuda13-llama-cpp-development" nvidia-cuda-12: "cuda12-llama-cpp-development" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-llama-cpp-development" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-llama-cpp-development" - !!merge <<: *neutts name: "cpu-neutts" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-neutts" mirrors: - localai/localai-backends:latest-cpu-neutts - !!merge <<: *neutts name: "cuda12-neutts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-neutts" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-neutts - !!merge <<: *neutts name: "rocm-neutts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-neutts" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-neutts - !!merge <<: *neutts name: "cpu-neutts-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-neutts" mirrors: - localai/localai-backends:master-cpu-neutts - !!merge <<: *neutts name: "cuda12-neutts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-neutts" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-neutts - !!merge <<: *neutts name: "rocm-neutts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-neutts" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-neutts - !!merge <<: *mlx name: "mlx-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx" mirrors: - localai/localai-backends:master-metal-darwin-arm64-mlx - !!merge <<: *mlx-vlm name: "mlx-vlm-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-vlm" mirrors: - localai/localai-backends:master-metal-darwin-arm64-mlx-vlm - !!merge <<: *mlx-audio name: "mlx-audio-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-audio" mirrors: - localai/localai-backends:master-metal-darwin-arm64-mlx-audio - !!merge <<: *mlx-distributed name: "mlx-distributed-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-mlx-distributed" mirrors: - localai/localai-backends:master-metal-darwin-arm64-mlx-distributed ## mlx - !!merge <<: *mlx name: "cpu-mlx" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-mlx" mirrors: - localai/localai-backends:latest-cpu-mlx - !!merge <<: *mlx name: "cpu-mlx-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-mlx" mirrors: - localai/localai-backends:master-cpu-mlx - !!merge <<: *mlx name: "cuda12-mlx" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-mlx" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-mlx - !!merge <<: *mlx name: "cuda12-mlx-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-mlx" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-mlx - !!merge <<: *mlx name: "cuda13-mlx" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-mlx" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-mlx - !!merge <<: *mlx name: "cuda13-mlx-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-mlx" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-mlx - !!merge <<: *mlx name: "nvidia-l4t-mlx" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-mlx" mirrors: - localai/localai-backends:latest-nvidia-l4t-mlx - !!merge <<: *mlx name: "nvidia-l4t-mlx-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-mlx" mirrors: - localai/localai-backends:master-nvidia-l4t-mlx - !!merge <<: *mlx name: "cuda13-nvidia-l4t-arm64-mlx" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx - !!merge <<: *mlx name: "cuda13-nvidia-l4t-arm64-mlx-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx ## mlx-vlm - !!merge <<: *mlx-vlm name: "cpu-mlx-vlm" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-mlx-vlm" mirrors: - localai/localai-backends:latest-cpu-mlx-vlm - !!merge <<: *mlx-vlm name: "cpu-mlx-vlm-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-mlx-vlm" mirrors: - localai/localai-backends:master-cpu-mlx-vlm - !!merge <<: *mlx-vlm name: "cuda12-mlx-vlm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-mlx-vlm" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-mlx-vlm - !!merge <<: *mlx-vlm name: "cuda12-mlx-vlm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-mlx-vlm" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-mlx-vlm - !!merge <<: *mlx-vlm name: "cuda13-mlx-vlm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-mlx-vlm" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-mlx-vlm - !!merge <<: *mlx-vlm name: "cuda13-mlx-vlm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-mlx-vlm" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-mlx-vlm - !!merge <<: *mlx-vlm name: "nvidia-l4t-mlx-vlm" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-mlx-vlm" mirrors: - localai/localai-backends:latest-nvidia-l4t-mlx-vlm - !!merge <<: *mlx-vlm name: "nvidia-l4t-mlx-vlm-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-mlx-vlm" mirrors: - localai/localai-backends:master-nvidia-l4t-mlx-vlm - !!merge <<: *mlx-vlm name: "cuda13-nvidia-l4t-arm64-mlx-vlm" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-vlm" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-vlm - !!merge <<: *mlx-vlm name: "cuda13-nvidia-l4t-arm64-mlx-vlm-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-vlm" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-vlm ## mlx-audio - !!merge <<: *mlx-audio name: "cpu-mlx-audio" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-mlx-audio" mirrors: - localai/localai-backends:latest-cpu-mlx-audio - !!merge <<: *mlx-audio name: "cpu-mlx-audio-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-mlx-audio" mirrors: - localai/localai-backends:master-cpu-mlx-audio - !!merge <<: *mlx-audio name: "cuda12-mlx-audio" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-mlx-audio" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-mlx-audio - !!merge <<: *mlx-audio name: "cuda12-mlx-audio-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-mlx-audio" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-mlx-audio - !!merge <<: *mlx-audio name: "cuda13-mlx-audio" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-mlx-audio" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-mlx-audio - !!merge <<: *mlx-audio name: "cuda13-mlx-audio-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-mlx-audio" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-mlx-audio - !!merge <<: *mlx-audio name: "nvidia-l4t-mlx-audio" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-mlx-audio" mirrors: - localai/localai-backends:latest-nvidia-l4t-mlx-audio - !!merge <<: *mlx-audio name: "nvidia-l4t-mlx-audio-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-mlx-audio" mirrors: - localai/localai-backends:master-nvidia-l4t-mlx-audio - !!merge <<: *mlx-audio name: "cuda13-nvidia-l4t-arm64-mlx-audio" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-audio" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-audio - !!merge <<: *mlx-audio name: "cuda13-nvidia-l4t-arm64-mlx-audio-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-audio ## mlx-distributed - !!merge <<: *mlx-distributed name: "cpu-mlx-distributed" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-mlx-distributed" mirrors: - localai/localai-backends:latest-cpu-mlx-distributed - !!merge <<: *mlx-distributed name: "cpu-mlx-distributed-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-mlx-distributed" mirrors: - localai/localai-backends:master-cpu-mlx-distributed - !!merge <<: *mlx-distributed name: "cuda12-mlx-distributed" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-mlx-distributed - !!merge <<: *mlx-distributed name: "cuda12-mlx-distributed-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-mlx-distributed" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-mlx-distributed - !!merge <<: *mlx-distributed name: "cuda13-mlx-distributed" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-mlx-distributed - !!merge <<: *mlx-distributed name: "cuda13-mlx-distributed-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-mlx-distributed" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-mlx-distributed - !!merge <<: *mlx-distributed name: "nvidia-l4t-mlx-distributed" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-mlx-distributed" mirrors: - localai/localai-backends:latest-nvidia-l4t-mlx-distributed - !!merge <<: *mlx-distributed name: "nvidia-l4t-mlx-distributed-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-mlx-distributed" mirrors: - localai/localai-backends:master-nvidia-l4t-mlx-distributed - !!merge <<: *mlx-distributed name: "cuda13-nvidia-l4t-arm64-mlx-distributed" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-mlx-distributed - !!merge <<: *mlx-distributed name: "cuda13-nvidia-l4t-arm64-mlx-distributed-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-mlx-distributed - !!merge <<: *kitten-tts name: "kitten-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-kitten-tts" mirrors: - localai/localai-backends:master-kitten-tts - !!merge <<: *kitten-tts name: "metal-kitten-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-kitten-tts" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-kitten-tts - !!merge <<: *kitten-tts name: "metal-kitten-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-kitten-tts" mirrors: - localai/localai-backends:master-metal-darwin-arm64-kitten-tts - !!merge <<: *local-store name: "local-store-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-local-store" mirrors: - localai/localai-backends:master-cpu-local-store - !!merge <<: *local-store name: "metal-local-store" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-local-store" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-local-store - !!merge <<: *local-store name: "metal-local-store-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-local-store" mirrors: - localai/localai-backends:master-metal-darwin-arm64-local-store - !!merge <<: *opus name: "opus-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-opus" mirrors: - localai/localai-backends:master-cpu-opus - !!merge <<: *opus name: "metal-opus" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-opus" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-opus - !!merge <<: *opus name: "metal-opus-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-opus" mirrors: - localai/localai-backends:master-metal-darwin-arm64-opus - !!merge <<: *silero-vad name: "silero-vad-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-silero-vad" mirrors: - localai/localai-backends:master-cpu-silero-vad - !!merge <<: *silero-vad name: "metal-silero-vad" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-silero-vad" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-silero-vad - !!merge <<: *silero-vad name: "metal-silero-vad-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-silero-vad" mirrors: - localai/localai-backends:master-metal-darwin-arm64-silero-vad - !!merge <<: *piper name: "piper-development" uri: "quay.io/go-skynet/local-ai-backends:master-piper" mirrors: - localai/localai-backends:master-piper - !!merge <<: *piper name: "metal-piper" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-piper" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-piper - !!merge <<: *piper name: "metal-piper-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-piper" mirrors: - localai/localai-backends:master-metal-darwin-arm64-piper ## llama-cpp - !!merge <<: *llamacpp name: "nvidia-l4t-arm64-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-llama-cpp" mirrors: - localai/localai-backends:latest-nvidia-l4t-arm64-llama-cpp - !!merge <<: *llamacpp name: "nvidia-l4t-arm64-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-llama-cpp" mirrors: - localai/localai-backends:master-nvidia-l4t-arm64-llama-cpp - !!merge <<: *llamacpp name: "cuda13-nvidia-l4t-arm64-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-llama-cpp" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-llama-cpp - !!merge <<: *llamacpp name: "cuda13-nvidia-l4t-arm64-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-llama-cpp" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-llama-cpp - !!merge <<: *llamacpp name: "cpu-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp" mirrors: - localai/localai-backends:latest-cpu-llama-cpp - !!merge <<: *llamacpp name: "cpu-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-llama-cpp" mirrors: - localai/localai-backends:master-cpu-llama-cpp - !!merge <<: *llamacpp name: "cuda12-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-llama-cpp" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-llama-cpp - !!merge <<: *llamacpp name: "rocm-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-llama-cpp" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-llama-cpp - !!merge <<: *llamacpp name: "intel-sycl-f32-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-llama-cpp" mirrors: - localai/localai-backends:latest-gpu-intel-sycl-f32-llama-cpp - !!merge <<: *llamacpp name: "intel-sycl-f16-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-llama-cpp" mirrors: - localai/localai-backends:latest-gpu-intel-sycl-f16-llama-cpp - !!merge <<: *llamacpp name: "vulkan-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-llama-cpp" mirrors: - localai/localai-backends:latest-gpu-vulkan-llama-cpp - !!merge <<: *llamacpp name: "vulkan-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-llama-cpp" mirrors: - localai/localai-backends:master-gpu-vulkan-llama-cpp - !!merge <<: *llamacpp name: "metal-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-llama-cpp - !!merge <<: *llamacpp name: "metal-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-llama-cpp" mirrors: - localai/localai-backends:master-metal-darwin-arm64-llama-cpp - !!merge <<: *llamacpp name: "cuda12-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-llama-cpp" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-llama-cpp - !!merge <<: *llamacpp name: "rocm-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-llama-cpp" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-llama-cpp - !!merge <<: *llamacpp name: "intel-sycl-f32-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-llama-cpp" mirrors: - localai/localai-backends:master-gpu-intel-sycl-f32-llama-cpp - !!merge <<: *llamacpp name: "intel-sycl-f16-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-llama-cpp" mirrors: - localai/localai-backends:master-gpu-intel-sycl-f16-llama-cpp - !!merge <<: *llamacpp name: "cuda13-llama-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-llama-cpp" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-llama-cpp - !!merge <<: *llamacpp name: "cuda13-llama-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-llama-cpp" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-llama-cpp ## whisper - !!merge <<: *whispercpp name: "nvidia-l4t-arm64-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-whisper" mirrors: - localai/localai-backends:latest-nvidia-l4t-arm64-whisper - !!merge <<: *whispercpp name: "nvidia-l4t-arm64-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-whisper" mirrors: - localai/localai-backends:master-nvidia-l4t-arm64-whisper - !!merge <<: *whispercpp name: "cuda13-nvidia-l4t-arm64-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-whisper" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-whisper - !!merge <<: *whispercpp name: "cuda13-nvidia-l4t-arm64-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-whisper" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-whisper - !!merge <<: *whispercpp name: "cpu-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisper" mirrors: - localai/localai-backends:latest-cpu-whisper - !!merge <<: *whispercpp name: "metal-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-whisper" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-whisper - !!merge <<: *whispercpp name: "metal-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisper" mirrors: - localai/localai-backends:master-metal-darwin-arm64-whisper - !!merge <<: *whispercpp name: "cpu-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-whisper" mirrors: - localai/localai-backends:master-cpu-whisper - !!merge <<: *whispercpp name: "cuda12-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-whisper" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-whisper - !!merge <<: *whispercpp name: "rocm-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-whisper" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-whisper - !!merge <<: *whispercpp name: "intel-sycl-f32-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-whisper" mirrors: - localai/localai-backends:latest-gpu-intel-sycl-f32-whisper - !!merge <<: *whispercpp name: "intel-sycl-f16-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-whisper" mirrors: - localai/localai-backends:latest-gpu-intel-sycl-f16-whisper - !!merge <<: *whispercpp name: "vulkan-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-whisper" mirrors: - localai/localai-backends:latest-gpu-vulkan-whisper - !!merge <<: *whispercpp name: "vulkan-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-whisper" mirrors: - localai/localai-backends:master-gpu-vulkan-whisper - !!merge <<: *whispercpp name: "metal-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-whisper" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-whisper - !!merge <<: *whispercpp name: "metal-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisper" mirrors: - localai/localai-backends:master-metal-darwin-arm64-whisper - !!merge <<: *whispercpp name: "cuda12-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-whisper" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-whisper - !!merge <<: *whispercpp name: "rocm-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-whisper" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-whisper - !!merge <<: *whispercpp name: "intel-sycl-f32-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-whisper" mirrors: - localai/localai-backends:master-gpu-intel-sycl-f32-whisper - !!merge <<: *whispercpp name: "intel-sycl-f16-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-whisper" mirrors: - localai/localai-backends:master-gpu-intel-sycl-f16-whisper - !!merge <<: *whispercpp name: "cuda13-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-whisper" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-whisper - !!merge <<: *whispercpp name: "cuda13-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-whisper" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-whisper ## stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "cpu-stablediffusion-ggml" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-stablediffusion-ggml" mirrors: - localai/localai-backends:latest-cpu-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "cpu-stablediffusion-ggml-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-stablediffusion-ggml" mirrors: - localai/localai-backends:master-cpu-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "metal-stablediffusion-ggml" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-stablediffusion-ggml" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "metal-stablediffusion-ggml-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-stablediffusion-ggml" mirrors: - localai/localai-backends:master-metal-darwin-arm64-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "vulkan-stablediffusion-ggml" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-stablediffusion-ggml" mirrors: - localai/localai-backends:latest-gpu-vulkan-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "vulkan-stablediffusion-ggml-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-stablediffusion-ggml" mirrors: - localai/localai-backends:master-gpu-vulkan-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "cuda12-stablediffusion-ggml" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-stablediffusion-ggml" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "intel-sycl-f32-stablediffusion-ggml" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-stablediffusion-ggml" - !!merge <<: *stablediffusionggml name: "intel-sycl-f16-stablediffusion-ggml" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-stablediffusion-ggml" mirrors: - localai/localai-backends:latest-gpu-intel-sycl-f16-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "cuda12-stablediffusion-ggml-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-stablediffusion-ggml" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "intel-sycl-f32-stablediffusion-ggml-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-stablediffusion-ggml" mirrors: - localai/localai-backends:master-gpu-intel-sycl-f32-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "intel-sycl-f16-stablediffusion-ggml-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-stablediffusion-ggml" mirrors: - localai/localai-backends:master-gpu-intel-sycl-f16-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "nvidia-l4t-arm64-stablediffusion-ggml-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-stablediffusion-ggml" mirrors: - localai/localai-backends:master-nvidia-l4t-arm64-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "nvidia-l4t-arm64-stablediffusion-ggml" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-stablediffusion-ggml" mirrors: - localai/localai-backends:latest-nvidia-l4t-arm64-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "cuda13-nvidia-l4t-arm64-stablediffusion-ggml" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "cuda13-nvidia-l4t-arm64-stablediffusion-ggml-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "cuda13-stablediffusion-ggml" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-stablediffusion-ggml" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-stablediffusion-ggml - !!merge <<: *stablediffusionggml name: "cuda13-stablediffusion-ggml-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-stablediffusion-ggml" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-stablediffusion-ggml # vllm - !!merge <<: *vllm name: "vllm-development" capabilities: nvidia: "cuda12-vllm-development" amd: "rocm-vllm-development" intel: "intel-vllm-development" - !!merge <<: *vllm name: "cuda12-vllm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-vllm" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-vllm - !!merge <<: *vllm name: "rocm-vllm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-vllm" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-vllm - !!merge <<: *vllm name: "intel-vllm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-vllm" mirrors: - localai/localai-backends:latest-gpu-intel-vllm - !!merge <<: *vllm name: "cuda12-vllm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-vllm" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-vllm - !!merge <<: *vllm name: "rocm-vllm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-vllm" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-vllm - !!merge <<: *vllm name: "intel-vllm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-vllm" mirrors: - localai/localai-backends:master-gpu-intel-vllm # vllm-omni - !!merge <<: *vllm-omni name: "vllm-omni-development" capabilities: nvidia: "cuda12-vllm-omni-development" amd: "rocm-vllm-omni-development" nvidia-cuda-12: "cuda12-vllm-omni-development" - !!merge <<: *vllm-omni name: "cuda12-vllm-omni" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-vllm-omni" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-vllm-omni - !!merge <<: *vllm-omni name: "rocm-vllm-omni" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-vllm-omni" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-vllm-omni - !!merge <<: *vllm-omni name: "cuda12-vllm-omni-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-vllm-omni" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-vllm-omni - !!merge <<: *vllm-omni name: "rocm-vllm-omni-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-vllm-omni" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-vllm-omni # rfdetr - !!merge <<: *rfdetr name: "rfdetr-development" capabilities: nvidia: "cuda12-rfdetr-development" intel: "intel-rfdetr-development" #amd: "rocm-rfdetr-development" nvidia-l4t: "nvidia-l4t-arm64-rfdetr-development" metal: "metal-rfdetr-development" default: "cpu-rfdetr-development" nvidia-cuda-13: "cuda13-rfdetr-development" - !!merge <<: *rfdetr name: "cuda12-rfdetr" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-rfdetr" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-rfdetr - !!merge <<: *rfdetr name: "intel-rfdetr" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-rfdetr" mirrors: - localai/localai-backends:latest-gpu-intel-rfdetr # - !!merge <<: *rfdetr # name: "rocm-rfdetr" # uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-hipblas-rfdetr" # mirrors: # - localai/localai-backends:latest-gpu-hipblas-rfdetr - !!merge <<: *rfdetr name: "nvidia-l4t-arm64-rfdetr" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-rfdetr" mirrors: - localai/localai-backends:latest-nvidia-l4t-arm64-rfdetr - !!merge <<: *rfdetr name: "nvidia-l4t-arm64-rfdetr-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-rfdetr" mirrors: - localai/localai-backends:master-nvidia-l4t-arm64-rfdetr - !!merge <<: *rfdetr name: "cpu-rfdetr" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-rfdetr" mirrors: - localai/localai-backends:latest-cpu-rfdetr - !!merge <<: *rfdetr name: "cuda12-rfdetr-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-rfdetr" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-rfdetr - !!merge <<: *rfdetr name: "intel-rfdetr-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-rfdetr" mirrors: - localai/localai-backends:master-gpu-intel-rfdetr # - !!merge <<: *rfdetr # name: "rocm-rfdetr-development" # uri: "quay.io/go-skynet/local-ai-backends:master-gpu-hipblas-rfdetr" # mirrors: # - localai/localai-backends:master-gpu-hipblas-rfdetr - !!merge <<: *rfdetr name: "cpu-rfdetr-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-rfdetr" mirrors: - localai/localai-backends:master-cpu-rfdetr - !!merge <<: *rfdetr name: "intel-rfdetr" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-rfdetr" mirrors: - localai/localai-backends:latest-gpu-intel-rfdetr - !!merge <<: *rfdetr name: "cuda13-rfdetr" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-rfdetr" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-rfdetr - !!merge <<: *rfdetr name: "cuda13-rfdetr-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-rfdetr" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-rfdetr - !!merge <<: *rfdetr name: "metal-rfdetr" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-rfdetr" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-rfdetr - !!merge <<: *rfdetr name: "metal-rfdetr-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rfdetr" mirrors: - localai/localai-backends:master-metal-darwin-arm64-rfdetr ## Rerankers - !!merge <<: *rerankers name: "rerankers-development" capabilities: nvidia: "cuda12-rerankers-development" intel: "intel-rerankers-development" amd: "rocm-rerankers-development" metal: "metal-rerankers-development" nvidia-cuda-13: "cuda13-rerankers-development" - !!merge <<: *rerankers name: "cuda12-rerankers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-rerankers" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-rerankers - !!merge <<: *rerankers name: "intel-rerankers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-rerankers" mirrors: - localai/localai-backends:latest-gpu-intel-rerankers - !!merge <<: *rerankers name: "rocm-rerankers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-rerankers" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-rerankers - !!merge <<: *rerankers name: "cuda12-rerankers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-rerankers" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-rerankers - !!merge <<: *rerankers name: "rocm-rerankers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-rerankers" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-rerankers - !!merge <<: *rerankers name: "intel-rerankers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-rerankers" mirrors: - localai/localai-backends:master-gpu-intel-rerankers - !!merge <<: *rerankers name: "cuda13-rerankers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-rerankers" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-rerankers - !!merge <<: *rerankers name: "cuda13-rerankers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-rerankers" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-rerankers - !!merge <<: *rerankers name: "metal-rerankers" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-rerankers" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-rerankers - !!merge <<: *rerankers name: "metal-rerankers-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-rerankers" mirrors: - localai/localai-backends:master-metal-darwin-arm64-rerankers ## Transformers - !!merge <<: *transformers name: "transformers-development" capabilities: nvidia: "cuda12-transformers-development" intel: "intel-transformers-development" amd: "rocm-transformers-development" metal: "metal-transformers-development" nvidia-cuda-13: "cuda13-transformers-development" - !!merge <<: *transformers name: "cuda12-transformers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-transformers" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-transformers - !!merge <<: *transformers name: "rocm-transformers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-transformers" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-transformers - !!merge <<: *transformers name: "intel-transformers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-transformers" mirrors: - localai/localai-backends:latest-gpu-intel-transformers - !!merge <<: *transformers name: "cuda12-transformers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-transformers" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-transformers - !!merge <<: *transformers name: "rocm-transformers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-transformers" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-transformers - !!merge <<: *transformers name: "intel-transformers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-transformers" mirrors: - localai/localai-backends:master-gpu-intel-transformers - !!merge <<: *transformers name: "cuda13-transformers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-transformers" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-transformers - !!merge <<: *transformers name: "cuda13-transformers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-transformers" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-transformers - !!merge <<: *transformers name: "metal-transformers" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-transformers" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-transformers - !!merge <<: *transformers name: "metal-transformers-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-transformers" mirrors: - localai/localai-backends:master-metal-darwin-arm64-transformers ## Diffusers - !!merge <<: *diffusers name: "diffusers-development" capabilities: nvidia: "cuda12-diffusers-development" intel: "intel-diffusers-development" amd: "rocm-diffusers-development" nvidia-l4t: "nvidia-l4t-diffusers-development" metal: "metal-diffusers-development" default: "cpu-diffusers-development" nvidia-cuda-13: "cuda13-diffusers-development" - !!merge <<: *diffusers name: "cpu-diffusers" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-diffusers" mirrors: - localai/localai-backends:latest-cpu-diffusers - !!merge <<: *diffusers name: "cpu-diffusers-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-diffusers" mirrors: - localai/localai-backends:master-cpu-diffusers - !!merge <<: *diffusers name: "nvidia-l4t-diffusers" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-diffusers" mirrors: - localai/localai-backends:latest-nvidia-l4t-diffusers - !!merge <<: *diffusers name: "nvidia-l4t-diffusers-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-diffusers" mirrors: - localai/localai-backends:master-nvidia-l4t-diffusers - !!merge <<: *diffusers name: "cuda13-nvidia-l4t-arm64-diffusers" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-diffusers" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-diffusers - !!merge <<: *diffusers name: "cuda13-nvidia-l4t-arm64-diffusers-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-diffusers" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-diffusers - !!merge <<: *diffusers name: "cuda12-diffusers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-diffusers" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-diffusers - !!merge <<: *diffusers name: "rocm-diffusers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-diffusers" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-diffusers - !!merge <<: *diffusers name: "intel-diffusers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-diffusers" mirrors: - localai/localai-backends:latest-gpu-intel-diffusers - !!merge <<: *diffusers name: "cuda12-diffusers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-diffusers" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-diffusers - !!merge <<: *diffusers name: "rocm-diffusers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-diffusers" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-diffusers - !!merge <<: *diffusers name: "intel-diffusers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-diffusers" mirrors: - localai/localai-backends:master-gpu-intel-diffusers - !!merge <<: *diffusers name: "cuda13-diffusers" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-diffusers" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-diffusers - !!merge <<: *diffusers name: "cuda13-diffusers-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-diffusers" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-diffusers - !!merge <<: *diffusers name: "metal-diffusers" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-diffusers" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-diffusers - !!merge <<: *diffusers name: "metal-diffusers-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-diffusers" mirrors: - localai/localai-backends:master-metal-darwin-arm64-diffusers ## ace-step - !!merge <<: *ace-step name: "cpu-ace-step" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-ace-step" mirrors: - localai/localai-backends:latest-cpu-ace-step - !!merge <<: *ace-step name: "cpu-ace-step-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-ace-step" mirrors: - localai/localai-backends:master-cpu-ace-step - !!merge <<: *ace-step name: "cuda12-ace-step" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-ace-step" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-ace-step - !!merge <<: *ace-step name: "cuda12-ace-step-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-ace-step" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-ace-step - !!merge <<: *ace-step name: "cuda13-ace-step" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-ace-step" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-ace-step - !!merge <<: *ace-step name: "cuda13-ace-step-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-ace-step" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-ace-step - !!merge <<: *ace-step name: "rocm-ace-step" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-ace-step" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-ace-step - !!merge <<: *ace-step name: "rocm-ace-step-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-ace-step" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-ace-step - !!merge <<: *ace-step name: "intel-ace-step" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-ace-step" mirrors: - localai/localai-backends:latest-gpu-intel-ace-step - !!merge <<: *ace-step name: "intel-ace-step-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-ace-step" mirrors: - localai/localai-backends:master-gpu-intel-ace-step - !!merge <<: *ace-step name: "metal-ace-step" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-ace-step" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-ace-step - !!merge <<: *ace-step name: "metal-ace-step-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-ace-step" mirrors: - localai/localai-backends:master-metal-darwin-arm64-ace-step ## acestep-cpp - !!merge <<: *acestepcpp name: "nvidia-l4t-arm64-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-acestep-cpp" mirrors: - localai/localai-backends:latest-nvidia-l4t-arm64-acestep-cpp - !!merge <<: *acestepcpp name: "nvidia-l4t-arm64-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-acestep-cpp" mirrors: - localai/localai-backends:master-nvidia-l4t-arm64-acestep-cpp - !!merge <<: *acestepcpp name: "cuda13-nvidia-l4t-arm64-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-acestep-cpp" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-acestep-cpp - !!merge <<: *acestepcpp name: "cuda13-nvidia-l4t-arm64-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-acestep-cpp" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-acestep-cpp - !!merge <<: *acestepcpp name: "cpu-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-acestep-cpp" mirrors: - localai/localai-backends:latest-cpu-acestep-cpp - !!merge <<: *acestepcpp name: "metal-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-acestep-cpp" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-acestep-cpp - !!merge <<: *acestepcpp name: "metal-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-acestep-cpp" mirrors: - localai/localai-backends:master-metal-darwin-arm64-acestep-cpp - !!merge <<: *acestepcpp name: "cpu-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-acestep-cpp" mirrors: - localai/localai-backends:master-cpu-acestep-cpp - !!merge <<: *acestepcpp name: "cuda12-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-acestep-cpp" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-acestep-cpp - !!merge <<: *acestepcpp name: "rocm-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-acestep-cpp" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-acestep-cpp - !!merge <<: *acestepcpp name: "intel-sycl-f32-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-acestep-cpp" mirrors: - localai/localai-backends:latest-gpu-intel-sycl-f32-acestep-cpp - !!merge <<: *acestepcpp name: "intel-sycl-f16-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-acestep-cpp" mirrors: - localai/localai-backends:latest-gpu-intel-sycl-f16-acestep-cpp - !!merge <<: *acestepcpp name: "vulkan-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-acestep-cpp" mirrors: - localai/localai-backends:latest-gpu-vulkan-acestep-cpp - !!merge <<: *acestepcpp name: "vulkan-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-vulkan-acestep-cpp" mirrors: - localai/localai-backends:master-gpu-vulkan-acestep-cpp - !!merge <<: *acestepcpp name: "cuda12-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-acestep-cpp" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-acestep-cpp - !!merge <<: *acestepcpp name: "rocm-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-acestep-cpp" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-acestep-cpp - !!merge <<: *acestepcpp name: "intel-sycl-f32-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f32-acestep-cpp" mirrors: - localai/localai-backends:master-gpu-intel-sycl-f32-acestep-cpp - !!merge <<: *acestepcpp name: "intel-sycl-f16-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-sycl-f16-acestep-cpp" mirrors: - localai/localai-backends:master-gpu-intel-sycl-f16-acestep-cpp - !!merge <<: *acestepcpp name: "cuda13-acestep-cpp" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-acestep-cpp" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-acestep-cpp - !!merge <<: *acestepcpp name: "cuda13-acestep-cpp-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-acestep-cpp" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-acestep-cpp ## kokoro - !!merge <<: *kokoro name: "kokoro-development" capabilities: nvidia: "cuda12-kokoro-development" intel: "intel-kokoro-development" amd: "rocm-kokoro-development" nvidia-l4t: "nvidia-l4t-kokoro-development" metal: "metal-kokoro-development" - !!merge <<: *kokoro name: "cuda12-kokoro-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-kokoro" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-kokoro - !!merge <<: *kokoro name: "rocm-kokoro-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-kokoro" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-kokoro - !!merge <<: *kokoro name: "intel-kokoro" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-kokoro" mirrors: - localai/localai-backends:latest-gpu-intel-kokoro - !!merge <<: *kokoro name: "intel-kokoro-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-kokoro" mirrors: - localai/localai-backends:master-gpu-intel-kokoro - !!merge <<: *kokoro name: "nvidia-l4t-kokoro" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-kokoro" mirrors: - localai/localai-backends:latest-nvidia-l4t-kokoro - !!merge <<: *kokoro name: "nvidia-l4t-kokoro-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-kokoro" mirrors: - localai/localai-backends:master-nvidia-l4t-kokoro - !!merge <<: *kokoro name: "cuda12-kokoro" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-kokoro" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-kokoro - !!merge <<: *kokoro name: "rocm-kokoro" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-kokoro" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-kokoro - !!merge <<: *kokoro name: "cuda13-kokoro" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-kokoro" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-kokoro - !!merge <<: *kokoro name: "cuda13-kokoro-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-kokoro" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-kokoro - !!merge <<: *kokoro name: "metal-kokoro" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-kokoro" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-kokoro - !!merge <<: *kokoro name: "metal-kokoro-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-kokoro" mirrors: - localai/localai-backends:master-metal-darwin-arm64-kokoro ## faster-whisper - !!merge <<: *faster-whisper name: "faster-whisper-development" capabilities: nvidia: "cuda12-faster-whisper-development" intel: "intel-faster-whisper-development" amd: "rocm-faster-whisper-development" metal: "metal-faster-whisper-development" nvidia-cuda-13: "cuda13-faster-whisper-development" - !!merge <<: *faster-whisper name: "cuda12-faster-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-faster-whisper" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-faster-whisper - !!merge <<: *faster-whisper name: "rocm-faster-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-faster-whisper" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-faster-whisper - !!merge <<: *faster-whisper name: "intel-faster-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-faster-whisper" mirrors: - localai/localai-backends:latest-gpu-intel-faster-whisper - !!merge <<: *faster-whisper name: "intel-faster-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-faster-whisper" mirrors: - localai/localai-backends:master-gpu-intel-faster-whisper - !!merge <<: *faster-whisper name: "cuda13-faster-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-faster-whisper" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-faster-whisper - !!merge <<: *faster-whisper name: "cuda13-faster-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-faster-whisper" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-faster-whisper - !!merge <<: *faster-whisper name: "metal-faster-whisper" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-faster-whisper" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-faster-whisper - !!merge <<: *faster-whisper name: "metal-faster-whisper-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-faster-whisper" mirrors: - localai/localai-backends:master-metal-darwin-arm64-faster-whisper ## moonshine - !!merge <<: *moonshine name: "moonshine-development" capabilities: nvidia: "cuda12-moonshine-development" default: "cpu-moonshine-development" nvidia-cuda-13: "cuda13-moonshine-development" nvidia-cuda-12: "cuda12-moonshine-development" - !!merge <<: *moonshine name: "cpu-moonshine" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-moonshine" mirrors: - localai/localai-backends:latest-cpu-moonshine - !!merge <<: *moonshine name: "cpu-moonshine-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-moonshine" mirrors: - localai/localai-backends:master-cpu-moonshine - !!merge <<: *moonshine name: "cuda12-moonshine" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-moonshine" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-moonshine - !!merge <<: *moonshine name: "cuda12-moonshine-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-moonshine" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-moonshine - !!merge <<: *moonshine name: "cuda13-moonshine" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-moonshine" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-moonshine - !!merge <<: *moonshine name: "cuda13-moonshine-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-moonshine" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-moonshine - !!merge <<: *moonshine name: "metal-moonshine" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-moonshine" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-moonshine - !!merge <<: *moonshine name: "metal-moonshine-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-moonshine" mirrors: - localai/localai-backends:master-metal-darwin-arm64-moonshine ## whisperx - !!merge <<: *whisperx name: "whisperx-development" capabilities: nvidia: "cuda12-whisperx-development" amd: "rocm-whisperx-development" metal: "metal-whisperx-development" default: "cpu-whisperx-development" nvidia-cuda-13: "cuda13-whisperx-development" nvidia-cuda-12: "cuda12-whisperx-development" - !!merge <<: *whisperx name: "cpu-whisperx" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-whisperx" mirrors: - localai/localai-backends:latest-cpu-whisperx - !!merge <<: *whisperx name: "cpu-whisperx-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-whisperx" mirrors: - localai/localai-backends:master-cpu-whisperx - !!merge <<: *whisperx name: "cuda12-whisperx" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-whisperx" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-whisperx - !!merge <<: *whisperx name: "cuda12-whisperx-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-whisperx" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-whisperx - !!merge <<: *whisperx name: "rocm-whisperx" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-whisperx" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-whisperx - !!merge <<: *whisperx name: "rocm-whisperx-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-whisperx" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-whisperx - !!merge <<: *whisperx name: "cuda13-whisperx" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-whisperx" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-whisperx - !!merge <<: *whisperx name: "cuda13-whisperx-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-whisperx" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-whisperx - !!merge <<: *whisperx name: "metal-whisperx" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-whisperx" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-whisperx - !!merge <<: *whisperx name: "metal-whisperx-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-whisperx" mirrors: - localai/localai-backends:master-metal-darwin-arm64-whisperx ## coqui - !!merge <<: *coqui name: "coqui-development" capabilities: nvidia: "cuda12-coqui-development" intel: "intel-coqui-development" amd: "rocm-coqui-development" metal: "metal-coqui-development" - !!merge <<: *coqui name: "cuda12-coqui" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-coqui" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-coqui - !!merge <<: *coqui name: "cuda12-coqui-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-coqui" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-coqui - !!merge <<: *coqui name: "rocm-coqui-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-coqui" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-coqui - !!merge <<: *coqui name: "intel-coqui" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-coqui" mirrors: - localai/localai-backends:latest-gpu-intel-coqui - !!merge <<: *coqui name: "intel-coqui-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-coqui" mirrors: - localai/localai-backends:master-gpu-intel-coqui - !!merge <<: *coqui name: "rocm-coqui" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-coqui" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-coqui - !!merge <<: *coqui name: "metal-coqui" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-coqui" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-coqui - !!merge <<: *coqui name: "metal-coqui-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-coqui" mirrors: - localai/localai-backends:master-metal-darwin-arm64-coqui ## outetts - !!merge <<: *outetts name: "outetts-development" capabilities: default: "cpu-outetts-development" nvidia-cuda-12: "cuda12-outetts-development" - !!merge <<: *outetts name: "cpu-outetts" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-outetts" mirrors: - localai/localai-backends:latest-cpu-outetts - !!merge <<: *outetts name: "cpu-outetts-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-outetts" mirrors: - localai/localai-backends:master-cpu-outetts - !!merge <<: *outetts name: "cuda12-outetts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-outetts" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-outetts - !!merge <<: *outetts name: "cuda12-outetts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-outetts" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-outetts ## chatterbox - !!merge <<: *chatterbox name: "chatterbox-development" capabilities: nvidia: "cuda12-chatterbox-development" metal: "metal-chatterbox-development" default: "cpu-chatterbox-development" nvidia-l4t: "nvidia-l4t-arm64-chatterbox" nvidia-cuda-13: "cuda13-chatterbox-development" nvidia-cuda-12: "cuda12-chatterbox-development" nvidia-l4t-cuda-12: "nvidia-l4t-arm64-chatterbox" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-chatterbox-development" - !!merge <<: *chatterbox name: "cpu-chatterbox" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-chatterbox" mirrors: - localai/localai-backends:latest-cpu-chatterbox - !!merge <<: *chatterbox name: "cpu-chatterbox-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-chatterbox" mirrors: - localai/localai-backends:master-cpu-chatterbox - !!merge <<: *chatterbox name: "nvidia-l4t-arm64-chatterbox" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-arm64-chatterbox" mirrors: - localai/localai-backends:latest-nvidia-l4t-arm64-chatterbox - !!merge <<: *chatterbox name: "nvidia-l4t-arm64-chatterbox-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-arm64-chatterbox" mirrors: - localai/localai-backends:master-nvidia-l4t-arm64-chatterbox - !!merge <<: *chatterbox name: "metal-chatterbox" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-chatterbox" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-chatterbox - !!merge <<: *chatterbox name: "metal-chatterbox-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-chatterbox" mirrors: - localai/localai-backends:master-metal-darwin-arm64-chatterbox - !!merge <<: *chatterbox name: "cuda12-chatterbox-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-chatterbox" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-chatterbox - !!merge <<: *chatterbox name: "cuda12-chatterbox" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-chatterbox" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-chatterbox - !!merge <<: *chatterbox name: "cuda13-chatterbox" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-chatterbox" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-chatterbox - !!merge <<: *chatterbox name: "cuda13-chatterbox-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-chatterbox" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-chatterbox - !!merge <<: *chatterbox name: "cuda13-nvidia-l4t-arm64-chatterbox" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-chatterbox" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-chatterbox - !!merge <<: *chatterbox name: "cuda13-nvidia-l4t-arm64-chatterbox-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-chatterbox" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-chatterbox ## vibevoice - !!merge <<: *vibevoice name: "vibevoice-development" capabilities: nvidia: "cuda12-vibevoice-development" intel: "intel-vibevoice-development" amd: "rocm-vibevoice-development" nvidia-l4t: "nvidia-l4t-vibevoice-development" metal: "metal-vibevoice-development" default: "cpu-vibevoice-development" nvidia-cuda-13: "cuda13-vibevoice-development" nvidia-cuda-12: "cuda12-vibevoice-development" nvidia-l4t-cuda-12: "nvidia-l4t-vibevoice-development" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-vibevoice-development" - !!merge <<: *vibevoice name: "cpu-vibevoice" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-vibevoice" mirrors: - localai/localai-backends:latest-cpu-vibevoice - !!merge <<: *vibevoice name: "cpu-vibevoice-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-vibevoice" mirrors: - localai/localai-backends:master-cpu-vibevoice - !!merge <<: *vibevoice name: "cuda12-vibevoice" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-vibevoice" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-vibevoice - !!merge <<: *vibevoice name: "cuda12-vibevoice-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-vibevoice" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-vibevoice - !!merge <<: *vibevoice name: "cuda13-vibevoice" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-vibevoice" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-vibevoice - !!merge <<: *vibevoice name: "cuda13-vibevoice-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-vibevoice" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-vibevoice - !!merge <<: *vibevoice name: "intel-vibevoice" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-vibevoice" mirrors: - localai/localai-backends:latest-gpu-intel-vibevoice - !!merge <<: *vibevoice name: "intel-vibevoice-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-vibevoice" mirrors: - localai/localai-backends:master-gpu-intel-vibevoice - !!merge <<: *vibevoice name: "rocm-vibevoice" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-vibevoice" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-vibevoice - !!merge <<: *vibevoice name: "rocm-vibevoice-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-vibevoice" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-vibevoice - !!merge <<: *vibevoice name: "nvidia-l4t-vibevoice" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-vibevoice" mirrors: - localai/localai-backends:latest-nvidia-l4t-vibevoice - !!merge <<: *vibevoice name: "nvidia-l4t-vibevoice-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-vibevoice" mirrors: - localai/localai-backends:master-nvidia-l4t-vibevoice - !!merge <<: *vibevoice name: "cuda13-nvidia-l4t-arm64-vibevoice" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-vibevoice" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-vibevoice - !!merge <<: *vibevoice name: "cuda13-nvidia-l4t-arm64-vibevoice-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-vibevoice" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-vibevoice - !!merge <<: *vibevoice name: "metal-vibevoice" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-vibevoice" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-vibevoice - !!merge <<: *vibevoice name: "metal-vibevoice-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-vibevoice" mirrors: - localai/localai-backends:master-metal-darwin-arm64-vibevoice ## qwen-tts - !!merge <<: *qwen-tts name: "qwen-tts-development" capabilities: nvidia: "cuda12-qwen-tts-development" intel: "intel-qwen-tts-development" amd: "rocm-qwen-tts-development" nvidia-l4t: "nvidia-l4t-qwen-tts-development" metal: "metal-qwen-tts-development" default: "cpu-qwen-tts-development" nvidia-cuda-13: "cuda13-qwen-tts-development" nvidia-cuda-12: "cuda12-qwen-tts-development" nvidia-l4t-cuda-12: "nvidia-l4t-qwen-tts-development" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-qwen-tts-development" - !!merge <<: *qwen-tts name: "cpu-qwen-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-qwen-tts" mirrors: - localai/localai-backends:latest-cpu-qwen-tts - !!merge <<: *qwen-tts name: "cpu-qwen-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-qwen-tts" mirrors: - localai/localai-backends:master-cpu-qwen-tts - !!merge <<: *qwen-tts name: "cuda12-qwen-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-qwen-tts" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-qwen-tts - !!merge <<: *qwen-tts name: "cuda12-qwen-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-qwen-tts" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-qwen-tts - !!merge <<: *qwen-tts name: "cuda13-qwen-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-qwen-tts" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-qwen-tts - !!merge <<: *qwen-tts name: "cuda13-qwen-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-qwen-tts" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-qwen-tts - !!merge <<: *qwen-tts name: "intel-qwen-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-qwen-tts" mirrors: - localai/localai-backends:latest-gpu-intel-qwen-tts - !!merge <<: *qwen-tts name: "intel-qwen-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-qwen-tts" mirrors: - localai/localai-backends:master-gpu-intel-qwen-tts - !!merge <<: *qwen-tts name: "rocm-qwen-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-qwen-tts" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-qwen-tts - !!merge <<: *qwen-tts name: "rocm-qwen-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-qwen-tts" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-qwen-tts - !!merge <<: *qwen-tts name: "nvidia-l4t-qwen-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-qwen-tts" mirrors: - localai/localai-backends:latest-nvidia-l4t-qwen-tts - !!merge <<: *qwen-tts name: "nvidia-l4t-qwen-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-qwen-tts" mirrors: - localai/localai-backends:master-nvidia-l4t-qwen-tts - !!merge <<: *qwen-tts name: "cuda13-nvidia-l4t-arm64-qwen-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-qwen-tts" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-qwen-tts - !!merge <<: *qwen-tts name: "cuda13-nvidia-l4t-arm64-qwen-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-qwen-tts" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-qwen-tts - !!merge <<: *qwen-tts name: "metal-qwen-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-qwen-tts" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-qwen-tts - !!merge <<: *qwen-tts name: "metal-qwen-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-qwen-tts" mirrors: - localai/localai-backends:master-metal-darwin-arm64-qwen-tts ## fish-speech - !!merge <<: *fish-speech name: "fish-speech-development" capabilities: nvidia: "cuda12-fish-speech-development" intel: "intel-fish-speech-development" amd: "rocm-fish-speech-development" nvidia-l4t: "nvidia-l4t-fish-speech-development" metal: "metal-fish-speech-development" default: "cpu-fish-speech-development" nvidia-cuda-13: "cuda13-fish-speech-development" nvidia-cuda-12: "cuda12-fish-speech-development" nvidia-l4t-cuda-12: "nvidia-l4t-fish-speech-development" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-fish-speech-development" - !!merge <<: *fish-speech name: "cpu-fish-speech" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-fish-speech" mirrors: - localai/localai-backends:latest-cpu-fish-speech - !!merge <<: *fish-speech name: "cpu-fish-speech-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-fish-speech" mirrors: - localai/localai-backends:master-cpu-fish-speech - !!merge <<: *fish-speech name: "cuda12-fish-speech" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-fish-speech" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-fish-speech - !!merge <<: *fish-speech name: "cuda12-fish-speech-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-fish-speech" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-fish-speech - !!merge <<: *fish-speech name: "cuda13-fish-speech" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-fish-speech" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-fish-speech - !!merge <<: *fish-speech name: "cuda13-fish-speech-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-fish-speech" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-fish-speech - !!merge <<: *fish-speech name: "intel-fish-speech" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-fish-speech" mirrors: - localai/localai-backends:latest-gpu-intel-fish-speech - !!merge <<: *fish-speech name: "intel-fish-speech-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-fish-speech" mirrors: - localai/localai-backends:master-gpu-intel-fish-speech - !!merge <<: *fish-speech name: "rocm-fish-speech" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-fish-speech" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-fish-speech - !!merge <<: *fish-speech name: "rocm-fish-speech-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-fish-speech" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-fish-speech - !!merge <<: *fish-speech name: "nvidia-l4t-fish-speech" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-fish-speech" mirrors: - localai/localai-backends:latest-nvidia-l4t-fish-speech - !!merge <<: *fish-speech name: "nvidia-l4t-fish-speech-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-fish-speech" mirrors: - localai/localai-backends:master-nvidia-l4t-fish-speech - !!merge <<: *fish-speech name: "cuda13-nvidia-l4t-arm64-fish-speech" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-fish-speech" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-fish-speech - !!merge <<: *fish-speech name: "cuda13-nvidia-l4t-arm64-fish-speech-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-fish-speech" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-fish-speech - !!merge <<: *fish-speech name: "metal-fish-speech" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-fish-speech" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-fish-speech - !!merge <<: *fish-speech name: "metal-fish-speech-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-fish-speech" mirrors: - localai/localai-backends:master-metal-darwin-arm64-fish-speech ## faster-qwen3-tts - !!merge <<: *faster-qwen3-tts name: "faster-qwen3-tts-development" capabilities: nvidia: "cuda12-faster-qwen3-tts-development" default: "cuda12-faster-qwen3-tts-development" nvidia-cuda-13: "cuda13-faster-qwen3-tts-development" nvidia-cuda-12: "cuda12-faster-qwen3-tts-development" nvidia-l4t: "nvidia-l4t-faster-qwen3-tts-development" nvidia-l4t-cuda-12: "nvidia-l4t-faster-qwen3-tts-development" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-faster-qwen3-tts-development" - !!merge <<: *faster-qwen3-tts name: "cuda12-faster-qwen3-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-faster-qwen3-tts" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-faster-qwen3-tts - !!merge <<: *faster-qwen3-tts name: "cuda12-faster-qwen3-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-faster-qwen3-tts" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-faster-qwen3-tts - !!merge <<: *faster-qwen3-tts name: "cuda13-faster-qwen3-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-faster-qwen3-tts" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-faster-qwen3-tts - !!merge <<: *faster-qwen3-tts name: "cuda13-faster-qwen3-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-faster-qwen3-tts" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-faster-qwen3-tts - !!merge <<: *faster-qwen3-tts name: "nvidia-l4t-faster-qwen3-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-faster-qwen3-tts" mirrors: - localai/localai-backends:latest-nvidia-l4t-faster-qwen3-tts - !!merge <<: *faster-qwen3-tts name: "nvidia-l4t-faster-qwen3-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-faster-qwen3-tts" mirrors: - localai/localai-backends:master-nvidia-l4t-faster-qwen3-tts - !!merge <<: *faster-qwen3-tts name: "cuda13-nvidia-l4t-arm64-faster-qwen3-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts - !!merge <<: *faster-qwen3-tts name: "cuda13-nvidia-l4t-arm64-faster-qwen3-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-faster-qwen3-tts ## qwen-asr - !!merge <<: *qwen-asr name: "qwen-asr-development" capabilities: nvidia: "cuda12-qwen-asr-development" intel: "intel-qwen-asr-development" amd: "rocm-qwen-asr-development" nvidia-l4t: "nvidia-l4t-qwen-asr-development" metal: "metal-qwen-asr-development" default: "cpu-qwen-asr-development" nvidia-cuda-13: "cuda13-qwen-asr-development" nvidia-cuda-12: "cuda12-qwen-asr-development" nvidia-l4t-cuda-12: "nvidia-l4t-qwen-asr-development" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-qwen-asr-development" - !!merge <<: *qwen-asr name: "cpu-qwen-asr" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-qwen-asr" mirrors: - localai/localai-backends:latest-cpu-qwen-asr - !!merge <<: *qwen-asr name: "cpu-qwen-asr-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-qwen-asr" mirrors: - localai/localai-backends:master-cpu-qwen-asr - !!merge <<: *qwen-asr name: "cuda12-qwen-asr" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-qwen-asr" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-qwen-asr - !!merge <<: *qwen-asr name: "cuda12-qwen-asr-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-qwen-asr" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-qwen-asr - !!merge <<: *qwen-asr name: "cuda13-qwen-asr" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-qwen-asr" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-qwen-asr - !!merge <<: *qwen-asr name: "cuda13-qwen-asr-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-qwen-asr" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-qwen-asr - !!merge <<: *qwen-asr name: "intel-qwen-asr" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-qwen-asr" mirrors: - localai/localai-backends:latest-gpu-intel-qwen-asr - !!merge <<: *qwen-asr name: "intel-qwen-asr-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-qwen-asr" mirrors: - localai/localai-backends:master-gpu-intel-qwen-asr - !!merge <<: *qwen-asr name: "rocm-qwen-asr" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-qwen-asr" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-qwen-asr - !!merge <<: *qwen-asr name: "rocm-qwen-asr-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-qwen-asr" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-qwen-asr - !!merge <<: *qwen-asr name: "nvidia-l4t-qwen-asr" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-qwen-asr" mirrors: - localai/localai-backends:latest-nvidia-l4t-qwen-asr - !!merge <<: *qwen-asr name: "nvidia-l4t-qwen-asr-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-qwen-asr" mirrors: - localai/localai-backends:master-nvidia-l4t-qwen-asr - !!merge <<: *qwen-asr name: "cuda13-nvidia-l4t-arm64-qwen-asr" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-qwen-asr" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-qwen-asr - !!merge <<: *qwen-asr name: "cuda13-nvidia-l4t-arm64-qwen-asr-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-qwen-asr" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-qwen-asr - !!merge <<: *qwen-asr name: "metal-qwen-asr" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-qwen-asr" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-qwen-asr - !!merge <<: *qwen-asr name: "metal-qwen-asr-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-qwen-asr" mirrors: - localai/localai-backends:master-metal-darwin-arm64-qwen-asr ## nemo - !!merge <<: *nemo name: "nemo-development" capabilities: nvidia: "cuda12-nemo-development" intel: "intel-nemo-development" amd: "rocm-nemo-development" metal: "metal-nemo-development" default: "cpu-nemo-development" nvidia-cuda-13: "cuda13-nemo-development" nvidia-cuda-12: "cuda12-nemo-development" - !!merge <<: *nemo name: "cpu-nemo" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-nemo" mirrors: - localai/localai-backends:latest-cpu-nemo - !!merge <<: *nemo name: "cpu-nemo-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-nemo" mirrors: - localai/localai-backends:master-cpu-nemo - !!merge <<: *nemo name: "cuda12-nemo" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-nemo" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-nemo - !!merge <<: *nemo name: "cuda12-nemo-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-nemo" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-nemo - !!merge <<: *nemo name: "cuda13-nemo" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-nemo" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-nemo - !!merge <<: *nemo name: "cuda13-nemo-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-nemo" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-nemo - !!merge <<: *nemo name: "intel-nemo" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-nemo" mirrors: - localai/localai-backends:latest-gpu-intel-nemo - !!merge <<: *nemo name: "intel-nemo-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-nemo" mirrors: - localai/localai-backends:master-gpu-intel-nemo - !!merge <<: *nemo name: "rocm-nemo" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-nemo" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-nemo - !!merge <<: *nemo name: "rocm-nemo-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-nemo" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-nemo - !!merge <<: *nemo name: "metal-nemo" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-nemo" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-nemo - !!merge <<: *nemo name: "metal-nemo-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-nemo" mirrors: - localai/localai-backends:master-metal-darwin-arm64-nemo ## voxcpm - !!merge <<: *voxcpm name: "voxcpm-development" capabilities: nvidia: "cuda12-voxcpm-development" intel: "intel-voxcpm-development" amd: "rocm-voxcpm-development" metal: "metal-voxcpm-development" default: "cpu-voxcpm-development" nvidia-cuda-13: "cuda13-voxcpm-development" nvidia-cuda-12: "cuda12-voxcpm-development" - !!merge <<: *voxcpm name: "cpu-voxcpm" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-voxcpm" mirrors: - localai/localai-backends:latest-cpu-voxcpm - !!merge <<: *voxcpm name: "cpu-voxcpm-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-voxcpm" mirrors: - localai/localai-backends:master-cpu-voxcpm - !!merge <<: *voxcpm name: "cuda12-voxcpm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-voxcpm" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-voxcpm - !!merge <<: *voxcpm name: "cuda12-voxcpm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-voxcpm" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-voxcpm - !!merge <<: *voxcpm name: "cuda13-voxcpm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-voxcpm" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-voxcpm - !!merge <<: *voxcpm name: "cuda13-voxcpm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-voxcpm" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-voxcpm - !!merge <<: *voxcpm name: "intel-voxcpm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-voxcpm" mirrors: - localai/localai-backends:latest-gpu-intel-voxcpm - !!merge <<: *voxcpm name: "intel-voxcpm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-voxcpm" mirrors: - localai/localai-backends:master-gpu-intel-voxcpm - !!merge <<: *voxcpm name: "rocm-voxcpm" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-voxcpm" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-voxcpm - !!merge <<: *voxcpm name: "rocm-voxcpm-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-voxcpm" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-voxcpm - !!merge <<: *voxcpm name: "metal-voxcpm" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-voxcpm" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-voxcpm - !!merge <<: *voxcpm name: "metal-voxcpm-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxcpm" mirrors: - localai/localai-backends:master-metal-darwin-arm64-voxcpm ## pocket-tts - !!merge <<: *pocket-tts name: "pocket-tts-development" capabilities: nvidia: "cuda12-pocket-tts-development" intel: "intel-pocket-tts-development" amd: "rocm-pocket-tts-development" nvidia-l4t: "nvidia-l4t-pocket-tts-development" metal: "metal-pocket-tts-development" default: "cpu-pocket-tts-development" nvidia-cuda-13: "cuda13-pocket-tts-development" nvidia-cuda-12: "cuda12-pocket-tts-development" nvidia-l4t-cuda-12: "nvidia-l4t-pocket-tts-development" nvidia-l4t-cuda-13: "cuda13-nvidia-l4t-arm64-pocket-tts-development" - !!merge <<: *pocket-tts name: "cpu-pocket-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-pocket-tts" mirrors: - localai/localai-backends:latest-cpu-pocket-tts - !!merge <<: *pocket-tts name: "cpu-pocket-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-pocket-tts" mirrors: - localai/localai-backends:master-cpu-pocket-tts - !!merge <<: *pocket-tts name: "cuda12-pocket-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-pocket-tts" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-12-pocket-tts - !!merge <<: *pocket-tts name: "cuda12-pocket-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-12-pocket-tts" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-12-pocket-tts - !!merge <<: *pocket-tts name: "cuda13-pocket-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-pocket-tts" mirrors: - localai/localai-backends:latest-gpu-nvidia-cuda-13-pocket-tts - !!merge <<: *pocket-tts name: "cuda13-pocket-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-pocket-tts" mirrors: - localai/localai-backends:master-gpu-nvidia-cuda-13-pocket-tts - !!merge <<: *pocket-tts name: "intel-pocket-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-pocket-tts" mirrors: - localai/localai-backends:latest-gpu-intel-pocket-tts - !!merge <<: *pocket-tts name: "intel-pocket-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-intel-pocket-tts" mirrors: - localai/localai-backends:master-gpu-intel-pocket-tts - !!merge <<: *pocket-tts name: "rocm-pocket-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-pocket-tts" mirrors: - localai/localai-backends:latest-gpu-rocm-hipblas-pocket-tts - !!merge <<: *pocket-tts name: "rocm-pocket-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-gpu-rocm-hipblas-pocket-tts" mirrors: - localai/localai-backends:master-gpu-rocm-hipblas-pocket-tts - !!merge <<: *pocket-tts name: "nvidia-l4t-pocket-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-pocket-tts" mirrors: - localai/localai-backends:latest-nvidia-l4t-pocket-tts - !!merge <<: *pocket-tts name: "nvidia-l4t-pocket-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-pocket-tts" mirrors: - localai/localai-backends:master-nvidia-l4t-pocket-tts - !!merge <<: *pocket-tts name: "cuda13-nvidia-l4t-arm64-pocket-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-nvidia-l4t-cuda-13-arm64-pocket-tts" mirrors: - localai/localai-backends:latest-nvidia-l4t-cuda-13-arm64-pocket-tts - !!merge <<: *pocket-tts name: "cuda13-nvidia-l4t-arm64-pocket-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-nvidia-l4t-cuda-13-arm64-pocket-tts" mirrors: - localai/localai-backends:master-nvidia-l4t-cuda-13-arm64-pocket-tts - !!merge <<: *pocket-tts name: "metal-pocket-tts" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-pocket-tts" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-pocket-tts - !!merge <<: *pocket-tts name: "metal-pocket-tts-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-pocket-tts" mirrors: - localai/localai-backends:master-metal-darwin-arm64-pocket-tts ## voxtral - !!merge <<: *voxtral name: "cpu-voxtral" uri: "quay.io/go-skynet/local-ai-backends:latest-cpu-voxtral" mirrors: - localai/localai-backends:latest-cpu-voxtral - !!merge <<: *voxtral name: "cpu-voxtral-development" uri: "quay.io/go-skynet/local-ai-backends:master-cpu-voxtral" mirrors: - localai/localai-backends:master-cpu-voxtral - !!merge <<: *voxtral name: "metal-voxtral" uri: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-voxtral" mirrors: - localai/localai-backends:latest-metal-darwin-arm64-voxtral - !!merge <<: *voxtral name: "metal-voxtral-development" uri: "quay.io/go-skynet/local-ai-backends:master-metal-darwin-arm64-voxtral" mirrors: - localai/localai-backends:master-metal-darwin-arm64-voxtral ================================================ FILE: backend/python/README.md ================================================ # Python Backends for LocalAI This directory contains Python-based AI backends for LocalAI, providing support for various AI models and hardware acceleration targets. ## Overview The Python backends use a unified build system based on `libbackend.sh` that provides: - **Automatic virtual environment management** with support for both `uv` and `pip` - **Hardware-specific dependency installation** (CPU, CUDA, Intel, MLX, etc.) - **Portable Python support** for standalone deployments - **Consistent backend execution** across different environments ## Available Backends ### Core AI Models - **transformers** - Hugging Face Transformers framework (PyTorch-based) - **vllm** - High-performance LLM inference engine - **mlx** - Apple Silicon optimized ML framework ### Audio & Speech - **coqui** - Coqui TTS models - **faster-whisper** - Fast Whisper speech recognition - **kitten-tts** - Lightweight TTS - **mlx-audio** - Apple Silicon audio processing - **chatterbox** - TTS model - **kokoro** - TTS models ### Computer Vision - **diffusers** - Stable Diffusion and image generation - **mlx-vlm** - Vision-language models for Apple Silicon - **rfdetr** - Object detection models ### Specialized - **rerankers** - Text reranking models ## Quick Start ### Prerequisites - Python 3.10+ (default: 3.10.18) - `uv` package manager (recommended) or `pip` - Appropriate hardware drivers for your target (CUDA, Intel, etc.) ### Installation Each backend can be installed individually: ```bash # Navigate to a specific backend cd backend/python/transformers # Install dependencies make transformers # or bash install.sh # Run the backend make run # or bash run.sh ``` ### Using the Unified Build System The `libbackend.sh` script provides consistent commands across all backends: ```bash # Source the library in your backend script source $(dirname $0)/../common/libbackend.sh # Install requirements (automatically handles hardware detection) installRequirements # Start the backend server startBackend $@ # Run tests runUnittests ``` ## Hardware Targets The build system automatically detects and configures for different hardware: - **CPU** - Standard CPU-only builds - **CUDA** - NVIDIA GPU acceleration (supports CUDA 12/13) - **Intel** - Intel XPU/GPU optimization - **MLX** - Apple Silicon (M1/M2/M3) optimization - **HIP** - AMD GPU acceleration ### Target-Specific Requirements Backends can specify hardware-specific dependencies: - `requirements.txt` - Base requirements - `requirements-cpu.txt` - CPU-specific packages - `requirements-cublas12.txt` - CUDA 12 packages - `requirements-cublas13.txt` - CUDA 13 packages - `requirements-intel.txt` - Intel-optimized packages - `requirements-mps.txt` - Apple Silicon packages ## Configuration Options ### Environment Variables - `PYTHON_VERSION` - Python version (default: 3.10) - `PYTHON_PATCH` - Python patch version (default: 18) - `BUILD_TYPE` - Force specific build target - `USE_PIP` - Use pip instead of uv (default: false) - `PORTABLE_PYTHON` - Enable portable Python builds - `LIMIT_TARGETS` - Restrict backend to specific targets ### Example: CUDA 12 Only Backend ```bash # In your backend script LIMIT_TARGETS="cublas12" source $(dirname $0)/../common/libbackend.sh ``` ### Example: Intel-Optimized Backend ```bash # In your backend script LIMIT_TARGETS="intel" source $(dirname $0)/../common/libbackend.sh ``` ## Development ### Adding a New Backend 1. Create a new directory in `backend/python/` 2. Copy the template structure from `common/template/` 3. Implement your `backend.py` with the required gRPC interface 4. Add appropriate requirements files for your target hardware 5. Use `libbackend.sh` for consistent build and execution ### Testing ```bash # Run backend tests make test # or bash test.sh ``` ### Building ```bash # Install dependencies make # Clean build artifacts make clean ``` ## Architecture Each backend follows a consistent structure: ``` backend-name/ ├── backend.py # Main backend implementation ├── requirements.txt # Base dependencies ├── requirements-*.txt # Hardware-specific dependencies ├── install.sh # Installation script ├── run.sh # Execution script ├── test.sh # Test script ├── Makefile # Build targets └── test.py # Unit tests ``` ## Troubleshooting ### Common Issues 1. **Missing dependencies**: Ensure all requirements files are properly configured 2. **Hardware detection**: Check that `BUILD_TYPE` matches your system 3. **Python version**: Verify Python 3.10+ is available 4. **Virtual environment**: Use `ensureVenv` to create/activate environments ## Contributing When adding new backends or modifying existing ones: 1. Follow the established directory structure 2. Use `libbackend.sh` for consistent behavior 3. Include appropriate requirements files for all target hardware 4. Add comprehensive tests 5. Update this README if adding new backend types ================================================ FILE: backend/python/ace-step/Makefile ================================================ .DEFAULT_GOAL := install .PHONY: install install: bash install.sh .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ test: install bash test.sh ================================================ FILE: backend/python/ace-step/backend.py ================================================ #!/usr/bin/env python3 """ LocalAI ACE-Step Backend gRPC backend for ACE-Step 1.5 music generation. Aligns with upstream acestep API: - LoadModel: initializes AceStepHandler (DiT) and LLMHandler, parses Options. - SoundGeneration: uses create_sample (simple mode), format_sample (optional), then generate_music from acestep.inference. Writes first output to request.dst. - Fail hard: no fallback WAV on error; exceptions propagate to gRPC. """ from concurrent import futures import argparse import shutil import signal import sys import os import tempfile import backend_pb2 import backend_pb2_grpc import grpc from acestep.inference import ( GenerationParams, GenerationConfig, generate_music, create_sample, format_sample, ) from acestep.handler import AceStepHandler from acestep.llm_inference import LLMHandler from acestep.model_downloader import ensure_lm_model _ONE_DAY_IN_SECONDS = 60 * 60 * 24 MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1")) # Model name -> HuggingFace/ModelScope repo (from upstream api_server.py) MODEL_REPO_MAPPING = { "acestep-v15-turbo": "ACE-Step/Ace-Step1.5", "acestep-5Hz-lm-0.6B": "ACE-Step/Ace-Step1.5", "acestep-5Hz-lm-1.7B": "ACE-Step/Ace-Step1.5", "vae": "ACE-Step/Ace-Step1.5", "Qwen3-Embedding-0.6B": "ACE-Step/Ace-Step1.5", "acestep-v15-base": "ACE-Step/acestep-v15-base", "acestep-v15-sft": "ACE-Step/acestep-v15-sft", "acestep-v15-turbo-shift3": "ACE-Step/acestep-v15-turbo-shift3", "acestep-5Hz-lm-4B": "ACE-Step/acestep-5Hz-lm-4B", } DEFAULT_REPO_ID = "ACE-Step/Ace-Step1.5" def _is_float(s): try: float(s) return True except (ValueError, TypeError): return False def _is_int(s): try: int(s) return True except (ValueError, TypeError): return False def _parse_timesteps(s): if s is None or (isinstance(s, str) and not s.strip()): return None if isinstance(s, (list, tuple)): return [float(x) for x in s] try: return [float(x.strip()) for x in str(s).split(",") if x.strip()] except (ValueError, TypeError): return None def _parse_options(opts_list): """Parse repeated 'key:value' options into a dict. Coerce numeric and bool.""" out = {} for opt in opts_list or []: if ":" not in opt: continue key, value = opt.split(":", 1) key = key.strip() value = value.strip() if _is_int(value): out[key] = int(value) elif _is_float(value): out[key] = float(value) elif value.lower() in ("true", "false"): out[key] = value.lower() == "true" else: out[key] = value return out def _generate_audio_sync(servicer, payload, dst_path): """ Run full ACE-Step pipeline using acestep.inference: - If sample_mode/sample_query: create_sample() for caption/lyrics/metadata. - If use_format and caption/lyrics: format_sample(). - Build GenerationParams and GenerationConfig, then generate_music(). Writes the first generated audio to dst_path. Raises on failure. """ opts = servicer.options dit_handler = servicer.dit_handler llm_handler = servicer.llm_handler for key, value in opts.items(): if key not in payload: payload[key] = value def _opt(name, default): return opts.get(name, default) lm_temperature = _opt("temperature", 0.85) lm_cfg_scale = _opt("lm_cfg_scale", _opt("cfg_scale", 2.0)) lm_top_k = opts.get("top_k") lm_top_p = _opt("top_p", 0.9) if lm_top_p is not None and lm_top_p >= 1.0: lm_top_p = None inference_steps = _opt("inference_steps", 8) guidance_scale = _opt("guidance_scale", 7.0) batch_size = max(1, int(_opt("batch_size", 1))) use_simple = bool(payload.get("sample_query") or payload.get("text")) sample_mode = use_simple and (payload.get("thinking") or payload.get("sample_mode")) sample_query = (payload.get("sample_query") or payload.get("text") or "").strip() use_format = bool(payload.get("use_format")) caption = (payload.get("prompt") or payload.get("caption") or "").strip() lyrics = (payload.get("lyrics") or "").strip() vocal_language = (payload.get("vocal_language") or "en").strip() instrumental = bool(payload.get("instrumental")) bpm = payload.get("bpm") key_scale = (payload.get("key_scale") or "").strip() time_signature = (payload.get("time_signature") or "").strip() audio_duration = payload.get("audio_duration") if audio_duration is not None: try: audio_duration = float(audio_duration) except (TypeError, ValueError): audio_duration = None if sample_mode and llm_handler and getattr(llm_handler, "llm_initialized", False): parsed_language = None if sample_query: for hint in ("english", "en", "chinese", "zh", "japanese", "ja"): if hint in sample_query.lower(): parsed_language = "en" if hint == "english" or hint == "en" else hint break vocal_lang = vocal_language if vocal_language and vocal_language != "unknown" else parsed_language sample_result = create_sample( llm_handler=llm_handler, query=sample_query or "NO USER INPUT", instrumental=instrumental, vocal_language=vocal_lang, temperature=lm_temperature, top_k=lm_top_k, top_p=lm_top_p, use_constrained_decoding=True, ) if not sample_result.success: raise RuntimeError(f"create_sample failed: {sample_result.error or sample_result.status_message}") caption = sample_result.caption or caption lyrics = sample_result.lyrics or lyrics bpm = sample_result.bpm key_scale = sample_result.keyscale or key_scale time_signature = sample_result.timesignature or time_signature if sample_result.duration is not None: audio_duration = sample_result.duration if getattr(sample_result, "language", None): vocal_language = sample_result.language if use_format and (caption or lyrics) and llm_handler and getattr(llm_handler, "llm_initialized", False): user_metadata = {} if bpm is not None: user_metadata["bpm"] = bpm if audio_duration is not None and float(audio_duration) > 0: user_metadata["duration"] = int(audio_duration) if key_scale: user_metadata["keyscale"] = key_scale if time_signature: user_metadata["timesignature"] = time_signature if vocal_language and vocal_language != "unknown": user_metadata["language"] = vocal_language format_result = format_sample( llm_handler=llm_handler, caption=caption, lyrics=lyrics, user_metadata=user_metadata if user_metadata else None, temperature=lm_temperature, top_k=lm_top_k, top_p=lm_top_p, use_constrained_decoding=True, ) if format_result.success: caption = format_result.caption or caption lyrics = format_result.lyrics or lyrics if format_result.duration is not None: audio_duration = format_result.duration if format_result.bpm is not None: bpm = format_result.bpm if format_result.keyscale: key_scale = format_result.keyscale if format_result.timesignature: time_signature = format_result.timesignature if getattr(format_result, "language", None): vocal_language = format_result.language thinking = bool(payload.get("thinking")) use_cot_metas = not sample_mode params = GenerationParams( task_type=payload.get("task_type", "text2music"), instruction=payload.get("instruction", "Fill the audio semantic mask based on the given conditions:"), reference_audio=payload.get("reference_audio_path"), src_audio=payload.get("src_audio_path"), audio_codes=payload.get("audio_code_string", ""), caption=caption, lyrics=lyrics, instrumental=instrumental or (not lyrics or str(lyrics).strip().lower() in ("[inst]", "[instrumental]")), vocal_language=vocal_language or "unknown", bpm=bpm, keyscale=key_scale, timesignature=time_signature, duration=float(audio_duration) if audio_duration and float(audio_duration) > 0 else -1.0, inference_steps=inference_steps, seed=int(payload.get("seed", -1)), guidance_scale=guidance_scale, use_adg=bool(payload.get("use_adg")), cfg_interval_start=float(payload.get("cfg_interval_start", 0.0)), cfg_interval_end=float(payload.get("cfg_interval_end", 1.0)), shift=float(payload.get("shift", 1.0)), infer_method=(payload.get("infer_method") or "ode").strip(), timesteps=_parse_timesteps(payload.get("timesteps")), repainting_start=float(payload.get("repainting_start", 0.0)), repainting_end=float(payload.get("repainting_end", -1)) if payload.get("repainting_end") is not None else -1, audio_cover_strength=float(payload.get("audio_cover_strength", 1.0)), thinking=thinking, lm_temperature=lm_temperature, lm_cfg_scale=lm_cfg_scale, lm_top_k=lm_top_k or 0, lm_top_p=lm_top_p if lm_top_p is not None and lm_top_p < 1.0 else 0.9, lm_negative_prompt=payload.get("lm_negative_prompt", "NO USER INPUT"), use_cot_metas=use_cot_metas, use_cot_caption=bool(payload.get("use_cot_caption", True)), use_cot_language=bool(payload.get("use_cot_language", True)), use_constrained_decoding=True, ) config = GenerationConfig( batch_size=batch_size, allow_lm_batch=bool(payload.get("allow_lm_batch", False)), use_random_seed=bool(payload.get("use_random_seed", True)), seeds=payload.get("seeds"), lm_batch_chunk_size=max(1, int(payload.get("lm_batch_chunk_size", 8))), constrained_decoding_debug=bool(payload.get("constrained_decoding_debug")), audio_format=(payload.get("audio_format") or "flac").strip() or "flac", ) save_dir = tempfile.mkdtemp(prefix="ace_step_") try: result = generate_music( dit_handler=dit_handler, llm_handler=llm_handler if (llm_handler and getattr(llm_handler, "llm_initialized", False)) else None, params=params, config=config, save_dir=save_dir, progress=None, ) if not result.success: raise RuntimeError(result.error or result.status_message or "generate_music failed") audios = result.audios or [] if not audios: raise RuntimeError("generate_music returned no audio") first_path = audios[0].get("path") or "" if not first_path or not os.path.isfile(first_path): raise RuntimeError("first generated audio path missing or not a file") shutil.copy2(first_path, dst_path) finally: try: shutil.rmtree(save_dir, ignore_errors=True) except Exception: pass class BackendServicer(backend_pb2_grpc.BackendServicer): def __init__(self): self.model_path = None self.model_dir = None self.checkpoint_dir = None self.project_root = None self.options = {} self.dit_handler = None self.llm_handler = None def Health(self, request, context): return backend_pb2.Reply(message=b"OK") def LoadModel(self, request, context): try: self.options = _parse_options(list(getattr(request, "Options", []) or [])) model_path = getattr(request, "ModelPath", None) or "" model_name = (request.Model or "").strip() model_file = (getattr(request, "ModelFile", None) or "").strip() # Model dir: where we store checkpoints (always under LocalAI models path, never backend dir) if model_path and model_name: model_dir = os.path.join(model_path, model_name) elif model_file: model_dir = model_file else: model_dir = os.path.abspath(model_name or ".") self.model_dir = model_dir self.checkpoint_dir = os.path.join(model_dir, "checkpoints") self.project_root = model_dir self.model_path = os.path.join(self.checkpoint_dir, model_name or os.path.basename(model_dir.rstrip("/\\"))) config_path = model_name or os.path.basename(model_dir.rstrip("/\\")) os.makedirs(self.checkpoint_dir, exist_ok=True) self.dit_handler = AceStepHandler() # Patch handler so it uses our model dir instead of site-packages/checkpoints self.dit_handler._get_project_root = lambda: self.project_root device = self.options.get("device", "auto") use_flash = self.options.get("use_flash_attention", True) if isinstance(use_flash, str): use_flash = str(use_flash).lower() in ("1", "true", "yes") offload = self.options.get("offload_to_cpu", False) if isinstance(offload, str): offload = str(offload).lower() in ("1", "true", "yes") status_msg, ok = self.dit_handler.initialize_service( project_root=self.project_root, config_path=config_path, device=device, use_flash_attention=use_flash, compile_model=False, offload_to_cpu=offload, offload_dit_to_cpu=bool(self.options.get("offload_dit_to_cpu", False)), ) if not ok: return backend_pb2.Result(success=False, message=f"DiT init failed: {status_msg}") self.llm_handler = None if self.options.get("init_lm", True): lm_model = self.options.get("lm_model_path", "acestep-5Hz-lm-0.6B") # Ensure LM model is downloaded before initializing try: from pathlib import Path lm_success, lm_msg = ensure_lm_model( model_name=lm_model, checkpoints_dir=Path(self.checkpoint_dir), prefer_source=None, # Auto-detect HuggingFace vs ModelScope ) if not lm_success: print(f"[ace-step] Warning: LM model download failed: {lm_msg}", file=sys.stderr) # Continue anyway - LLM initialization will fail gracefully else: print(f"[ace-step] LM model ready: {lm_msg}", file=sys.stderr) except Exception as e: print(f"[ace-step] Warning: LM model download check failed: {e}", file=sys.stderr) # Continue anyway - LLM initialization will fail gracefully self.llm_handler = LLMHandler() lm_backend = (self.options.get("lm_backend") or "vllm").strip().lower() if lm_backend not in ("vllm", "pt"): lm_backend = "vllm" lm_status, lm_ok = self.llm_handler.initialize( checkpoint_dir=self.checkpoint_dir, lm_model_path=lm_model, backend=lm_backend, device=device, offload_to_cpu=offload, dtype=getattr(self.dit_handler, "dtype", None), ) if not lm_ok: self.llm_handler = None print(f"[ace-step] LM init failed (optional): {lm_status}", file=sys.stderr) print(f"[ace-step] LoadModel: model={self.model_path}, options={list(self.options.keys())}", file=sys.stderr) return backend_pb2.Result(success=True, message="Model loaded successfully") except Exception as err: return backend_pb2.Result(success=False, message=f"LoadModel error: {err}") def SoundGeneration(self, request, context): if not request.dst: return backend_pb2.Result(success=False, message="request.dst is required") use_simple = bool(request.text) if use_simple: payload = { "sample_query": request.text or "", "sample_mode": True, "thinking": True, "vocal_language": request.language or request.GetLanguage() or "en", "instrumental": request.instrumental if request.HasField("instrumental") else False, } else: caption = request.caption or request.GetCaption() or request.text payload = { "prompt": caption, "lyrics": request.lyrics or request.lyrics or "", "thinking": request.think if request.HasField("think") else False, "vocal_language": request.language or request.GetLanguage() or "en", } if request.HasField("bpm"): payload["bpm"] = request.bpm if request.HasField("keyscale") and request.keyscale: payload["key_scale"] = request.keyscale if request.HasField("timesignature") and request.timesignature: payload["time_signature"] = request.timesignature if request.HasField("duration") and request.duration: payload["audio_duration"] = int(request.duration) if request.duration else None if request.src: payload["src_audio_path"] = request.src _generate_audio_sync(self, payload, request.dst) return backend_pb2.Result(success=True, message="Sound generated successfully") def TTS(self, request, context): if not request.dst: return backend_pb2.Result(success=False, message="request.dst is required") payload = { "sample_query": request.text, "sample_mode": True, "thinking": False, "vocal_language": (request.language if request.language else "") or "en", "instrumental": False, } _generate_audio_sync(self, payload, request.dst) return backend_pb2.Result(success=True, message="TTS (music fallback) generated successfully") def serve(address): server = grpc.server( futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ("grpc.max_message_length", 50 * 1024 * 1024), ("grpc.max_send_message_length", 50 * 1024 * 1024), ("grpc.max_receive_message_length", 50 * 1024 * 1024), ], ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print(f"[ace-step] Server listening on {address}", file=sys.stderr) def shutdown(sig, frame): server.stop(0) sys.exit(0) signal.signal(signal.SIGINT, shutdown) signal.signal(signal.SIGTERM, shutdown) try: while True: import time time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--addr", default="localhost:50051", help="Listen address") args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/ace-step/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi PYTHON_VERSION="3.11" PYTHON_PATCH="14" PY_STANDALONE_TAG="20260203" installRequirements if [ ! -d ACE-Step-1.5 ]; then git clone https://github.com/ace-step/ACE-Step-1.5 cd ACE-Step-1.5/ if [ "x${USE_PIP}" == "xtrue" ]; then pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps . else uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --no-deps . fi fi ================================================ FILE: backend/python/ace-step/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu torch torchaudio torchvision # Core dependencies transformers>=4.51.0,<4.58.0 diffusers gradio matplotlib>=3.7.5 scipy>=1.10.1 soundfile>=0.13.1 loguru>=0.7.3 einops>=0.8.1 accelerate>=1.12.0 fastapi>=0.110.0 uvicorn[standard]>=0.27.0 numba>=0.63.1 vector-quantize-pytorch>=1.27.15 torchcodec>=0.9.1 torchao modelscope ================================================ FILE: backend/python/ace-step/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu128 torch torchaudio torchvision # Core dependencies transformers>=4.51.0,<4.58.0 diffusers gradio>=6.5.1 matplotlib>=3.7.5 scipy>=1.10.1 soundfile>=0.13.1 loguru>=0.7.3 einops>=0.8.1 accelerate>=1.12.0 fastapi>=0.110.0 uvicorn[standard]>=0.27.0 numba>=0.63.1 vector-quantize-pytorch>=1.27.15 torchcodec>=0.9.1 torchao modelscope ================================================ FILE: backend/python/ace-step/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio torchvision # Core dependencies transformers>=4.51.0,<4.58.0 diffusers gradio>=6.5.1 matplotlib>=3.7.5 scipy>=1.10.1 soundfile>=0.13.1 loguru>=0.7.3 einops>=0.8.1 accelerate>=1.12.0 fastapi>=0.110.0 uvicorn[standard]>=0.27.0 numba>=0.63.1 vector-quantize-pytorch>=1.27.15 torchcodec>=0.9.1 torchao modelscope ================================================ FILE: backend/python/ace-step/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0+rocm6.4 torchaudio torchvision # Core dependencies transformers>=4.51.0,<4.58.0 diffusers gradio>=6.5.1 matplotlib>=3.7.5 scipy>=1.10.1 soundfile>=0.13.1 loguru>=0.7.3 einops>=0.8.1 accelerate>=1.12.0 fastapi>=0.110.0 uvicorn[standard]>=0.27.0 numba>=0.63.1 vector-quantize-pytorch>=1.27.15 torchcodec>=0.9.1 torchao modelscope ================================================ FILE: backend/python/ace-step/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch torchaudio torchvision # Core dependencies transformers>=4.51.0,<4.58.0 diffusers gradio matplotlib>=3.7.5 scipy>=1.10.1 soundfile>=0.13.1 loguru>=0.7.3 einops>=0.8.1 accelerate>=1.12.0 fastapi>=0.110.0 uvicorn[standard]>=0.27.0 numba>=0.63.1 vector-quantize-pytorch>=1.27.15 torchcodec>=0.9.1 torchao modelscope # LoRA Training dependencies (optional) peft>=0.7.0 lightning>=2.0.0 ================================================ FILE: backend/python/ace-step/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio torchvision # Core dependencies transformers>=4.51.0,<4.58.0 diffusers gradio>=6.5.1 matplotlib>=3.7.5 scipy>=1.10.1 soundfile>=0.13.1 loguru>=0.7.3 einops>=0.8.1 accelerate>=1.12.0 fastapi>=0.110.0 uvicorn[standard]>=0.27.0 numba>=0.63.1 vector-quantize-pytorch>=1.27.15 torchcodec>=0.9.1 torchao modelscope ================================================ FILE: backend/python/ace-step/requirements-mps.txt ================================================ torch torchaudio torchvision # Core dependencies transformers>=4.51.0,<4.58.0 diffusers gradio matplotlib>=3.7.5 scipy>=1.10.1 soundfile>=0.13.1 loguru>=0.7.3 einops>=0.8.1 accelerate>=1.12.0 fastapi>=0.110.0 uvicorn[standard]>=0.27.0 numba>=0.63.1 vector-quantize-pytorch>=1.27.15 torchcodec>=0.9.1 torchao modelscope # LoRA Training dependencies (optional) peft>=0.7.0 lightning>=2.0.0 ================================================ FILE: backend/python/ace-step/requirements.txt ================================================ setuptools grpcio==1.76.0 protobuf certifi ================================================ FILE: backend/python/ace-step/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/ace-step/test.py ================================================ """ Tests for the ACE-Step gRPC backend. """ import os import tempfile import unittest import backend_pb2 import backend_pb2_grpc import grpc class TestACEStepBackend(unittest.TestCase): """Test Health, LoadModel, and SoundGeneration (minimal; no real model required).""" @classmethod def setUpClass(cls): port = os.environ.get("BACKEND_PORT", "50051") cls.channel = grpc.insecure_channel(f"localhost:{port}") cls.stub = backend_pb2_grpc.BackendStub(cls.channel) @classmethod def tearDownClass(cls): cls.channel.close() def test_health(self): response = self.stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b"OK") def test_load_model(self): response = self.stub.LoadModel(backend_pb2.ModelOptions(Model="ace-step-test")) self.assertTrue(response.success, response.message) def test_sound_generation_minimal(self): with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as f: dst = f.name try: req = backend_pb2.SoundGenerationRequest( text="upbeat pop song", model="ace-step-test", dst=dst, ) response = self.stub.SoundGeneration(req) self.assertTrue(response.success, response.message) self.assertTrue(os.path.exists(dst), f"Output file not created: {dst}") self.assertGreater(os.path.getsize(dst), 0) finally: if os.path.exists(dst): os.unlink(dst) if __name__ == "__main__": unittest.main() ================================================ FILE: backend/python/ace-step/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # Start backend in background (use env to avoid port conflict in parallel tests) export PYTHONUNBUFFERED=1 BACKEND_PORT=${BACKEND_PORT:-50051} python backend.py --addr "localhost:${BACKEND_PORT}" & BACKEND_PID=$! trap "kill $BACKEND_PID 2>/dev/null || true" EXIT sleep 3 export BACKEND_PORT runUnittests ================================================ FILE: backend/python/chatterbox/Makefile ================================================ .PHONY: chatterbox chatterbox: bash install.sh .PHONY: run run: chatterbox @echo "Running coqui..." bash run.sh @echo "coqui run." .PHONY: test test: chatterbox @echo "Testing coqui..." bash test.sh @echo "coqui tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/chatterbox/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for Chatterbox TTS """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc import torch import torchaudio as ta from chatterbox.tts import ChatterboxTTS from chatterbox.mtl_tts import ChatterboxMultilingualTTS import grpc import tempfile def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False def split_text_at_word_boundary(text, max_length=250): """ Split text at word boundaries without truncating words. Returns a list of text chunks. """ if not text or len(text) <= max_length: return [text] chunks = [] words = text.split() current_chunk = "" for word in words: # Check if adding this word would exceed the limit if len(current_chunk) + len(word) + 1 <= max_length: if current_chunk: current_chunk += " " + word else: current_chunk = word else: # If current chunk is not empty, add it to chunks if current_chunk: chunks.append(current_chunk) current_chunk = word else: # If a single word is longer than max_length, we have to include it anyway chunks.append(word) current_chunk = "" # Add the last chunk if it's not empty if current_chunk: chunks.append(current_chunk) return chunks def merge_audio_files(audio_files, output_path, sample_rate): """ Merge multiple audio files into a single audio file. """ if not audio_files: return if len(audio_files) == 1: # If only one file, just copy it import shutil shutil.copy2(audio_files[0], output_path) return # Load all audio files waveforms = [] for audio_file in audio_files: waveform, sr = ta.load(audio_file) if sr != sample_rate: # Resample if necessary resampler = ta.transforms.Resample(sr, sample_rate) waveform = resampler(waveform) waveforms.append(waveform) # Concatenate all waveforms merged_waveform = torch.cat(waveforms, dim=1) # Save the merged audio ta.save(output_path, merged_waveform, sample_rate) # Clean up temporary files for audio_file in audio_files: if os.path.exists(audio_file): os.remove(audio_file) _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): # Get device # device = "cuda" if request.CUDA else "cpu" if torch.cuda.is_available(): print("CUDA is available", file=sys.stderr) device = "cuda" else: print("CUDA is not available", file=sys.stderr) device = "cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") options = request.Options # empty dict self.options = {} # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when # generating the images for opt in options: if ":" not in opt: continue key, value = opt.split(":") # if value is a number, convert it to the appropriate type if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value self.AudioPath = None if os.path.isabs(request.AudioPath): self.AudioPath = request.AudioPath elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath): # get base path of modelFile modelFileBase = os.path.dirname(request.ModelFile) # modify LoraAdapter to be relative to modelFileBase self.AudioPath = os.path.join(modelFileBase, request.AudioPath) try: print("Preparing models, please wait", file=sys.stderr) if "multilingual" in self.options: # remove key from options del self.options["multilingual"] self.model = ChatterboxMultilingualTTS.from_pretrained(device=device) else: self.model = ChatterboxTTS.from_pretrained(device=device) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) def TTS(self, request, context): try: kwargs = {} if "language" in self.options: kwargs["language_id"] = self.options["language"] if self.AudioPath is not None: kwargs["audio_prompt_path"] = self.AudioPath # add options to kwargs kwargs.update(self.options) # Check if text exceeds 250 characters # (chatterbox does not support long text) # https://github.com/resemble-ai/chatterbox/issues/60 # https://github.com/resemble-ai/chatterbox/issues/110 if len(request.text) > 250: # Split text at word boundaries text_chunks = split_text_at_word_boundary(request.text, max_length=250) print(f"Splitting text into chunks of 250 characters: {len(text_chunks)}", file=sys.stderr) # Generate audio for each chunk temp_audio_files = [] for i, chunk in enumerate(text_chunks): # Generate audio for this chunk wav = self.model.generate(chunk, **kwargs) # Create temporary file for this chunk temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav') temp_file.close() ta.save(temp_file.name, wav, self.model.sr) temp_audio_files.append(temp_file.name) # Merge all audio files merge_audio_files(temp_audio_files, request.dst, self.model.sr) else: # Generate audio using ChatterboxTTS for short text wav = self.model.generate(request.text, **kwargs) # Save the generated audio ta.save(request.dst, wav, self.model.sr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/chatterbox/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation" if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then USE_PIP=true fi installRequirements ================================================ FILE: backend/python/chatterbox/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu accelerate torch torchaudio numpy>=1.24.0,<1.26.0 transformers # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster #chatterbox-tts==0.1.4 ================================================ FILE: backend/python/chatterbox/requirements-cublas12.txt ================================================ torch torchaudio transformers numpy>=1.24.0,<1.26.0 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster accelerate ================================================ FILE: backend/python/chatterbox/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio transformers numpy>=1.24.0,<1.26.0 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster accelerate ================================================ FILE: backend/python/chatterbox/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.9.1+rocm6.4 torchaudio==2.9.1+rocm6.4 transformers numpy>=1.24.0,<1.26.0 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster accelerate ================================================ FILE: backend/python/chatterbox/requirements-install.txt ================================================ # Build dependencies needed for packages installed from source (e.g., git dependencies) # When using --no-build-isolation, these must be installed in the venv first wheel setuptools packaging ================================================ FILE: backend/python/chatterbox/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch torchaudio transformers numpy>=1.24.0,<1.26.0 # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster accelerate oneccl_bind_pt==2.3.100+xpu optimum[openvino] setuptools ================================================ FILE: backend/python/chatterbox/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/ torch torchaudio transformers numpy>=1.24.0,<1.26.0 chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster accelerate ================================================ FILE: backend/python/chatterbox/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio transformers numpy>=1.24.0,<1.26.0 chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster accelerate ================================================ FILE: backend/python/chatterbox/requirements-mps.txt ================================================ torch torchaudio accelerate numpy>=1.24.0,<1.26.0 transformers # https://github.com/mudler/LocalAI/pull/6240#issuecomment-3329518289 chatterbox-tts@git+https://git@github.com/mudler/chatterbox.git@faster ================================================ FILE: backend/python/chatterbox/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging setuptools poetry ================================================ FILE: backend/python/chatterbox/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/chatterbox/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(30) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions()) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_tts(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions()) self.assertTrue(response.success) tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story") tts_response = stub.TTS(tts_request) self.assertIsNotNone(tts_response) except Exception as err: print(err) self.fail("TTS service failed") finally: self.tearDown() ================================================ FILE: backend/python/chatterbox/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/common/libbackend.sh ================================================ #!/usr/bin/env bash set -euo pipefail # # use the library by adding the following line to a script: # source $(dirname $0)/../common/libbackend.sh # # If you want to limit what targets a backend can be used on, set the variable LIMIT_TARGETS to a # space separated list of valid targets BEFORE sourcing the library, for example to only allow a backend # to be used on CUDA and CPU backends: # # LIMIT_TARGETS="cublas cpu" # source $(dirname $0)/../common/libbackend.sh # # You can use any valid BUILD_TYPE or BUILD_PROFILE, if you need to limit a backend to CUDA 12 only: # # LIMIT_TARGETS="cublas12" # source $(dirname $0)/../common/libbackend.sh # # You can switch between uv (conda-like) and pip installation methods by setting USE_PIP: # USE_PIP=true source $(dirname $0)/../common/libbackend.sh # # ===================== user-configurable defaults ===================== PYTHON_VERSION="${PYTHON_VERSION:-3.10}" # e.g. 3.10 / 3.11 / 3.12 / 3.13 PYTHON_PATCH="${PYTHON_PATCH:-18}" # e.g. 18 -> 3.10.18 ; 13 -> 3.11.13 PY_STANDALONE_TAG="${PY_STANDALONE_TAG:-20250818}" # release tag date # Enable/disable bundling of a portable Python build PORTABLE_PYTHON="${PORTABLE_PYTHON:-false}" # If you want to fully pin the filename (including tuned CPU targets), set: # PORTABLE_PY_FILENAME="cpython-3.10.18+20250818-x86_64_v3-unknown-linux-gnu-install_only.tar.gz" : "${PORTABLE_PY_FILENAME:=}" : "${PORTABLE_PY_SHA256:=}" # optional; if set we verify the download # ===================================================================== # Default to uv if USE_PIP is not set if [ "x${USE_PIP:-}" == "x" ]; then USE_PIP=false fi # ----------------------- helpers ----------------------- function _is_musl() { # detect musl (Alpine, etc) if command -v ldd >/dev/null 2>&1; then ldd --version 2>&1 | grep -qi musl && return 0 fi # busybox-ish fallback if command -v getconf >/dev/null 2>&1; then getconf GNU_LIBC_VERSION >/dev/null 2>&1 || return 0 fi return 1 } function _triple() { local os="" arch="" libc="gnu" case "$(uname -s)" in Linux*) os="unknown-linux" ;; Darwin*) os="apple-darwin" ;; MINGW*|MSYS*|CYGWIN*) os="pc-windows-msvc" ;; # best-effort for Git Bash *) echo "Unsupported OS $(uname -s)"; exit 1;; esac case "$(uname -m)" in x86_64) arch="x86_64" ;; aarch64|arm64) arch="aarch64" ;; armv7l) arch="armv7" ;; i686|i386) arch="i686" ;; ppc64le) arch="ppc64le" ;; s390x) arch="s390x" ;; riscv64) arch="riscv64" ;; *) echo "Unsupported arch $(uname -m)"; exit 1;; esac if [[ "$os" == "unknown-linux" ]]; then if _is_musl; then libc="musl" else libc="gnu" fi echo "${arch}-${os}-${libc}" else echo "${arch}-${os}" fi } function _portable_dir() { echo "${EDIR}/python" } function _portable_bin() { # python-build-standalone puts python in ./bin echo "$(_portable_dir)/bin" } function _portable_python() { if [ -x "$(_portable_bin)/python3" ]; then echo "$(_portable_bin)/python3" else echo "$(_portable_bin)/python" fi } # macOS loader env for the portable CPython _macosPortableEnv() { if [ "$(uname -s)" = "Darwin" ]; then export DYLD_LIBRARY_PATH="$(_portable_dir)/lib${DYLD_LIBRARY_PATH:+:${DYLD_LIBRARY_PATH}}" export DYLD_FALLBACK_LIBRARY_PATH="$(_portable_dir)/lib${DYLD_FALLBACK_LIBRARY_PATH:+:${DYLD_FALLBACK_LIBRARY_PATH}}" fi } # Good hygiene on macOS for downloaded/extracted trees _unquarantinePortablePython() { if [ "$(uname -s)" = "Darwin" ]; then command -v xattr >/dev/null 2>&1 && xattr -dr com.apple.quarantine "$(_portable_dir)" || true fi } # ------------------ ### PORTABLE PYTHON ------------------ function ensurePortablePython() { local pdir="$(_portable_dir)" local pbin="$(_portable_bin)" local pyexe if [ -x "${pbin}/python3" ] || [ -x "${pbin}/python" ]; then _macosPortableEnv return 0 fi mkdir -p "${pdir}" local triple="$(_triple)" local full_ver="${PYTHON_VERSION}.${PYTHON_PATCH}" local fn="" if [ -n "${PORTABLE_PY_FILENAME}" ]; then fn="${PORTABLE_PY_FILENAME}" else # generic asset name: cpython-+--install_only.tar.gz fn="cpython-${full_ver}+${PY_STANDALONE_TAG}-${triple}-install_only.tar.gz" fi local url="https://github.com/astral-sh/python-build-standalone/releases/download/${PY_STANDALONE_TAG}/${fn}" local tmp="${pdir}/${fn}" echo "Downloading portable Python: ${fn}" # curl with retries; fall back to wget if needed if command -v curl >/dev/null 2>&1; then curl -L --fail --retry 3 --retry-delay 1 -o "${tmp}" "${url}" else wget -O "${tmp}" "${url}" fi if [ -n "${PORTABLE_PY_SHA256}" ]; then echo "${PORTABLE_PY_SHA256} ${tmp}" | sha256sum -c - fi echo "Extracting ${fn} -> ${pdir}" # always a .tar.gz (we purposely choose install_only) tar -xzf "${tmp}" -C "${pdir}" rm -f "${tmp}" # Some archives nest a directory; if so, flatten to ${pdir} # Find the first dir with a 'bin/python*' local inner inner="$(find "${pdir}" -type f -path "*/bin/python*" -maxdepth 3 2>/dev/null | head -n1 || true)" if [ -n "${inner}" ]; then local inner_root inner_root="$(dirname "$(dirname "${inner}")")" # .../bin -> root if [ "${inner_root}" != "${pdir}" ]; then # move contents up one level shopt -s dotglob mv "${inner_root}/"* "${pdir}/" rm -rf "${inner_root}" shopt -u dotglob fi fi _unquarantinePortablePython _macosPortableEnv # Make sure it's runnable pyexe="$(_portable_python)" "${pyexe}" -V } # init handles the setup of the library function init() { BACKEND_NAME=${PWD##*/} MY_DIR=$(realpath "$(dirname "$0")") BUILD_PROFILE=$(getBuildProfile) EDIR=${MY_DIR} if [ "x${ENV_DIR:-}" != "x" ]; then EDIR=${ENV_DIR} fi if [ ! -z "${LIMIT_TARGETS:-}" ]; then isValidTarget=$(checkTargets ${LIMIT_TARGETS}) if [ ${isValidTarget} != true ]; then echo "${BACKEND_NAME} can only be used on the following targets: ${LIMIT_TARGETS}" exit 0 fi fi echo "Initializing libbackend for ${BACKEND_NAME}" } # getBuildProfile will inspect the system to determine which build profile is appropriate: # returns one of the following: # - cublas12 # - cublas13 # - hipblas # - intel function getBuildProfile() { if [ x"${BUILD_TYPE:-}" == "xcublas" ] || [ x"${BUILD_TYPE:-}" == "xl4t" ]; then if [ ! -z "${CUDA_MAJOR_VERSION:-}" ]; then echo ${BUILD_TYPE}${CUDA_MAJOR_VERSION} else echo ${BUILD_TYPE} fi return 0 fi if [ -d "/opt/intel" ]; then echo "intel" return 0 fi if [ -n "${BUILD_TYPE:-}" ]; then echo ${BUILD_TYPE} return 0 fi echo "cpu" } # Make the venv relocatable: # - rewrite venv/bin/python{,3} to relative symlinks into $(_portable_dir) # - normalize entrypoint shebangs to /usr/bin/env python3 # - optionally update pyvenv.cfg to point to the portable Python directory (only at runtime) # Usage: _makeVenvPortable [--update-pyvenv-cfg] _makeVenvPortable() { local update_pyvenv_cfg=false if [ "${1:-}" = "--update-pyvenv-cfg" ]; then update_pyvenv_cfg=true fi local venv_dir="${EDIR}/venv" local vbin="${venv_dir}/bin" [ -d "${vbin}" ] || return 0 # 1) Replace python symlinks with relative ones to ../../python/bin/python3 # (venv/bin -> venv -> EDIR -> python/bin) local rel_py='../../python/bin/python3' for name in python3 python; do if [ -e "${vbin}/${name}" ] || [ -L "${vbin}/${name}" ]; then rm -f "${vbin}/${name}" fi done ln -s "${rel_py}" "${vbin}/python3" ln -s "python3" "${vbin}/python" # 2) Update pyvenv.cfg to point to the portable Python directory (only at runtime) # Use absolute path resolved at runtime so it works when the venv is copied if [ "$update_pyvenv_cfg" = "true" ]; then local pyvenv_cfg="${venv_dir}/pyvenv.cfg" if [ -f "${pyvenv_cfg}" ]; then local portable_dir="$(_portable_dir)" # Resolve to absolute path - this ensures it works when the backend is copied # Only resolve if the directory exists (it should if ensurePortablePython was called) if [ -d "${portable_dir}" ]; then portable_dir="$(cd "${portable_dir}" && pwd)" else # Fallback to relative path if directory doesn't exist yet portable_dir="../python" fi local sed_i=(sed -i) # macOS/BSD sed needs a backup suffix; GNU sed doesn't. Make it portable: if sed --version >/dev/null 2>&1; then sed_i=(sed -i) else sed_i=(sed -i '') fi # Update the home field in pyvenv.cfg # Handle both absolute paths (starting with /) and relative paths if grep -q "^home = " "${pyvenv_cfg}"; then "${sed_i[@]}" "s|^home = .*|home = ${portable_dir}|" "${pyvenv_cfg}" else # If home field doesn't exist, add it echo "home = ${portable_dir}" >> "${pyvenv_cfg}" fi fi fi # 3) Rewrite shebangs of entry points to use env, so the venv is relocatable # Only touch text files that start with #! and reference the current venv. local ve_abs="${vbin}/python" local sed_i=(sed -i) # macOS/BSD sed needs a backup suffix; GNU sed doesn't. Make it portable: if sed --version >/dev/null 2>&1; then sed_i=(sed -i) else sed_i=(sed -i '') fi for f in "${vbin}"/*; do [ -f "$f" ] || continue # Fast path: check first two bytes (#!) head -c2 "$f" 2>/dev/null | grep -q '^#!' || continue # Only rewrite if the shebang mentions the (absolute) venv python if head -n1 "$f" | grep -Fq "${ve_abs}"; then "${sed_i[@]}" '1s|^#!.*$|#!/usr/bin/env python3|' "$f" chmod +x "$f" 2>/dev/null || true fi done } # ensureVenv makes sure that the venv for the backend both exists, and is activated. # # This function is idempotent, so you can call it as many times as you want and it will # always result in an activated virtual environment function ensureVenv() { local interpreter="" if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -e "$(_portable_python)" ]; then echo "Using portable Python" ensurePortablePython interpreter="$(_portable_python)" else # Prefer system python${PYTHON_VERSION}, else python3, else fall back to bundled if command -v python${PYTHON_VERSION} >/dev/null 2>&1; then interpreter="python${PYTHON_VERSION}" elif command -v python3 >/dev/null 2>&1; then interpreter="python3" else echo "No suitable system Python found, bootstrapping portable build..." ensurePortablePython interpreter="$(_portable_python)" fi fi if [ ! -d "${EDIR}/venv" ]; then if [ "x${USE_PIP}" == "xtrue" ]; then "${interpreter}" -m venv --copies "${EDIR}/venv" source "${EDIR}/venv/bin/activate" "${interpreter}" -m pip install --upgrade pip else if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then uv venv --python "${interpreter}" "${EDIR}/venv" else uv venv --python "${PYTHON_VERSION}" "${EDIR}/venv" fi fi if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then # During install, only update symlinks and shebangs, not pyvenv.cfg _makeVenvPortable fi fi # We call it here to make sure that when we source a venv we can still use python as expected if [ -x "$(_portable_python)" ]; then _macosPortableEnv fi if [ "x${VIRTUAL_ENV:-}" != "x${EDIR}/venv" ]; then source "${EDIR}/venv/bin/activate" fi } function runProtogen() { ensureVenv if [ "x${USE_PIP}" == "xtrue" ]; then pip install grpcio-tools else uv pip install grpcio-tools fi pushd "${EDIR}" >/dev/null # use the venv python (ensures correct interpreter & sys.path) python -m grpc_tools.protoc -I../../ -I./ --python_out=. --grpc_python_out=. backend.proto popd >/dev/null } # installRequirements looks for several requirements files and if they exist runs the install for them in order # # - requirements-install.txt # - requirements.txt # - requirements-${BUILD_TYPE}.txt # - requirements-${BUILD_PROFILE}.txt # # BUILD_PROFILE is a more specific version of BUILD_TYPE, ex: cuda-12 or cuda-13 # it can also include some options that we do not have BUILD_TYPES for, ex: intel # # NOTE: for BUILD_PROFILE==intel, this function does NOT automatically use the Intel python package index. # you may want to add the following line to a requirements-intel.txt if you use one: # # --index-url https://download.pytorch.org/whl/xpu # # If you need to add extra flags into the pip install command you can do so by setting the variable EXTRA_PIP_INSTALL_FLAGS # before calling installRequirements. For example: # # source $(dirname $0)/../common/libbackend.sh # EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation" # installRequirements function installRequirements() { ensureVenv declare -a requirementFiles=( "${EDIR}/requirements-install.txt" "${EDIR}/requirements.txt" "${EDIR}/requirements-${BUILD_TYPE:-}.txt" ) if [ "x${BUILD_TYPE:-}" != "x${BUILD_PROFILE}" ]; then requirementFiles+=("${EDIR}/requirements-${BUILD_PROFILE}.txt") fi if [ "x${BUILD_TYPE:-}" == "x" ]; then requirementFiles+=("${EDIR}/requirements-cpu.txt") fi requirementFiles+=("${EDIR}/requirements-after.txt") if [ "x${BUILD_TYPE:-}" != "x${BUILD_PROFILE}" ]; then requirementFiles+=("${EDIR}/requirements-${BUILD_PROFILE}-after.txt") fi # This is needed to build wheels that e.g. depends on Python.h if [ "x${PORTABLE_PYTHON}" == "xtrue" ]; then export C_INCLUDE_PATH="${C_INCLUDE_PATH:-}:$(_portable_dir)/include/python${PYTHON_VERSION}" fi for reqFile in ${requirementFiles[@]}; do if [ -f "${reqFile}" ]; then echo "starting requirements install for ${reqFile}" if [ "x${USE_PIP}" == "xtrue" ]; then pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement "${reqFile}" else uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} --requirement "${reqFile}" fi echo "finished requirements install for ${reqFile}" fi done runProtogen } # startBackend discovers and runs the backend GRPC server # # You can specify a specific backend file to execute by setting BACKEND_FILE before calling startBackend. # example: # # source ../common/libbackend.sh # BACKEND_FILE="${MY_DIR}/source/backend.py" # startBackend $@ # # valid filenames for autodiscovered backend servers are: # - server.py # - backend.py # - ${BACKEND_NAME}.py function startBackend() { ensureVenv # Update pyvenv.cfg before running to ensure paths are correct for current location # This is critical when the backend position is dynamic (e.g., copied from container) if [ "x${PORTABLE_PYTHON}" == "xtrue" ] || [ -x "$(_portable_python)" ]; then _makeVenvPortable --update-pyvenv-cfg fi # Set up GPU library paths if a lib directory exists # This allows backends to include their own GPU libraries (CUDA, ROCm, etc.) if [ -d "${EDIR}/lib" ]; then export LD_LIBRARY_PATH="${EDIR}/lib:${LD_LIBRARY_PATH:-}" echo "Added ${EDIR}/lib to LD_LIBRARY_PATH for GPU libraries" fi if [ ! -z "${BACKEND_FILE:-}" ]; then exec "${EDIR}/venv/bin/python" "${BACKEND_FILE}" "$@" elif [ -e "${MY_DIR}/server.py" ]; then exec "${EDIR}/venv/bin/python" "${MY_DIR}/server.py" "$@" elif [ -e "${MY_DIR}/backend.py" ]; then exec "${EDIR}/venv/bin/python" "${MY_DIR}/backend.py" "$@" elif [ -e "${MY_DIR}/${BACKEND_NAME}.py" ]; then exec "${EDIR}/venv/bin/python" "${MY_DIR}/${BACKEND_NAME}.py" "$@" fi } # runUnittests discovers and runs python unittests # # You can specify a specific test file to use by setting TEST_FILE before calling runUnittests. # example: # # source ../common/libbackend.sh # TEST_FILE="${MY_DIR}/source/test.py" # runUnittests $@ # # be default a file named test.py in the backends directory will be used function runUnittests() { ensureVenv if [ ! -z "${TEST_FILE:-}" ]; then testDir=$(dirname "$(realpath "${TEST_FILE}")") testFile=$(basename "${TEST_FILE}") pushd "${testDir}" >/dev/null python -m unittest "${testFile}" popd >/dev/null elif [ -f "${MY_DIR}/test.py" ]; then pushd "${MY_DIR}" >/dev/null python -m unittest test.py popd >/dev/null else echo "no tests defined for ${BACKEND_NAME}" fi } ################################################################################## # Below here are helper functions not intended to be used outside of the library # ################################################################################## # checkTargets determines if the current BUILD_TYPE or BUILD_PROFILE is in a list of valid targets function checkTargets() { targets=$@ declare -a targets=($targets) for target in ${targets[@]}; do if [ "x${BUILD_TYPE:-}" == "x${target}" ]; then echo true; return 0 fi if [ "x${BUILD_PROFILE}" == "x${target}" ]; then echo true; return 0 fi done echo false } init ================================================ FILE: backend/python/common/template/Makefile ================================================ .DEFAULT_GOAL := install .PHONY: install install: bash install.sh .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/common/template/backend.py ================================================ #!/usr/bin/env python3 import grpc import backend_pb2 import backend_pb2_grpc ================================================ FILE: backend/python/common/template/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi installRequirements ================================================ FILE: backend/python/common/template/protogen.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runProtogen ================================================ FILE: backend/python/common/template/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch ================================================ FILE: backend/python/common/template/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch==2.8.0 oneccl_bind_pt==2.8.0+xpu optimum[openvino] ================================================ FILE: backend/python/common/template/requirements.txt ================================================ grpcio==1.78.1 protobuf grpcio-tools ================================================ FILE: backend/python/common/template/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/common/template/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/coqui/Makefile ================================================ .PHONY: coqui coqui: bash install.sh .PHONY: run run: coqui @echo "Running coqui..." bash run.sh @echo "coqui run." .PHONY: test test: coqui @echo "Testing coqui..." bash test.sh @echo "coqui tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/coqui/README.md ================================================ # Creating a separate environment for coqui project ``` make coqui ``` # Testing the gRPC server ``` make test ``` ================================================ FILE: backend/python/coqui/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for Coqui TTS """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc import torch from TTS.api import TTS import grpc _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) COQUI_LANGUAGE = os.environ.get('COQUI_LANGUAGE', None) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): # Get device # device = "cuda" if request.CUDA else "cpu" if torch.cuda.is_available(): print("CUDA is available", file=sys.stderr) device = "cuda" else: print("CUDA is not available", file=sys.stderr) device = "cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") self.AudioPath = None # List available 🐸TTS models print(TTS().list_models()) if os.path.isabs(request.AudioPath): self.AudioPath = request.AudioPath elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath): # get base path of modelFile modelFileBase = os.path.dirname(request.ModelFile) # modify LoraAdapter to be relative to modelFileBase self.AudioPath = os.path.join(modelFileBase, request.AudioPath) try: print("Preparing models, please wait", file=sys.stderr) self.tts = TTS(request.Model).to(device) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) def TTS(self, request, context): try: # if model is multilingual add language from request or env as fallback lang = request.language or COQUI_LANGUAGE if lang == "": lang = None if self.tts.is_multi_lingual and lang is None: return backend_pb2.Result(success=False, message=f"Model is multi-lingual, but no language was provided") # if model is multi-speaker, use speaker_wav or the speaker_id from request.voice if self.tts.is_multi_speaker and self.AudioPath is None and request.voice is None: return backend_pb2.Result(success=False, message=f"Model is multi-speaker, but no speaker was provided") if self.tts.is_multi_speaker and request.voice is not None: self.tts.tts_to_file(text=request.text, speaker=request.voice, language=lang, file_path=request.dst) else: self.tts.tts_to_file(text=request.text, speaker_wav=self.AudioPath, language=lang, file_path=request.dst) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/coqui/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi installRequirements ================================================ FILE: backend/python/coqui/requirements-cpu.txt ================================================ transformers==4.48.3 accelerate torch==2.4.1 coqui-tts ================================================ FILE: backend/python/coqui/requirements-cublas12.txt ================================================ torch==2.4.1 torchaudio==2.4.1 transformers==4.48.3 accelerate coqui-tts ================================================ FILE: backend/python/coqui/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0+rocm6.4 torchaudio==2.8.0+rocm6.4 transformers==4.48.3 accelerate coqui-tts ================================================ FILE: backend/python/coqui/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch==2.8.0+xpu torchaudio==2.8.0+xpu optimum[openvino] setuptools transformers==4.48.3 accelerate coqui-tts ================================================ FILE: backend/python/coqui/requirements-mps.txt ================================================ torch==2.7.1 transformers==4.48.3 accelerate coqui-tts ================================================ FILE: backend/python/coqui/requirements.txt ================================================ grpcio==1.78.1 protobuf certifi packaging==24.1 ================================================ FILE: backend/python/coqui/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/coqui/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(30) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits")) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_tts(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits")) self.assertTrue(response.success) tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story") tts_response = stub.TTS(tts_request) self.assertIsNotNone(tts_response) except Exception as err: print(err) self.fail("TTS service failed") finally: self.tearDown() ================================================ FILE: backend/python/coqui/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/diffusers/Makefile ================================================ export CONDA_ENV_PATH = "diffusers.yml" ifeq ($(BUILD_TYPE), hipblas) export CONDA_ENV_PATH = "diffusers-rocm.yml" endif # Intel GPU are supposed to have dependencies installed in the main python # environment, so we skip conda installation for SYCL builds. # https://github.com/intel/intel-extension-for-pytorch/issues/538 ifneq (,$(findstring sycl,$(BUILD_TYPE))) export SKIP_CONDA=1 endif .PHONY: diffusers diffusers: bash install.sh .PHONY: run run: diffusers @echo "Running diffusers..." bash run.sh @echo "Diffusers run." test: diffusers bash test.sh .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/diffusers/README.md ================================================ # LocalAI Diffusers Backend This backend provides gRPC access to Hugging Face diffusers pipelines with dynamic pipeline loading. ## Creating a separate environment for the diffusers project ``` make diffusers ``` ## Dynamic Pipeline Loader The diffusers backend includes a dynamic pipeline loader (`diffusers_dynamic_loader.py`) that automatically discovers and loads diffusers pipelines at runtime. This eliminates the need for per-pipeline conditional statements - new pipelines added to diffusers become available automatically without code changes. ### How It Works 1. **Pipeline Discovery**: On first use, the loader scans the `diffusers` package to find all classes that inherit from `DiffusionPipeline`. 2. **Registry Caching**: Discovery results are cached for the lifetime of the process to avoid repeated scanning. 3. **Task Aliases**: The loader automatically derives task aliases from class names (e.g., "text-to-image", "image-to-image", "inpainting") without hardcoding. 4. **Multiple Resolution Methods**: Pipelines can be resolved by: - Exact class name (e.g., `StableDiffusionPipeline`) - Task alias (e.g., `text-to-image`, `img2img`) - Model ID (uses HuggingFace Hub to infer pipeline type) ### Usage Examples ```python from diffusers_dynamic_loader import ( load_diffusers_pipeline, get_available_pipelines, get_available_tasks, resolve_pipeline_class, discover_diffusers_classes, get_available_classes, ) # List all available pipelines pipelines = get_available_pipelines() print(f"Available pipelines: {pipelines[:10]}...") # List all task aliases tasks = get_available_tasks() print(f"Available tasks: {tasks}") # Resolve a pipeline class by name cls = resolve_pipeline_class(class_name="StableDiffusionPipeline") # Resolve by task alias cls = resolve_pipeline_class(task="stable-diffusion") # Load and instantiate a pipeline pipe = load_diffusers_pipeline( class_name="StableDiffusionPipeline", model_id="runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ) # Load from single file pipe = load_diffusers_pipeline( class_name="StableDiffusionPipeline", model_id="/path/to/model.safetensors", from_single_file=True, torch_dtype=torch.float16 ) # Discover other diffusers classes (schedulers, models, etc.) schedulers = discover_diffusers_classes("SchedulerMixin") print(f"Available schedulers: {list(schedulers.keys())[:5]}...") # Get list of available scheduler classes scheduler_list = get_available_classes("SchedulerMixin") ``` ### Generic Class Discovery The dynamic loader can discover not just pipelines but any class type from diffusers: ```python # Discover all scheduler classes schedulers = discover_diffusers_classes("SchedulerMixin") # Discover all model classes models = discover_diffusers_classes("ModelMixin") # Get a sorted list of available classes scheduler_names = get_available_classes("SchedulerMixin") ``` ### Special Pipeline Handling Most pipelines are loaded dynamically through `load_diffusers_pipeline()`. Only pipelines requiring truly custom initialization logic are handled explicitly: - `FluxTransformer2DModel`: Requires quantization and custom transformer loading (cannot use dynamic loader) - `WanPipeline` / `WanImageToVideoPipeline`: Uses dynamic loader with special VAE (float32 dtype) - `SanaPipeline`: Uses dynamic loader with post-load dtype conversion for VAE/text encoder - `StableVideoDiffusionPipeline`: Uses dynamic loader with CPU offload handling - `VideoDiffusionPipeline`: Alias for DiffusionPipeline with video flags All other pipelines (StableDiffusionPipeline, StableDiffusionXLPipeline, FluxPipeline, etc.) are loaded purely through the dynamic loader. ### Error Handling When a pipeline cannot be resolved, the loader provides helpful error messages listing available pipelines and tasks: ``` ValueError: Unknown pipeline class 'NonExistentPipeline'. Available pipelines: AnimateDiffPipeline, AnimateDiffVideoToVideoPipeline, ... ``` ## Environment Variables | Variable | Default | Description | |----------|---------|-------------| | `COMPEL` | `0` | Enable Compel for prompt weighting | | `SD_EMBED` | `0` | Enable sd_embed for prompt weighting | | `XPU` | `0` | Enable Intel XPU support | | `CLIPSKIP` | `1` | Enable CLIP skip support | | `SAFETENSORS` | `1` | Use safetensors format | | `CHUNK_SIZE` | `8` | Decode chunk size for video | | `FPS` | `7` | Video frames per second | | `DISABLE_CPU_OFFLOAD` | `0` | Disable CPU offload | | `FRAMES` | `64` | Number of video frames | | `BFL_REPO` | `ChuckMcSneed/FLUX.1-dev` | Flux base repo | | `PYTHON_GRPC_MAX_WORKERS` | `1` | Max gRPC workers | ## Running Tests ```bash ./test.sh ``` The test suite includes: - Unit tests for the dynamic loader (`test_dynamic_loader.py`) - Integration tests for the gRPC backend (`test.py`) ================================================ FILE: backend/python/diffusers/backend.py ================================================ #!/usr/bin/env python3 """ LocalAI Diffusers Backend This backend provides gRPC access to diffusers pipelines with dynamic pipeline loading. New pipelines added to diffusers become available automatically without code changes. """ from concurrent import futures import traceback import argparse from collections import defaultdict from enum import Enum import signal import sys import time import os from PIL import Image import torch import backend_pb2 import backend_pb2_grpc import grpc # Import dynamic loader for pipeline discovery from diffusers_dynamic_loader import ( get_pipeline_registry, resolve_pipeline_class, get_available_pipelines, load_diffusers_pipeline, ) # Import specific items still needed for special cases and safety checker from diffusers import DiffusionPipeline, ControlNetModel from diffusers import FluxPipeline, FluxTransformer2DModel, AutoencoderKLWan from diffusers.pipelines.stable_diffusion import safety_checker from diffusers.utils import load_image, export_to_video from compel import Compel, ReturnedEmbeddingsType from optimum.quanto import freeze, qfloat8, quantize from transformers import T5EncoderModel from safetensors.torch import load_file # Try to import sd_embed - it might not always be available try: from sd_embed.embedding_funcs import ( get_weighted_text_embeddings_sd15, get_weighted_text_embeddings_sdxl, get_weighted_text_embeddings_sd3, get_weighted_text_embeddings_flux1, ) SD_EMBED_AVAILABLE = True except ImportError: get_weighted_text_embeddings_sd15 = None get_weighted_text_embeddings_sdxl = None get_weighted_text_embeddings_sd3 = None get_weighted_text_embeddings_flux1 = None SD_EMBED_AVAILABLE = False # Import LTX-2 specific utilities from diffusers.pipelines.ltx2.export_utils import encode_video as ltx2_encode_video from diffusers import LTX2VideoTransformer3DModel, GGUFQuantizationConfig _ONE_DAY_IN_SECONDS = 60 * 60 * 24 COMPEL = os.environ.get("COMPEL", "0") == "1" SD_EMBED = os.environ.get("SD_EMBED", "0") == "1" # Warn if SD_EMBED is enabled but the module is not available if SD_EMBED and not SD_EMBED_AVAILABLE: print("WARNING: SD_EMBED is enabled but sd_embed module is not available. Falling back to standard prompt processing.", file=sys.stderr) XPU = os.environ.get("XPU", "0") == "1" CLIPSKIP = os.environ.get("CLIPSKIP", "1") == "1" SAFETENSORS = os.environ.get("SAFETENSORS", "1") == "1" CHUNK_SIZE = os.environ.get("CHUNK_SIZE", "8") FPS = os.environ.get("FPS", "7") DISABLE_CPU_OFFLOAD = os.environ.get("DISABLE_CPU_OFFLOAD", "0") == "1" FRAMES = os.environ.get("FRAMES", "64") if XPU: print(torch.xpu.get_device_name(0)) # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # https://github.com/CompVis/stable-diffusion/issues/239#issuecomment-1627615287 def sc(self, clip_input, images): return images, [False for i in images] # edit the StableDiffusionSafetyChecker class so that, when called, it just returns the images and an array of True values safety_checker.StableDiffusionSafetyChecker.forward = sc from diffusers.schedulers import ( DDIMScheduler, DPMSolverMultistepScheduler, DPMSolverSinglestepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, HeunDiscreteScheduler, KDPM2AncestralDiscreteScheduler, KDPM2DiscreteScheduler, LMSDiscreteScheduler, PNDMScheduler, UniPCMultistepScheduler, ) def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False # The scheduler list mapping was taken from here: https://github.com/neggles/animatediff-cli/blob/6f336f5f4b5e38e85d7f06f1744ef42d0a45f2a7/src/animatediff/schedulers.py#L39 # Credits to https://github.com/neggles # See https://github.com/huggingface/diffusers/issues/4167 for more details on sched mapping from A1111 class DiffusionScheduler(str, Enum): ddim = "ddim" # DDIM pndm = "pndm" # PNDM heun = "heun" # Heun unipc = "unipc" # UniPC euler = "euler" # Euler euler_a = "euler_a" # Euler a lms = "lms" # LMS k_lms = "k_lms" # LMS Karras dpm_2 = "dpm_2" # DPM2 k_dpm_2 = "k_dpm_2" # DPM2 Karras dpm_2_a = "dpm_2_a" # DPM2 a k_dpm_2_a = "k_dpm_2_a" # DPM2 a Karras dpmpp_2m = "dpmpp_2m" # DPM++ 2M k_dpmpp_2m = "k_dpmpp_2m" # DPM++ 2M Karras dpmpp_sde = "dpmpp_sde" # DPM++ SDE k_dpmpp_sde = "k_dpmpp_sde" # DPM++ SDE Karras dpmpp_2m_sde = "dpmpp_2m_sde" # DPM++ 2M SDE k_dpmpp_2m_sde = "k_dpmpp_2m_sde" # DPM++ 2M SDE Karras def get_scheduler(name: str, config: dict = {}): is_karras = name.startswith("k_") if is_karras: # strip the k_ prefix and add the karras sigma flag to config name = name.lstrip("k_") config["use_karras_sigmas"] = True if name == DiffusionScheduler.ddim: sched_class = DDIMScheduler elif name == DiffusionScheduler.pndm: sched_class = PNDMScheduler elif name == DiffusionScheduler.heun: sched_class = HeunDiscreteScheduler elif name == DiffusionScheduler.unipc: sched_class = UniPCMultistepScheduler elif name == DiffusionScheduler.euler: sched_class = EulerDiscreteScheduler elif name == DiffusionScheduler.euler_a: sched_class = EulerAncestralDiscreteScheduler elif name == DiffusionScheduler.lms: sched_class = LMSDiscreteScheduler elif name == DiffusionScheduler.dpm_2: # Equivalent to DPM2 in K-Diffusion sched_class = KDPM2DiscreteScheduler elif name == DiffusionScheduler.dpm_2_a: # Equivalent to `DPM2 a`` in K-Diffusion sched_class = KDPM2AncestralDiscreteScheduler elif name == DiffusionScheduler.dpmpp_2m: # Equivalent to `DPM++ 2M` in K-Diffusion sched_class = DPMSolverMultistepScheduler config["algorithm_type"] = "dpmsolver++" config["solver_order"] = 2 elif name == DiffusionScheduler.dpmpp_sde: # Equivalent to `DPM++ SDE` in K-Diffusion sched_class = DPMSolverSinglestepScheduler elif name == DiffusionScheduler.dpmpp_2m_sde: # Equivalent to `DPM++ 2M SDE` in K-Diffusion sched_class = DPMSolverMultistepScheduler config["algorithm_type"] = "sde-dpmsolver++" else: raise ValueError(f"Invalid scheduler '{'k_' if is_karras else ''}{name}'") return sched_class.from_config(config) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): def _load_pipeline(self, request, modelFile, fromSingleFile, torchType, variant, device_map=None): """ Load a diffusers pipeline dynamically using the dynamic loader. This method uses load_diffusers_pipeline() for most pipelines, falling back to explicit handling only for pipelines requiring custom initialization (e.g., quantization, special VAE handling). Args: request: The gRPC request containing pipeline configuration modelFile: Path to the model file (for single file loading) fromSingleFile: Whether to use from_single_file() vs from_pretrained() torchType: The torch dtype to use variant: Model variant (e.g., "fp16") device_map: Device mapping strategy (e.g., "auto" for multi-GPU) Returns: The loaded pipeline instance """ pipeline_type = request.PipelineType # Handle IMG2IMG request flag with default pipeline if request.IMG2IMG and pipeline_type == "": pipeline_type = "StableDiffusionImg2ImgPipeline" # ================================================================ # Special cases requiring custom initialization logic # Only handle pipelines that truly need custom code (quantization, # special VAE handling, etc.). All other pipelines use dynamic loading. # ================================================================ # FluxTransformer2DModel - requires quantization and custom transformer loading if pipeline_type == "FluxTransformer2DModel": dtype = torch.bfloat16 bfl_repo = os.environ.get("BFL_REPO", "ChuckMcSneed/FLUX.1-dev") transformer = FluxTransformer2DModel.from_single_file(modelFile, torch_dtype=dtype, device_map=device_map) quantize(transformer, weights=qfloat8) freeze(transformer) text_encoder_2 = T5EncoderModel.from_pretrained(bfl_repo, subfolder="text_encoder_2", torch_dtype=dtype, device_map=device_map) quantize(text_encoder_2, weights=qfloat8) freeze(text_encoder_2) pipe = FluxPipeline.from_pretrained(bfl_repo, transformer=None, text_encoder_2=None, torch_dtype=dtype, device_map=device_map) pipe.transformer = transformer pipe.text_encoder_2 = text_encoder_2 if request.LowVRAM: pipe.enable_model_cpu_offload() return pipe # WanPipeline - requires special VAE with float32 dtype if pipeline_type == "WanPipeline": vae = AutoencoderKLWan.from_pretrained( request.Model, subfolder="vae", torch_dtype=torch.float32, device_map=device_map ) pipe = load_diffusers_pipeline( class_name="WanPipeline", model_id=request.Model, vae=vae, torch_dtype=torchType, device_map=device_map ) self.txt2vid = True return pipe # WanImageToVideoPipeline - requires special VAE with float32 dtype if pipeline_type == "WanImageToVideoPipeline": vae = AutoencoderKLWan.from_pretrained( request.Model, subfolder="vae", torch_dtype=torch.float32, device_map=device_map ) pipe = load_diffusers_pipeline( class_name="WanImageToVideoPipeline", model_id=request.Model, vae=vae, torch_dtype=torchType, device_map=device_map ) self.img2vid = True return pipe # SanaPipeline - requires special VAE and text encoder dtype conversion if pipeline_type == "SanaPipeline": pipe = load_diffusers_pipeline( class_name="SanaPipeline", model_id=request.Model, variant="bf16", torch_dtype=torch.bfloat16, device_map=device_map ) pipe.vae.to(torch.bfloat16) pipe.text_encoder.to(torch.bfloat16) return pipe # VideoDiffusionPipeline - alias for DiffusionPipeline with txt2vid flag if pipeline_type == "VideoDiffusionPipeline": self.txt2vid = True pipe = load_diffusers_pipeline( class_name="DiffusionPipeline", model_id=request.Model, torch_dtype=torchType, device_map=device_map ) return pipe # StableVideoDiffusionPipeline - needs img2vid flag and CPU offload if pipeline_type == "StableVideoDiffusionPipeline": self.img2vid = True pipe = load_diffusers_pipeline( class_name="StableVideoDiffusionPipeline", model_id=request.Model, torch_dtype=torchType, variant=variant, device_map=device_map ) if not DISABLE_CPU_OFFLOAD: pipe.enable_model_cpu_offload() return pipe # LTX2ImageToVideoPipeline - needs img2vid flag, CPU offload, and special handling if pipeline_type == "LTX2ImageToVideoPipeline": self.img2vid = True self.ltx2_pipeline = True # Check if loading from single file (GGUF) if fromSingleFile and LTX2VideoTransformer3DModel is not None: _, single_file_ext = os.path.splitext(modelFile) if single_file_ext == ".gguf": # Load transformer from single GGUF file with quantization transformer_kwargs = {} quantization_config = GGUFQuantizationConfig(compute_dtype=torchType) transformer_kwargs["quantization_config"] = quantization_config transformer = LTX2VideoTransformer3DModel.from_single_file( modelFile, config=request.Model, # Use request.Model as the config/model_id subfolder="transformer", device_map=device_map, **transformer_kwargs, ) # Load pipeline with custom transformer pipe = load_diffusers_pipeline( class_name="LTX2ImageToVideoPipeline", model_id=request.Model, transformer=transformer, torch_dtype=torchType, device_map=device_map, ) else: # Single file but not GGUF - use standard single file loading pipe = load_diffusers_pipeline( class_name="LTX2ImageToVideoPipeline", model_id=modelFile, from_single_file=True, torch_dtype=torchType, device_map=device_map, ) else: # Standard loading from pretrained pipe = load_diffusers_pipeline( class_name="LTX2ImageToVideoPipeline", model_id=request.Model, torch_dtype=torchType, variant=variant, device_map=device_map ) if not DISABLE_CPU_OFFLOAD: pipe.enable_model_cpu_offload() return pipe # LTX2Pipeline - text-to-video pipeline, needs txt2vid flag, CPU offload, and special handling if pipeline_type == "LTX2Pipeline": self.txt2vid = True self.ltx2_pipeline = True # Check if loading from single file (GGUF) if fromSingleFile and LTX2VideoTransformer3DModel is not None: _, single_file_ext = os.path.splitext(modelFile) if single_file_ext == ".gguf": # Load transformer from single GGUF file with quantization transformer_kwargs = {} quantization_config = GGUFQuantizationConfig(compute_dtype=torchType) transformer_kwargs["quantization_config"] = quantization_config transformer = LTX2VideoTransformer3DModel.from_single_file( modelFile, config=request.Model, # Use request.Model as the config/model_id subfolder="transformer", device_map=device_map, **transformer_kwargs, ) # Load pipeline with custom transformer pipe = load_diffusers_pipeline( class_name="LTX2Pipeline", model_id=request.Model, transformer=transformer, torch_dtype=torchType, device_map=device_map, ) else: # Single file but not GGUF - use standard single file loading pipe = load_diffusers_pipeline( class_name="LTX2Pipeline", model_id=modelFile, from_single_file=True, torch_dtype=torchType, device_map=device_map, ) else: # Standard loading from pretrained pipe = load_diffusers_pipeline( class_name="LTX2Pipeline", model_id=request.Model, torch_dtype=torchType, variant=variant, device_map=device_map ) if not DISABLE_CPU_OFFLOAD: pipe.enable_model_cpu_offload() return pipe # ================================================================ # Dynamic pipeline loading - the default path for most pipelines # Uses the dynamic loader to instantiate any pipeline by class name # ================================================================ # Build kwargs for dynamic loading load_kwargs = {"torch_dtype": torchType} # Add variant if not loading from single file if not fromSingleFile and variant: load_kwargs["variant"] = variant # Add use_safetensors for from_pretrained if not fromSingleFile: load_kwargs["use_safetensors"] = SAFETENSORS # Add device_map for multi-GPU support (when TensorParallelSize > 1) if device_map: load_kwargs["device_map"] = device_map # Determine pipeline class name - default to AutoPipelineForText2Image effective_pipeline_type = pipeline_type if pipeline_type else "AutoPipelineForText2Image" # Use dynamic loader for all pipelines try: pipe = load_diffusers_pipeline( class_name=effective_pipeline_type, model_id=modelFile if fromSingleFile else request.Model, from_single_file=fromSingleFile, **load_kwargs ) except Exception as e: # Provide helpful error with available pipelines available = get_available_pipelines() raise ValueError( f"Failed to load pipeline '{effective_pipeline_type}': {e}\n" f"Available pipelines: {', '.join(available[:30])}..." ) from e # Apply LowVRAM optimization if supported and requested if request.LowVRAM and hasattr(pipe, 'enable_model_cpu_offload'): pipe.enable_model_cpu_offload() return pipe def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): try: print(f"Loading model {request.Model}...", file=sys.stderr) print(f"Request {request}", file=sys.stderr) torchType = torch.float32 variant = None if request.F16Memory: torchType = torch.float16 variant = "fp16" options = request.Options # empty dict self.options = {} # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when # generating the images for opt in options: if ":" not in opt: continue key, value = opt.split(":") # if value is a number, convert it to the appropriate type if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value # From options, extract if present "torch_dtype" and set it to the appropriate type if "torch_dtype" in self.options: if self.options["torch_dtype"] == "fp16": torchType = torch.float16 elif self.options["torch_dtype"] == "bf16": torchType = torch.bfloat16 elif self.options["torch_dtype"] == "fp32": torchType = torch.float32 # remove it from options del self.options["torch_dtype"] print(f"Options: {self.options}", file=sys.stderr) local = False modelFile = request.Model self.cfg_scale = 7 self.PipelineType = request.PipelineType if request.CFGScale != 0: self.cfg_scale = request.CFGScale clipmodel = "Lykon/dreamshaper-8" if request.CLIPModel != "": clipmodel = request.CLIPModel clipsubfolder = "text_encoder" if request.CLIPSubfolder != "": clipsubfolder = request.CLIPSubfolder # Check if ModelFile exists if request.ModelFile != "": if os.path.exists(request.ModelFile): local = True modelFile = request.ModelFile fromSingleFile = request.Model.startswith("http") or request.Model.startswith("/") or local self.img2vid = False self.txt2vid = False self.ltx2_pipeline = False print(f"LoadModel: PipelineType from request: {request.PipelineType}", file=sys.stderr) # Determine device_map for multi-GPU support based on TensorParallelSize # When TensorParallelSize > 1, use device_map='auto' to distribute model across GPUs device_map = None if hasattr(request, 'TensorParallelSize') and request.TensorParallelSize > 1: device_map = "auto" print(f"LoadModel: Multi-GPU mode enabled with TensorParallelSize={request.TensorParallelSize}, using device_map='auto'", file=sys.stderr) # Load pipeline using dynamic loader # Special cases that require custom initialization are handled first self.pipe = self._load_pipeline( request=request, modelFile=modelFile, fromSingleFile=fromSingleFile, torchType=torchType, variant=variant, device_map=device_map ) print(f"LoadModel: After loading - ltx2_pipeline: {self.ltx2_pipeline}, img2vid: {self.img2vid}, txt2vid: {self.txt2vid}, PipelineType: {self.PipelineType}", file=sys.stderr) if CLIPSKIP and request.CLIPSkip != 0: self.clip_skip = request.CLIPSkip else: self.clip_skip = 0 # torch_dtype needs to be customized. float16 for GPU, float32 for CPU # TODO: this needs to be customized if request.SchedulerType != "": self.pipe.scheduler = get_scheduler(request.SchedulerType, self.pipe.scheduler.config) if COMPEL: self.compel = Compel( tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2], text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True] ) if request.ControlNet: self.controlnet = ControlNetModel.from_pretrained( request.ControlNet, torch_dtype=torchType, variant=variant, device_map=device_map ) self.pipe.controlnet = self.controlnet else: self.controlnet = None if request.LoraAdapter and not os.path.isabs(request.LoraAdapter): # modify LoraAdapter to be relative to modelFileBase request.LoraAdapter = os.path.join(request.ModelPath, request.LoraAdapter) device = "cpu" if not request.CUDA else "cuda" if XPU: device = "xpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" self.device = device if request.LoraAdapter: # Check if its a local file and not a directory ( we load lora differently for a safetensor file ) if os.path.exists(request.LoraAdapter) and not os.path.isdir(request.LoraAdapter): self.pipe.load_lora_weights(request.LoraAdapter) else: self.pipe.unet.load_attn_procs(request.LoraAdapter) if len(request.LoraAdapters) > 0: i = 0 adapters_name = [] adapters_weights = [] for adapter in request.LoraAdapters: if not os.path.isabs(adapter): adapter = os.path.join(request.ModelPath, adapter) self.pipe.load_lora_weights(adapter, adapter_name=f"adapter_{i}") adapters_name.append(f"adapter_{i}") i += 1 for adapters_weight in request.LoraScales: adapters_weights.append(adapters_weight) self.pipe.set_adapters(adapters_name, adapter_weights=adapters_weights) # Only move pipeline to device if NOT using device_map # device_map handles device placement automatically if device_map is None and device != "cpu": self.pipe.to(device) if self.controlnet: self.controlnet.to(device) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) # https://github.com/huggingface/diffusers/issues/3064 def load_lora_weights(self, checkpoint_path, multiplier, device, dtype): LORA_PREFIX_UNET = "lora_unet" LORA_PREFIX_TEXT_ENCODER = "lora_te" # load LoRA weight from .safetensors state_dict = load_file(checkpoint_path, device=device) updates = defaultdict(dict) for key, value in state_dict.items(): # it is suggested to print out the key, it usually will be something like below # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" layer, elem = key.split('.', 1) updates[layer][elem] = value # directly update weight in diffusers model for layer, elems in updates.items(): if "text" in layer: layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") curr_layer = self.pipe.text_encoder else: layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") curr_layer = self.pipe.unet # find the target layer temp_name = layer_infos.pop(0) while len(layer_infos) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) if len(layer_infos) > 0: temp_name = layer_infos.pop(0) elif len(layer_infos) == 0: break except Exception: if len(temp_name) > 0: temp_name += "_" + layer_infos.pop(0) else: temp_name = layer_infos.pop(0) # get elements for this layer weight_up = elems['lora_up.weight'].to(dtype) weight_down = elems['lora_down.weight'].to(dtype) alpha = elems['alpha'] if 'alpha' in elems else None if alpha: alpha = alpha.item() / weight_up.shape[1] else: alpha = 1.0 # update weight if len(weight_up.shape) == 4: curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) else: curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) def GenerateImage(self, request, context): prompt = request.positive_prompt steps = 1 if request.step != 0: steps = request.step # create a dictionary of values for the parameters options = { "num_inference_steps": steps, } if hasattr(request, 'negative_prompt') and request.negative_prompt != "": options["negative_prompt"] = request.negative_prompt # Handle image source: prioritize RefImages over request.src image_src = None if hasattr(request, 'ref_images') and request.ref_images and len(request.ref_images) > 0: # Use the first reference image if available image_src = request.ref_images[0] print(f"Using reference image: {image_src}", file=sys.stderr) elif request.src != "": # Fall back to request.src if no ref_images image_src = request.src print(f"Using source image: {image_src}", file=sys.stderr) else: print("No image source provided", file=sys.stderr) if image_src and not self.controlnet and not self.img2vid: image = Image.open(image_src) options["image"] = image elif self.controlnet and image_src: pose_image = load_image(image_src) options["image"] = pose_image if CLIPSKIP and self.clip_skip != 0: options["clip_skip"] = self.clip_skip kwargs = {} # populate kwargs from self.options. kwargs.update(self.options) kwargs.update(options) # Set seed if request.seed > 0: kwargs["generator"] = torch.Generator(device=self.device).manual_seed( request.seed ) if self.PipelineType == "FluxPipeline": kwargs["max_sequence_length"] = 256 if request.width: kwargs["width"] = request.width if request.height: kwargs["height"] = request.height if self.PipelineType == "FluxTransformer2DModel": kwargs["output_type"] = "pil" kwargs["generator"] = torch.Generator("cpu").manual_seed(0) if self.img2vid: # Load the conditioning image if image_src: image = load_image(image_src) else: # Fallback to request.src for img2vid if no ref_images image = load_image(request.src) image = image.resize((1024, 576)) generator = torch.manual_seed(request.seed) frames = self.pipe(image, guidance_scale=self.cfg_scale, decode_chunk_size=CHUNK_SIZE, generator=generator).frames[0] export_to_video(frames, request.dst, fps=FPS) return backend_pb2.Result(message="Media generated successfully", success=True) if self.txt2vid: video_frames = self.pipe(prompt, guidance_scale=self.cfg_scale, num_inference_steps=steps, num_frames=int(FRAMES)).frames export_to_video(video_frames, request.dst) return backend_pb2.Result(message="Media generated successfully", success=True) print(f"Generating image with {kwargs=}", file=sys.stderr) image = {} if COMPEL: conditioning, pooled = self.compel.build_conditioning_tensor(prompt) kwargs["prompt_embeds"] = conditioning kwargs["pooled_prompt_embeds"] = pooled # pass the kwargs dictionary to the self.pipe method image = self.pipe( guidance_scale=self.cfg_scale, **kwargs ).images[0] elif SD_EMBED and SD_EMBED_AVAILABLE: if self.PipelineType == "StableDiffusionPipeline": ( kwargs["prompt_embeds"], kwargs["negative_prompt_embeds"], ) = get_weighted_text_embeddings_sd15( pipe = self.pipe, prompt = prompt, neg_prompt = request.negative_prompt if hasattr(request, 'negative_prompt') else None, ) if self.PipelineType == "StableDiffusionXLPipeline": ( kwargs["prompt_embeds"], kwargs["negative_prompt_embeds"], kwargs["pooled_prompt_embeds"], kwargs["negative_pooled_prompt_embeds"], ) = get_weighted_text_embeddings_sdxl( pipe = self.pipe, prompt = prompt, neg_prompt = request.negative_prompt if hasattr(request, 'negative_prompt') else None ) if self.PipelineType == "StableDiffusion3Pipeline": ( kwargs["prompt_embeds"], kwargs["negative_prompt_embeds"], kwargs["pooled_prompt_embeds"], kwargs["negative_pooled_prompt_embeds"], ) = get_weighted_text_embeddings_sd3( pipe = self.pipe, prompt = prompt, neg_prompt = request.negative_prompt if hasattr(request, 'negative_prompt') else None ) if self.PipelineType == "FluxTransformer2DModel": ( kwargs["prompt_embeds"], kwargs["pooled_prompt_embeds"], ) = get_weighted_text_embeddings_flux1( pipe = self.pipe, prompt = prompt, ) image = self.pipe( guidance_scale=self.cfg_scale, **kwargs ).images[0] else: # pass the kwargs dictionary to the self.pipe method image = self.pipe( prompt, guidance_scale=self.cfg_scale, **kwargs ).images[0] # save the result image.save(request.dst) return backend_pb2.Result(message="Media generated", success=True) def GenerateVideo(self, request, context): try: prompt = request.prompt if not prompt: print(f"GenerateVideo: No prompt provided for video generation.", file=sys.stderr) return backend_pb2.Result(success=False, message="No prompt provided for video generation") # Debug: Print raw request values print(f"GenerateVideo: Raw request values - num_frames: {request.num_frames}, fps: {request.fps}, cfg_scale: {request.cfg_scale}, step: {request.step}", file=sys.stderr) # Set default values from request or use defaults num_frames = request.num_frames if request.num_frames > 0 else 81 fps = request.fps if request.fps > 0 else 16 cfg_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0 num_inference_steps = request.step if request.step > 0 else 40 print(f"GenerateVideo: Using values - num_frames: {num_frames}, fps: {fps}, cfg_scale: {cfg_scale}, num_inference_steps: {num_inference_steps}", file=sys.stderr) # Prepare generation parameters kwargs = { "prompt": prompt, "negative_prompt": request.negative_prompt if request.negative_prompt else "", "height": request.height if request.height > 0 else 720, "width": request.width if request.width > 0 else 1280, "num_frames": num_frames, "guidance_scale": cfg_scale, "num_inference_steps": num_inference_steps, } # Add custom options from self.options (including guidance_scale_2 if specified) kwargs.update(self.options) # Set seed if provided if request.seed > 0: kwargs["generator"] = torch.Generator(device=self.device).manual_seed(request.seed) # Handle start and end images for video generation if request.start_image: kwargs["start_image"] = load_image(request.start_image) if request.end_image: kwargs["end_image"] = load_image(request.end_image) print(f"Generating video with {kwargs=}", file=sys.stderr) print(f"GenerateVideo: Pipeline type: {self.PipelineType}, ltx2_pipeline flag: {self.ltx2_pipeline}", file=sys.stderr) # Generate video frames based on pipeline type if self.ltx2_pipeline or self.PipelineType in ["LTX2Pipeline", "LTX2ImageToVideoPipeline"]: # LTX-2 generation with audio (supports both text-to-video and image-to-video) # Determine if this is text-to-video (no image) or image-to-video (has image) has_image = bool(request.start_image) # Remove image-related parameters that might have been added earlier kwargs.pop("start_image", None) kwargs.pop("end_image", None) # LTX2ImageToVideoPipeline uses 'image' parameter for image-to-video # LTX2Pipeline (text-to-video) doesn't need an image parameter if has_image: # Image-to-video: use 'image' parameter if self.PipelineType == "LTX2ImageToVideoPipeline": image = load_image(request.start_image) kwargs["image"] = image print(f"LTX-2: Using image-to-video mode with image", file=sys.stderr) else: # If pipeline type is LTX2Pipeline but we have an image, we can't do image-to-video return backend_pb2.Result(success=False, message="LTX2Pipeline does not support image-to-video. Use LTX2ImageToVideoPipeline for image-to-video generation.") else: # Text-to-video: no image parameter needed # Ensure no image-related kwargs are present kwargs.pop("image", None) print(f"LTX-2: Using text-to-video mode (no image)", file=sys.stderr) # LTX-2 uses 'frame_rate' instead of 'fps' frame_rate = float(fps) kwargs["frame_rate"] = frame_rate # LTX-2 requires output_type="np" and return_dict=False kwargs["output_type"] = "np" kwargs["return_dict"] = False # Generate video and audio print(f"LTX-2: Generating with kwargs: {kwargs}", file=sys.stderr) try: video, audio = self.pipe(**kwargs) print(f"LTX-2: Generated video shape: {video.shape}, audio shape: {audio.shape}", file=sys.stderr) except Exception as e: print(f"LTX-2: Error during pipe() call: {e}", file=sys.stderr) traceback.print_exc() return backend_pb2.Result(success=False, message=f"Error generating video with LTX-2 pipeline: {e}") # Convert video to uint8 format video = (video * 255).round().astype("uint8") video = torch.from_numpy(video) print(f"LTX-2: Converting video, shape after conversion: {video.shape}", file=sys.stderr) print(f"LTX-2: Audio sample rate: {self.pipe.vocoder.config.output_sampling_rate}", file=sys.stderr) print(f"LTX-2: Output path: {request.dst}", file=sys.stderr) # Use LTX-2's encode_video function which handles audio try: ltx2_encode_video( video[0], fps=frame_rate, audio=audio[0].float().cpu(), audio_sample_rate=self.pipe.vocoder.config.output_sampling_rate, output_path=request.dst, ) # Verify file was created and has content import os if os.path.exists(request.dst): file_size = os.path.getsize(request.dst) print(f"LTX-2: Video file created successfully, size: {file_size} bytes", file=sys.stderr) if file_size == 0: return backend_pb2.Result(success=False, message=f"Video file was created but is empty (0 bytes). Check LTX-2 encode_video function.") else: return backend_pb2.Result(success=False, message=f"Video file was not created at {request.dst}") except Exception as e: print(f"LTX-2: Error encoding video: {e}", file=sys.stderr) traceback.print_exc() return backend_pb2.Result(success=False, message=f"Error encoding video: {e}") return backend_pb2.Result(message="Video generated successfully", success=True) elif self.PipelineType == "WanPipeline": # WAN2.2 text-to-video generation output = self.pipe(**kwargs) frames = output.frames[0] # WAN2.2 returns frames in this format elif self.PipelineType == "WanImageToVideoPipeline": # WAN2.2 image-to-video generation if request.start_image: # Load and resize the input image according to WAN2.2 requirements image = load_image(request.start_image) # Use request dimensions or defaults, but respect WAN2.2 constraints request_height = request.height if request.height > 0 else 480 request_width = request.width if request.width > 0 else 832 max_area = request_height * request_width aspect_ratio = image.height / image.width mod_value = self.pipe.vae_scale_factor_spatial * self.pipe.transformer.config.patch_size[1] height = round((max_area * aspect_ratio) ** 0.5 / mod_value) * mod_value width = round((max_area / aspect_ratio) ** 0.5 / mod_value) * mod_value image = image.resize((width, height)) kwargs["image"] = image kwargs["height"] = height kwargs["width"] = width output = self.pipe(**kwargs) frames = output.frames[0] elif self.img2vid: # Generic image-to-video generation if request.start_image: image = load_image(request.start_image) image = image.resize((request.width if request.width > 0 else 1024, request.height if request.height > 0 else 576)) kwargs["image"] = image output = self.pipe(**kwargs) frames = output.frames[0] elif self.txt2vid: # Generic text-to-video generation output = self.pipe(**kwargs) frames = output.frames[0] else: print(f"GenerateVideo: Pipeline {self.PipelineType} does not match any known video pipeline handler", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Pipeline {self.PipelineType} does not support video generation") # Export video (for non-LTX-2 pipelines) print(f"GenerateVideo: Exporting video to {request.dst} with fps={fps}", file=sys.stderr) export_to_video(frames, request.dst, fps=fps) # Verify file was created import os if os.path.exists(request.dst): file_size = os.path.getsize(request.dst) print(f"GenerateVideo: Video file created, size: {file_size} bytes", file=sys.stderr) if file_size == 0: return backend_pb2.Result(success=False, message=f"Video file was created but is empty (0 bytes)") else: return backend_pb2.Result(success=False, message=f"Video file was not created at {request.dst}") return backend_pb2.Result(message="Video generated successfully", success=True) except Exception as err: print(f"Error generating video: {err}", file=sys.stderr) traceback.print_exc() return backend_pb2.Result(success=False, message=f"Error generating video: {err}") def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/diffusers/diffusers_dynamic_loader.py ================================================ """ Dynamic Diffusers Pipeline Loader This module provides dynamic discovery and loading of diffusers pipelines at runtime, eliminating the need for per-pipeline conditional statements. New pipelines added to diffusers become available automatically without code changes. The module also supports discovering other diffusers classes like schedulers, models, and other components, making it a generic solution for dynamic class loading. Usage: from diffusers_dynamic_loader import load_diffusers_pipeline, get_available_pipelines # Load by class name pipe = load_diffusers_pipeline(class_name="StableDiffusionPipeline", model_id="...", torch_dtype=torch.float16) # Load by task alias pipe = load_diffusers_pipeline(task="text-to-image", model_id="...", torch_dtype=torch.float16) # Load using model_id (infers from HuggingFace Hub if possible) pipe = load_diffusers_pipeline(model_id="runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) # Get list of available pipelines available = get_available_pipelines() # Discover other diffusers classes (schedulers, models, etc.) schedulers = discover_diffusers_classes("SchedulerMixin") models = discover_diffusers_classes("ModelMixin") """ import importlib import re import sys from typing import Any, Dict, List, Optional, Tuple, Type # Global cache for discovered pipelines - computed once per process _pipeline_registry: Optional[Dict[str, Type]] = None _task_aliases: Optional[Dict[str, List[str]]] = None # Global cache for other discovered class types _class_registries: Dict[str, Dict[str, Type]] = {} def _camel_to_kebab(name: str) -> str: """ Convert CamelCase to kebab-case. Examples: StableDiffusionPipeline -> stable-diffusion-pipeline StableDiffusionXLImg2ImgPipeline -> stable-diffusion-xl-img-2-img-pipeline """ # Insert hyphen before uppercase letters (but not at the start) s1 = re.sub('(.)([A-Z][a-z]+)', r'\1-\2', name) # Insert hyphen before uppercase letters following lowercase letters or numbers s2 = re.sub('([a-z0-9])([A-Z])', r'\1-\2', s1) return s2.lower() def _extract_task_keywords(class_name: str) -> List[str]: """ Extract task-related keywords from a pipeline class name. This function derives useful task aliases from the class name without hardcoding per-pipeline branches. Returns a list of potential task aliases for this pipeline. """ aliases = [] name_lower = class_name.lower() # Direct task mappings based on common patterns in class names task_patterns = { 'text2image': ['text-to-image', 'txt2img', 'text2image'], 'texttoimage': ['text-to-image', 'txt2img', 'text2image'], 'txt2img': ['text-to-image', 'txt2img', 'text2image'], 'img2img': ['image-to-image', 'img2img', 'image2image'], 'image2image': ['image-to-image', 'img2img', 'image2image'], 'imagetoimage': ['image-to-image', 'img2img', 'image2image'], 'img2video': ['image-to-video', 'img2vid', 'img2video'], 'imagetovideo': ['image-to-video', 'img2vid', 'img2video'], 'text2video': ['text-to-video', 'txt2vid', 'text2video'], 'texttovideo': ['text-to-video', 'txt2vid', 'text2video'], 'inpaint': ['inpainting', 'inpaint'], 'depth2img': ['depth-to-image', 'depth2img'], 'depthtoimage': ['depth-to-image', 'depth2img'], 'controlnet': ['controlnet', 'control-net'], 'upscale': ['upscaling', 'upscale', 'super-resolution'], 'superresolution': ['upscaling', 'upscale', 'super-resolution'], } # Check for each pattern in the class name for pattern, task_aliases in task_patterns.items(): if pattern in name_lower: aliases.extend(task_aliases) # Also detect general pipeline types from the class name structure # E.g., StableDiffusionPipeline -> stable-diffusion, flux -> flux # Remove "Pipeline" suffix and convert to kebab case if class_name.endswith('Pipeline'): base_name = class_name[:-8] # Remove "Pipeline" kebab_name = _camel_to_kebab(base_name) aliases.append(kebab_name) # Extract model family name (e.g., "stable-diffusion" from "stable-diffusion-xl-img-2-img") parts = kebab_name.split('-') if len(parts) >= 2: # Try the first two words as a family name family = '-'.join(parts[:2]) if family not in aliases: aliases.append(family) # If no specific task pattern matched but class contains "Pipeline", add "text-to-image" as default # since most diffusion pipelines support text-to-image generation if 'text-to-image' not in aliases and 'image-to-image' not in aliases: # Only add for pipelines that seem to be generation pipelines (not schedulers, etc.) if 'pipeline' in name_lower and not any(x in name_lower for x in ['scheduler', 'processor', 'encoder']): # Don't automatically add - let it be explicit pass return list(set(aliases)) # Remove duplicates def discover_diffusers_classes( base_class_name: str, include_base: bool = True ) -> Dict[str, Type]: """ Discover all subclasses of a given base class from diffusers. This function provides a generic way to discover any type of diffusers class, not just pipelines. It can be used to discover schedulers, models, processors, and other components. Args: base_class_name: Name of the base class to search for subclasses (e.g., "DiffusionPipeline", "SchedulerMixin", "ModelMixin") include_base: Whether to include the base class itself in results Returns: Dict mapping class names to class objects Examples: # Discover all pipeline classes pipelines = discover_diffusers_classes("DiffusionPipeline") # Discover all scheduler classes schedulers = discover_diffusers_classes("SchedulerMixin") # Discover all model classes models = discover_diffusers_classes("ModelMixin") # Discover AutoPipeline classes auto_pipelines = discover_diffusers_classes("AutoPipelineForText2Image") """ global _class_registries # Check cache first if base_class_name in _class_registries: return _class_registries[base_class_name] import diffusers # Try to get the base class from diffusers base_class = None try: base_class = getattr(diffusers, base_class_name) except AttributeError: # Try to find in submodules for submodule in ['schedulers', 'models', 'pipelines']: try: module = importlib.import_module(f'diffusers.{submodule}') if hasattr(module, base_class_name): base_class = getattr(module, base_class_name) break except (ImportError, ModuleNotFoundError): continue if base_class is None: raise ValueError(f"Could not find base class '{base_class_name}' in diffusers") registry: Dict[str, Type] = {} # Include base class if requested if include_base: registry[base_class_name] = base_class # Scan diffusers module for subclasses for attr_name in dir(diffusers): try: attr = getattr(diffusers, attr_name) if (isinstance(attr, type) and issubclass(attr, base_class) and (include_base or attr is not base_class)): registry[attr_name] = attr except (ImportError, AttributeError, TypeError, RuntimeError, ModuleNotFoundError): continue # Cache the results _class_registries[base_class_name] = registry return registry def get_available_classes(base_class_name: str) -> List[str]: """ Get a sorted list of all discovered class names for a given base class. Args: base_class_name: Name of the base class (e.g., "SchedulerMixin") Returns: Sorted list of discovered class names """ return sorted(discover_diffusers_classes(base_class_name).keys()) def _discover_pipelines() -> Tuple[Dict[str, Type], Dict[str, List[str]]]: """ Discover all subclasses of DiffusionPipeline from diffusers. This function uses the generic discover_diffusers_classes() internally and adds pipeline-specific task alias generation. It also includes AutoPipeline classes which are special utility classes for automatic pipeline selection. Returns: A tuple of (pipeline_registry, task_aliases) where: - pipeline_registry: Dict mapping class names to class objects - task_aliases: Dict mapping task aliases to lists of class names """ # Use the generic discovery function pipeline_registry = discover_diffusers_classes("DiffusionPipeline", include_base=True) # Also add AutoPipeline classes - these are special utility classes that are # NOT subclasses of DiffusionPipeline but are commonly used import diffusers auto_pipeline_classes = [ "AutoPipelineForText2Image", "AutoPipelineForImage2Image", "AutoPipelineForInpainting", ] for cls_name in auto_pipeline_classes: try: cls = getattr(diffusers, cls_name) if cls is not None: pipeline_registry[cls_name] = cls except AttributeError: # Class not available in this version of diffusers pass # Generate task aliases for pipelines task_aliases: Dict[str, List[str]] = {} for attr_name in pipeline_registry: if attr_name == "DiffusionPipeline": continue # Skip base class for alias generation aliases = _extract_task_keywords(attr_name) for alias in aliases: if alias not in task_aliases: task_aliases[alias] = [] if attr_name not in task_aliases[alias]: task_aliases[alias].append(attr_name) return pipeline_registry, task_aliases def get_pipeline_registry() -> Dict[str, Type]: """ Get the cached pipeline registry. Returns a dictionary mapping pipeline class names to their class objects. The registry is built on first access and cached for subsequent calls. """ global _pipeline_registry, _task_aliases if _pipeline_registry is None: _pipeline_registry, _task_aliases = _discover_pipelines() return _pipeline_registry def get_task_aliases() -> Dict[str, List[str]]: """ Get the cached task aliases dictionary. Returns a dictionary mapping task aliases (e.g., "text-to-image") to lists of pipeline class names that support that task. """ global _pipeline_registry, _task_aliases if _task_aliases is None: _pipeline_registry, _task_aliases = _discover_pipelines() return _task_aliases def get_available_pipelines() -> List[str]: """ Get a sorted list of all discovered pipeline class names. Returns: List of pipeline class names available for loading. """ return sorted(get_pipeline_registry().keys()) def get_available_tasks() -> List[str]: """ Get a sorted list of all available task aliases. Returns: List of task aliases (e.g., ["text-to-image", "image-to-image", ...]) """ return sorted(get_task_aliases().keys()) def resolve_pipeline_class( class_name: Optional[str] = None, task: Optional[str] = None, model_id: Optional[str] = None ) -> Type: """ Resolve a pipeline class from class_name, task, or model_id. Priority: 1. If class_name is provided, look it up directly 2. If task is provided, resolve through task aliases 3. If model_id is provided, try to infer from HuggingFace Hub Args: class_name: Exact pipeline class name (e.g., "StableDiffusionPipeline") task: Task alias (e.g., "text-to-image", "img2img") model_id: HuggingFace model ID (e.g., "runwayml/stable-diffusion-v1-5") Returns: The resolved pipeline class. Raises: ValueError: If no pipeline could be resolved. """ registry = get_pipeline_registry() aliases = get_task_aliases() # 1. Direct class name lookup if class_name: if class_name in registry: return registry[class_name] # Try case-insensitive match for name, cls in registry.items(): if name.lower() == class_name.lower(): return cls raise ValueError( f"Unknown pipeline class '{class_name}'. " f"Available pipelines: {', '.join(sorted(registry.keys())[:20])}..." ) # 2. Task alias lookup if task: task_lower = task.lower().replace('_', '-') if task_lower in aliases: # Return the first matching pipeline for this task matching_classes = aliases[task_lower] if matching_classes: return registry[matching_classes[0]] # Try partial matching for alias, classes in aliases.items(): if task_lower in alias or alias in task_lower: if classes: return registry[classes[0]] raise ValueError( f"Unknown task '{task}'. " f"Available tasks: {', '.join(sorted(aliases.keys())[:20])}..." ) # 3. Try to infer from HuggingFace Hub if model_id: try: from huggingface_hub import model_info info = model_info(model_id) # Check pipeline_tag if hasattr(info, 'pipeline_tag') and info.pipeline_tag: tag = info.pipeline_tag.lower().replace('_', '-') if tag in aliases: matching_classes = aliases[tag] if matching_classes: return registry[matching_classes[0]] # Check model card for hints if hasattr(info, 'cardData') and info.cardData: card = info.cardData if 'pipeline_tag' in card: tag = card['pipeline_tag'].lower().replace('_', '-') if tag in aliases: matching_classes = aliases[tag] if matching_classes: return registry[matching_classes[0]] except ImportError: # huggingface_hub not available pass except (KeyError, AttributeError, ValueError, OSError): # Model info lookup failed - common cases: # - KeyError: Missing keys in model card # - AttributeError: Missing attributes on model info # - ValueError: Invalid model data # - OSError: Network or file access issues pass # Fallback: use DiffusionPipeline.from_pretrained which auto-detects # DiffusionPipeline is always added to registry in _discover_pipelines (line 132) # but use .get() with import fallback for extra safety from diffusers import DiffusionPipeline return registry.get('DiffusionPipeline', DiffusionPipeline) raise ValueError( "Must provide at least one of: class_name, task, or model_id. " f"Available pipelines: {', '.join(sorted(registry.keys())[:20])}... " f"Available tasks: {', '.join(sorted(aliases.keys())[:20])}..." ) def load_diffusers_pipeline( class_name: Optional[str] = None, task: Optional[str] = None, model_id: Optional[str] = None, from_single_file: bool = False, **kwargs ) -> Any: """ Load a diffusers pipeline dynamically. This function resolves the appropriate pipeline class based on the provided parameters and instantiates it with the given kwargs. Args: class_name: Exact pipeline class name (e.g., "StableDiffusionPipeline") task: Task alias (e.g., "text-to-image", "img2img") model_id: HuggingFace model ID or local path from_single_file: If True, use from_single_file() instead of from_pretrained() **kwargs: Additional arguments passed to from_pretrained() or from_single_file() Returns: An instantiated pipeline object. Raises: ValueError: If no pipeline could be resolved. Exception: If pipeline loading fails. Examples: # Load by class name pipe = load_diffusers_pipeline( class_name="StableDiffusionPipeline", model_id="runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ) # Load by task pipe = load_diffusers_pipeline( task="text-to-image", model_id="runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16 ) # Load from single file pipe = load_diffusers_pipeline( class_name="StableDiffusionPipeline", model_id="/path/to/model.safetensors", from_single_file=True, torch_dtype=torch.float16 ) """ # Resolve the pipeline class pipeline_class = resolve_pipeline_class( class_name=class_name, task=task, model_id=model_id ) # If no model_id provided but we have a class, we can't load if model_id is None: raise ValueError("model_id is required to load a pipeline") # Load the pipeline try: if from_single_file: # Check if the class has from_single_file method if hasattr(pipeline_class, 'from_single_file'): return pipeline_class.from_single_file(model_id, **kwargs) else: raise ValueError( f"Pipeline class {pipeline_class.__name__} does not support from_single_file(). " f"Use from_pretrained() instead." ) else: return pipeline_class.from_pretrained(model_id, **kwargs) except Exception as e: # Provide helpful error message available = get_available_pipelines() raise RuntimeError( f"Failed to load pipeline '{pipeline_class.__name__}' from '{model_id}': {e}\n" f"Available pipelines: {', '.join(available[:20])}..." ) from e def get_pipeline_info(class_name: str) -> Dict[str, Any]: """ Get information about a specific pipeline class. Args: class_name: The pipeline class name Returns: Dictionary with pipeline information including: - name: Class name - aliases: List of task aliases - supports_single_file: Whether from_single_file() is available - docstring: Class docstring (if available) """ registry = get_pipeline_registry() aliases = get_task_aliases() if class_name not in registry: raise ValueError(f"Unknown pipeline: {class_name}") cls = registry[class_name] # Find all aliases for this pipeline pipeline_aliases = [] for alias, classes in aliases.items(): if class_name in classes: pipeline_aliases.append(alias) return { 'name': class_name, 'aliases': pipeline_aliases, 'supports_single_file': hasattr(cls, 'from_single_file'), 'docstring': cls.__doc__[:200] if cls.__doc__ else None } ================================================ FILE: backend/python/diffusers/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then USE_PIP=true fi # Use python 3.12 for l4t if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then PYTHON_VERSION="3.12" PYTHON_PATCH="12" PY_STANDALONE_TAG="20251120" fi installRequirements ================================================ FILE: backend/python/diffusers/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu git+https://github.com/huggingface/diffusers opencv-python transformers torchvision==0.22.1 accelerate compel git+https://github.com/xhinker/sd_embed peft sentencepiece torch==2.7.1 optimum-quanto ftfy ================================================ FILE: backend/python/diffusers/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu121 git+https://github.com/huggingface/diffusers opencv-python transformers torchvision accelerate compel git+https://github.com/xhinker/sd_embed peft sentencepiece torch ftfy optimum-quanto ================================================ FILE: backend/python/diffusers/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 git+https://github.com/huggingface/diffusers opencv-python transformers torchvision accelerate compel git+https://github.com/xhinker/sd_embed peft sentencepiece torch ftfy optimum-quanto ================================================ FILE: backend/python/diffusers/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0+rocm6.4 torchvision==0.23.0+rocm6.4 git+https://github.com/huggingface/diffusers opencv-python transformers accelerate compel peft sentencepiece optimum-quanto ftfy ================================================ FILE: backend/python/diffusers/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch torchvision optimum[openvino] setuptools git+https://github.com/huggingface/diffusers opencv-python transformers accelerate compel git+https://github.com/xhinker/sd_embed peft sentencepiece optimum-quanto ftfy ================================================ FILE: backend/python/diffusers/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ torch git+https://github.com/huggingface/diffusers transformers accelerate compel peft optimum-quanto numpy<2 sentencepiece torchvision ftfy ================================================ FILE: backend/python/diffusers/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch git+https://github.com/huggingface/diffusers transformers accelerate compel peft optimum-quanto numpy<2 sentencepiece torchvision ftfy chardet ================================================ FILE: backend/python/diffusers/requirements-mps.txt ================================================ torch==2.7.1 torchvision==0.22.1 git+https://github.com/huggingface/diffusers opencv-python transformers accelerate compel peft sentencepiece optimum-quanto ftfy ================================================ FILE: backend/python/diffusers/requirements.txt ================================================ setuptools grpcio==1.76.0 pillow protobuf certifi av ================================================ FILE: backend/python/diffusers/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi if [ -d "/opt/intel" ]; then # Assumes we are using the Intel oneAPI container image # https://github.com/intel/intel-extension-for-pytorch/issues/538 export XPU=1 fi export PYTORCH_ENABLE_MPS_FALLBACK=1 startBackend $@ ================================================ FILE: backend/python/diffusers/test.py ================================================ """ A test script to test the gRPC service and dynamic loader """ import unittest import subprocess import time from unittest.mock import patch, MagicMock # Import dynamic loader for testing (these don't need gRPC) import diffusers_dynamic_loader as loader from diffusers import DiffusionPipeline, StableDiffusionPipeline # Try to import gRPC modules - may not be available during unit testing try: import grpc import backend_pb2 import backend_pb2_grpc GRPC_AVAILABLE = True except ImportError: GRPC_AVAILABLE = False @unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available") class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.kill() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ time.sleep(20) try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ time.sleep(20) try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test(self): """ This method tests if the backend can generate images """ time.sleep(20) try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="Lykon/dreamshaper-8")) print(response.message) self.assertTrue(response.success) image_req = backend_pb2.GenerateImageRequest(positive_prompt="cat", width=16,height=16, dst="test.jpg") re = stub.GenerateImage(image_req) self.assertTrue(re.success) except Exception as err: print(err) self.fail("Image gen service failed") finally: self.tearDown() class TestDiffusersDynamicLoader(unittest.TestCase): """Test cases for the diffusers dynamic loader functionality.""" @classmethod def setUpClass(cls): """Set up test fixtures - clear caches to ensure fresh discovery.""" # Reset the caches to ensure fresh discovery loader._pipeline_registry = None loader._task_aliases = None def test_camel_to_kebab_conversion(self): """Test CamelCase to kebab-case conversion.""" test_cases = [ ("StableDiffusionPipeline", "stable-diffusion-pipeline"), ("StableDiffusionXLPipeline", "stable-diffusion-xl-pipeline"), ("FluxPipeline", "flux-pipeline"), ("DiffusionPipeline", "diffusion-pipeline"), ] for input_val, expected in test_cases: with self.subTest(input=input_val): result = loader._camel_to_kebab(input_val) self.assertEqual(result, expected) def test_extract_task_keywords(self): """Test task keyword extraction from class names.""" # Test text-to-image detection aliases = loader._extract_task_keywords("StableDiffusionPipeline") self.assertIn("stable-diffusion", aliases) # Test img2img detection aliases = loader._extract_task_keywords("StableDiffusionImg2ImgPipeline") self.assertIn("image-to-image", aliases) self.assertIn("img2img", aliases) # Test inpainting detection aliases = loader._extract_task_keywords("StableDiffusionInpaintPipeline") self.assertIn("inpainting", aliases) self.assertIn("inpaint", aliases) # Test depth2img detection aliases = loader._extract_task_keywords("StableDiffusionDepth2ImgPipeline") self.assertIn("depth-to-image", aliases) def test_discover_pipelines_finds_known_classes(self): """Test that pipeline discovery finds at least one known pipeline class.""" registry = loader.get_pipeline_registry() # Check that the registry is not empty self.assertGreater(len(registry), 0, "Pipeline registry should not be empty") # Check for known pipeline classes known_pipelines = [ "StableDiffusionPipeline", "DiffusionPipeline", ] for pipeline_name in known_pipelines: with self.subTest(pipeline=pipeline_name): self.assertIn( pipeline_name, registry, f"Expected to find {pipeline_name} in registry" ) def test_discover_pipelines_caches_results(self): """Test that pipeline discovery results are cached.""" # Get registry twice registry1 = loader.get_pipeline_registry() registry2 = loader.get_pipeline_registry() # Should be the same object (cached) self.assertIs(registry1, registry2, "Registry should be cached") def test_get_available_pipelines(self): """Test getting list of available pipelines.""" available = loader.get_available_pipelines() # Should return a list self.assertIsInstance(available, list) # Should contain known pipelines self.assertIn("StableDiffusionPipeline", available) self.assertIn("DiffusionPipeline", available) # Should be sorted self.assertEqual(available, sorted(available)) def test_get_available_tasks(self): """Test getting list of available task aliases.""" tasks = loader.get_available_tasks() # Should return a list self.assertIsInstance(tasks, list) # Should be sorted self.assertEqual(tasks, sorted(tasks)) def test_resolve_pipeline_class_by_name(self): """Test resolving pipeline class by exact name.""" cls = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline") self.assertEqual(cls, StableDiffusionPipeline) def test_resolve_pipeline_class_by_name_case_insensitive(self): """Test that class name resolution is case-insensitive.""" cls1 = loader.resolve_pipeline_class(class_name="StableDiffusionPipeline") cls2 = loader.resolve_pipeline_class(class_name="stablediffusionpipeline") self.assertEqual(cls1, cls2) def test_resolve_pipeline_class_by_task(self): """Test resolving pipeline class by task alias.""" # Get the registry to find available tasks aliases = loader.get_task_aliases() # Test with a common task that should be available if "stable-diffusion" in aliases: cls = loader.resolve_pipeline_class(task="stable-diffusion") self.assertIsNotNone(cls) def test_resolve_pipeline_class_unknown_name_raises(self): """Test that resolving unknown class name raises ValueError with helpful message.""" with self.assertRaises(ValueError) as ctx: loader.resolve_pipeline_class(class_name="NonExistentPipeline") # Check that error message includes available pipelines error_msg = str(ctx.exception) self.assertIn("Unknown pipeline class", error_msg) self.assertIn("Available pipelines", error_msg) def test_resolve_pipeline_class_unknown_task_raises(self): """Test that resolving unknown task raises ValueError with helpful message.""" with self.assertRaises(ValueError) as ctx: loader.resolve_pipeline_class(task="nonexistent-task-xyz") # Check that error message includes available tasks error_msg = str(ctx.exception) self.assertIn("Unknown task", error_msg) self.assertIn("Available tasks", error_msg) def test_resolve_pipeline_class_no_params_raises(self): """Test that calling with no parameters raises helpful ValueError.""" with self.assertRaises(ValueError) as ctx: loader.resolve_pipeline_class() error_msg = str(ctx.exception) self.assertIn("Must provide at least one of", error_msg) def test_get_pipeline_info(self): """Test getting pipeline information.""" info = loader.get_pipeline_info("StableDiffusionPipeline") self.assertEqual(info['name'], "StableDiffusionPipeline") self.assertIsInstance(info['aliases'], list) self.assertIsInstance(info['supports_single_file'], bool) def test_get_pipeline_info_unknown_raises(self): """Test that getting info for unknown pipeline raises ValueError.""" with self.assertRaises(ValueError) as ctx: loader.get_pipeline_info("NonExistentPipeline") self.assertIn("Unknown pipeline", str(ctx.exception)) def test_discover_diffusers_classes_pipelines(self): """Test generic class discovery for DiffusionPipeline.""" classes = loader.discover_diffusers_classes("DiffusionPipeline") # Should return a dict self.assertIsInstance(classes, dict) # Should contain known pipeline classes self.assertIn("DiffusionPipeline", classes) self.assertIn("StableDiffusionPipeline", classes) def test_discover_diffusers_classes_caches_results(self): """Test that class discovery results are cached.""" classes1 = loader.discover_diffusers_classes("DiffusionPipeline") classes2 = loader.discover_diffusers_classes("DiffusionPipeline") # Should be the same object (cached) self.assertIs(classes1, classes2) def test_discover_diffusers_classes_exclude_base(self): """Test discovering classes without base class.""" classes = loader.discover_diffusers_classes("DiffusionPipeline", include_base=False) # Should still contain subclasses self.assertIn("StableDiffusionPipeline", classes) def test_get_available_classes(self): """Test getting list of available classes for a base class.""" classes = loader.get_available_classes("DiffusionPipeline") # Should return a sorted list self.assertIsInstance(classes, list) self.assertEqual(classes, sorted(classes)) # Should contain known classes self.assertIn("StableDiffusionPipeline", classes) class TestDiffusersDynamicLoaderWithMocks(unittest.TestCase): """Test cases using mocks to test edge cases.""" def test_load_pipeline_requires_model_id(self): """Test that load_diffusers_pipeline requires model_id.""" with self.assertRaises(ValueError) as ctx: loader.load_diffusers_pipeline(class_name="StableDiffusionPipeline") self.assertIn("model_id is required", str(ctx.exception)) def test_resolve_with_model_id_uses_diffusion_pipeline_fallback(self): """Test that resolving with only model_id falls back to DiffusionPipeline.""" # When model_id is provided, if hub lookup is not successful, # should fall back to DiffusionPipeline. # This tests the fallback behavior - the actual hub lookup may succeed # or fail depending on network, but the fallback path should work. cls = loader.resolve_pipeline_class(model_id="some/nonexistent/model") self.assertEqual(cls, DiffusionPipeline) @unittest.skipUnless(GRPC_AVAILABLE, "gRPC modules not available") class TestGenerateImageOptionsKwargsMerge(unittest.TestCase): """Test that GenerateImage merges the options dict into pipeline kwargs. The options dict holds image (PIL), negative_prompt, and num_inference_steps. Without the merge, img2img pipelines never receive the source image and fail with 'Input is in incorrect format'. """ def test_options_merged_into_pipeline_kwargs(self): from backend import BackendServicer from PIL import Image import tempfile, os svc = BackendServicer.__new__(BackendServicer) # Minimal attributes the method reads svc.pipe = MagicMock() svc.pipe.return_value.images = [Image.new("RGB", (4, 4))] svc.cfg_scale = 7.5 svc.controlnet = None svc.img2vid = False svc.txt2vid = False svc.clip_skip = 0 svc.PipelineType = "StableDiffusionImg2ImgPipeline" svc.options = {} # Create a tiny source image for the request's src field src_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) Image.new("RGB", (4, 4), color="red").save(src_file, format="PNG") src_file.close() dst_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) dst_file.close() try: request = MagicMock() request.positive_prompt = "a test prompt" request.negative_prompt = "bad quality" request.step = 10 request.seed = 0 request.width = 0 request.height = 0 request.src = src_file.name request.ref_images = [] request.dst = dst_file.name svc.GenerateImage(request, context=None) # The pipeline must have been called with the image kwarg svc.pipe.assert_called_once() _, call_kwargs = svc.pipe.call_args self.assertIn("image", call_kwargs, "source image must be passed to pipeline via kwargs") self.assertIn("negative_prompt", call_kwargs, "negative_prompt must be passed to pipeline via kwargs") self.assertEqual(call_kwargs["num_inference_steps"], 10) finally: os.unlink(src_file.name) os.unlink(dst_file.name) ================================================ FILE: backend/python/diffusers/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/faster-qwen3-tts/Makefile ================================================ .PHONY: faster-qwen3-tts faster-qwen3-tts: bash install.sh .PHONY: run run: faster-qwen3-tts @echo "Running faster-qwen3-tts..." bash run.sh @echo "faster-qwen3-tts run." .PHONY: test test: faster-qwen3-tts @echo "Testing faster-qwen3-tts..." bash test.sh @echo "faster-qwen3-tts tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/faster-qwen3-tts/backend.py ================================================ #!/usr/bin/env python3 """ gRPC server of LocalAI for Faster Qwen3-TTS (CUDA graph capture, voice clone only). """ from concurrent import futures import time import argparse import signal import sys import os import traceback import backend_pb2 import backend_pb2_grpc import torch import soundfile as sf import grpc def is_float(s): try: float(s) return True except ValueError: return False def is_int(s): try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) class BackendServicer(backend_pb2_grpc.BackendServicer): def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): if not torch.cuda.is_available(): return backend_pb2.Result( success=False, message="faster-qwen3-tts requires NVIDIA GPU with CUDA" ) self.options = {} for opt in request.Options: if ":" not in opt: continue key, value = opt.split(":", 1) if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value model_path = request.Model or "Qwen/Qwen3-TTS-12Hz-0.6B-Base" self.audio_path = request.AudioPath if hasattr(request, 'AudioPath') and request.AudioPath else None self.model_file = request.ModelFile if hasattr(request, 'ModelFile') and request.ModelFile else None self.model_path = request.ModelPath if hasattr(request, 'ModelPath') and request.ModelPath else None from faster_qwen3_tts import FasterQwen3TTS print(f"Loading model from: {model_path}", file=sys.stderr) try: self.model = FasterQwen3TTS.from_pretrained(model_path) except Exception as e: print(f"[ERROR] Loading model: {type(e).__name__}: {e}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.Result(success=False, message=str(e)) print(f"Model loaded successfully: {model_path}", file=sys.stderr) return backend_pb2.Result(message="Model loaded successfully", success=True) def _get_ref_audio_path(self, request): if not self.audio_path: return None if os.path.isabs(self.audio_path): return self.audio_path if self.model_file: model_file_base = os.path.dirname(self.model_file) ref_path = os.path.join(model_file_base, self.audio_path) if os.path.exists(ref_path): return ref_path if self.model_path: ref_path = os.path.join(self.model_path, self.audio_path) if os.path.exists(ref_path): return ref_path return self.audio_path def TTS(self, request, context): try: if not request.dst: return backend_pb2.Result( success=False, message="dst (output path) is required" ) text = request.text.strip() if not text: return backend_pb2.Result( success=False, message="Text is empty" ) language = request.language if hasattr(request, 'language') and request.language else None if not language or language == "": language = "English" ref_audio = self._get_ref_audio_path(request) if not ref_audio: return backend_pb2.Result( success=False, message="AudioPath is required for voice clone (set in LoadModel)" ) ref_text = self.options.get("ref_text") if not ref_text and hasattr(request, 'ref_text') and request.ref_text: ref_text = request.ref_text if not ref_text: return backend_pb2.Result( success=False, message="ref_text is required for voice clone (set via LoadModel Options, e.g. ref_text:Your reference transcript)" ) chunk_size = self.options.get("chunk_size") generation_kwargs = {} if chunk_size is not None: generation_kwargs["chunk_size"] = int(chunk_size) audio_list, sr = self.model.generate_voice_clone( text=text, language=language, ref_audio=ref_audio, ref_text=ref_text, **generation_kwargs ) if audio_list is None or (isinstance(audio_list, list) and len(audio_list) == 0): return backend_pb2.Result( success=False, message="No audio output generated" ) audio_data = audio_list[0] if isinstance(audio_list, list) else audio_list sf.write(request.dst, audio_data, sr) print(f"Saved output to {request.dst}", file=sys.stderr) except Exception as err: print(f"Error in TTS: {err}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def serve(address): server = grpc.server( futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ] ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.") args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/faster-qwen3-tts/install.sh ================================================ #!/bin/bash set -e EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation" backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi installRequirements ================================================ FILE: backend/python/faster-qwen3-tts/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu121 torch torchaudio faster-qwen3-tts ================================================ FILE: backend/python/faster-qwen3-tts/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio faster-qwen3-tts ================================================ FILE: backend/python/faster-qwen3-tts/requirements-install.txt ================================================ setuptools ================================================ FILE: backend/python/faster-qwen3-tts/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ torch torchaudio faster-qwen3-tts ================================================ FILE: backend/python/faster-qwen3-tts/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio faster-qwen3-tts ================================================ FILE: backend/python/faster-qwen3-tts/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 soundfile setuptools six anyio sox ================================================ FILE: backend/python/faster-qwen3-tts/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/faster-qwen3-tts/test.py ================================================ """ Tests for the faster-qwen3-tts gRPC backend. """ import unittest import subprocess import time import os import sys import tempfile import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): def setUp(self): self.service = subprocess.Popen( ["python3", "backend.py", "--addr", "localhost:50052"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, cwd=os.path.dirname(os.path.abspath(__file__)), ) time.sleep(15) def tearDown(self): self.service.terminate() try: self.service.communicate(timeout=5) except subprocess.TimeoutExpired: self.service.kill() self.service.communicate() def test_health(self): with grpc.insecure_channel("localhost:50052") as channel: stub = backend_pb2_grpc.BackendStub(channel) reply = stub.Health(backend_pb2.HealthMessage(), timeout=5.0) self.assertEqual(reply.message, b"OK") def test_load_model_requires_cuda(self): with grpc.insecure_channel("localhost:50052") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel( backend_pb2.ModelOptions( Model="Qwen/Qwen3-TTS-12Hz-0.6B-Base", CUDA=True, ), timeout=10.0, ) self.assertFalse(response.success) @unittest.skipUnless( __import__("torch").cuda.is_available(), "faster-qwen3-tts TTS requires CUDA", ) def test_tts(self): import soundfile as sf try: with grpc.insecure_channel("localhost:50052") as channel: stub = backend_pb2_grpc.BackendStub(channel) ref_audio = tempfile.NamedTemporaryFile(suffix='.wav', delete=False) ref_audio.close() try: sr = 22050 duration = 1.0 samples = int(sr * duration) sf.write(ref_audio.name, [0.0] * samples, sr) response = stub.LoadModel( backend_pb2.ModelOptions( Model="Qwen/Qwen3-TTS-12Hz-0.6B-Base", AudioPath=ref_audio.name, Options=["ref_text:Hello world"], ), timeout=600.0, ) self.assertTrue(response.success, response.message) with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as out: output_path = out.name try: tts_response = stub.TTS( backend_pb2.TTSRequest( text="Test output.", dst=output_path, language="English", ), timeout=120.0, ) self.assertTrue(tts_response.success, tts_response.message) self.assertTrue(os.path.exists(output_path)) self.assertGreater(os.path.getsize(output_path), 0) finally: if os.path.exists(output_path): os.unlink(output_path) finally: if os.path.exists(ref_audio.name): os.unlink(ref_audio.name) except Exception as err: self.fail(f"TTS test failed: {err}") if __name__ == "__main__": unittest.main() ================================================ FILE: backend/python/faster-qwen3-tts/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/faster-whisper/Makefile ================================================ .DEFAULT_GOAL := install .PHONY: install install: bash install.sh .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/faster-whisper/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for Faster Whisper TTS """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc import torch from faster_whisper import WhisperModel import grpc _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) COQUI_LANGUAGE = os.environ.get('COQUI_LANGUAGE', None) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): device = "cpu" # Get device # device = "cuda" if request.CUDA else "cpu" if request.CUDA: device = "cuda" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" try: print("Preparing models, please wait", file=sys.stderr) self.model = WhisperModel(request.Model, device=device, compute_type="default") except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) def AudioTranscription(self, request, context): resultSegments = [] text = "" try: segments, info = self.model.transcribe(request.dst, beam_size=5, condition_on_previous_text=False) id = 0 for segment in segments: print("[%.2fs -> %.2fs] %s" % (segment.start, segment.end, segment.text)) resultSegments.append(backend_pb2.TranscriptSegment(id=id, start=int(segment.start)*1e9, end=int(segment.end)*1e9, text=segment.text)) text += segment.text id += 1 except Exception as err: print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) raise err return backend_pb2.TranscriptResult(segments=resultSegments, text=text) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/faster-whisper/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi installRequirements ================================================ FILE: backend/python/faster-whisper/protogen.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto ================================================ FILE: backend/python/faster-whisper/requirements-cpu.txt ================================================ faster-whisper opencv-python accelerate compel peft sentencepiece torch==2.4.1 optimum-quanto ================================================ FILE: backend/python/faster-whisper/requirements-cublas12.txt ================================================ torch==2.4.1 faster-whisper opencv-python accelerate compel peft sentencepiece optimum-quanto ================================================ FILE: backend/python/faster-whisper/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch==2.9.1 faster-whisper opencv-python accelerate compel peft sentencepiece optimum-quanto ================================================ FILE: backend/python/faster-whisper/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch faster-whisper ================================================ FILE: backend/python/faster-whisper/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch optimum[openvino] faster-whisper ================================================ FILE: backend/python/faster-whisper/requirements-mps.txt ================================================ torch==2.7.1 faster-whisper opencv-python accelerate compel peft sentencepiece optimum-quanto ================================================ FILE: backend/python/faster-whisper/requirements.txt ================================================ grpcio==1.71.0 protobuf grpcio-tools ================================================ FILE: backend/python/faster-whisper/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/faster-whisper/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/fish-speech/Makefile ================================================ .PHONY: fish-speech fish-speech: bash install.sh .PHONY: run run: fish-speech @echo "Running fish-speech..." bash run.sh @echo "fish-speech run." .PHONY: test test: fish-speech @echo "Testing fish-speech..." bash test.sh @echo "fish-speech tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/fish-speech/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for fish-speech TTS """ from concurrent import futures import time import argparse import signal import sys import os import traceback import backend_pb2 import backend_pb2_grpc import torch import soundfile as sf import numpy as np import json import grpc def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1")) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", "utf-8")) def LoadModel(self, request, context): try: # Get device if torch.cuda.is_available(): print("CUDA is available", file=sys.stderr) device = "cuda" else: print("CUDA is not available", file=sys.stderr) device = "cpu" mps_available = ( hasattr(torch.backends, "mps") and torch.backends.mps.is_available() ) if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") # Validate mps availability if requested if device == "mps" and not torch.backends.mps.is_available(): print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr) device = "cpu" self.device = device self._torch_device = torch.device(device) options = request.Options # empty dict self.options = {} # The options are a list of strings in this form optname:optvalue for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value # Parse voices configuration from options self.voices = {} if "voices" in self.options: try: voices_data = self.options["voices"] if isinstance(voices_data, str): voices_list = json.loads(voices_data) else: voices_list = voices_data for voice_entry in voices_list: if not isinstance(voice_entry, dict): print( f"[WARNING] Invalid voice entry (not a dict): {voice_entry}", file=sys.stderr, ) continue name = voice_entry.get("name") audio = voice_entry.get("audio") ref_text = voice_entry.get("ref_text", "") if not name or not isinstance(name, str): print( f"[WARNING] Voice entry missing required 'name' field: {voice_entry}", file=sys.stderr, ) continue if not audio or not isinstance(audio, str): print( f"[WARNING] Voice entry missing required 'audio' field: {voice_entry}", file=sys.stderr, ) continue self.voices[name] = {"audio": audio, "ref_text": ref_text} print( f"[INFO] Registered voice '{name}' with audio: {audio}", file=sys.stderr, ) print(f"[INFO] Loaded {len(self.voices)} voice(s)", file=sys.stderr) except json.JSONDecodeError as e: print(f"[ERROR] Failed to parse voices JSON: {e}", file=sys.stderr) except Exception as e: print( f"[ERROR] Error processing voices configuration: {e}", file=sys.stderr, ) print(traceback.format_exc(), file=sys.stderr) # Store AudioPath, ModelFile, and ModelPath from LoadModel request self.audio_path = ( request.AudioPath if hasattr(request, "AudioPath") and request.AudioPath else None ) self.model_file = ( request.ModelFile if hasattr(request, "ModelFile") and request.ModelFile else None ) self.model_path = ( request.ModelPath if hasattr(request, "ModelPath") and request.ModelPath else None ) # Get model path from request model_path = request.Model if not model_path: model_path = "fishaudio/s2-pro" # If model_path looks like a HuggingFace repo ID (e.g. "fishaudio/fish-speech-1.5"), # download it locally first since fish-speech expects a local directory if "/" in model_path and not os.path.exists(model_path): from huggingface_hub import snapshot_download print( f"Downloading model from HuggingFace: {model_path}", file=sys.stderr, ) model_path = snapshot_download(repo_id=model_path) print(f"Model downloaded to: {model_path}", file=sys.stderr) # Determine precision if device in ("mps", "cpu"): precision = torch.float32 else: precision = torch.bfloat16 # Whether to use torch.compile compile_model = self.options.get("compile", False) print( f"Using device: {device}, precision: {precision}, compile: {compile_model}", file=sys.stderr, ) print(f"Loading model from: {model_path}", file=sys.stderr) # Import fish-speech modules from fish_speech.inference_engine import TTSInferenceEngine from fish_speech.models.dac.inference import load_model as load_decoder_model from fish_speech.models.text2semantic.inference import ( launch_thread_safe_queue, ) # Determine decoder checkpoint path # The codec model is typically at /codec.pth decoder_checkpoint = self.options.get("decoder_checkpoint", None) if not decoder_checkpoint: # Try common locations if os.path.isdir(model_path): candidate = os.path.join(model_path, "codec.pth") if os.path.exists(candidate): decoder_checkpoint = candidate # Launch LLaMA queue (runs in daemon thread) print("Launching LLaMA queue...", file=sys.stderr) llama_queue = launch_thread_safe_queue( checkpoint_path=model_path, device=device, precision=precision, compile=compile_model, ) # Load DAC decoder decoder_config = self.options.get("decoder_config", "modded_dac_vq") if not decoder_checkpoint: return backend_pb2.Result( success=False, message="Decoder checkpoint (codec.pth) not found. " "Ensure the model directory contains codec.pth or set " "decoder_checkpoint option.", ) print( f"Loading DAC decoder (config={decoder_config}, checkpoint={decoder_checkpoint})...", file=sys.stderr, ) decoder_model = load_decoder_model( config_name=decoder_config, checkpoint_path=decoder_checkpoint, device=device, ) # Create TTS inference engine self.engine = TTSInferenceEngine( llama_queue=llama_queue, decoder_model=decoder_model, precision=precision, compile=compile_model, ) print(f"Model loaded successfully: {model_path}", file=sys.stderr) return backend_pb2.Result(message="Model loaded successfully", success=True) except Exception as e: print(f"[ERROR] Loading model: {type(e).__name__}: {e}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.Result( success=False, message=f"Failed to load model: {e}" ) def _get_ref_audio_path(self, voice_name=None): """Get reference audio path from voices dict or stored AudioPath.""" if voice_name and voice_name in self.voices: audio_path = self.voices[voice_name]["audio"] if os.path.isabs(audio_path): return audio_path # Try relative to ModelFile if self.model_file: model_file_base = os.path.dirname(self.model_file) ref_path = os.path.join(model_file_base, audio_path) if os.path.exists(ref_path): return ref_path # Try relative to ModelPath if self.model_path: ref_path = os.path.join(self.model_path, audio_path) if os.path.exists(ref_path): return ref_path return audio_path # Fall back to legacy single-voice mode if not self.audio_path: return None if os.path.isabs(self.audio_path): return self.audio_path if self.model_file: model_file_base = os.path.dirname(self.model_file) ref_path = os.path.join(model_file_base, self.audio_path) if os.path.exists(ref_path): return ref_path if self.model_path: ref_path = os.path.join(self.model_path, self.audio_path) if os.path.exists(ref_path): return ref_path return self.audio_path def TTS(self, request, context): try: from fish_speech.utils.schema import ServeTTSRequest, ServeReferenceAudio if not request.dst: return backend_pb2.Result( success=False, message="dst (output path) is required" ) text = request.text.strip() if not text: return backend_pb2.Result(success=False, message="Text is empty") # Get generation parameters from options top_p = self.options.get("top_p", 0.8) temperature = self.options.get("temperature", 0.8) repetition_penalty = self.options.get("repetition_penalty", 1.1) max_new_tokens = self.options.get("max_new_tokens", 1024) chunk_length = self.options.get("chunk_length", 200) # Build references list for voice cloning references = [] voice_name = request.voice if request.voice else None if voice_name and voice_name in self.voices: ref_audio_path = self._get_ref_audio_path(voice_name) if ref_audio_path and os.path.exists(ref_audio_path): with open(ref_audio_path, "rb") as f: audio_bytes = f.read() ref_text = self.voices[voice_name].get("ref_text", "") references.append( ServeReferenceAudio(audio=audio_bytes, text=ref_text) ) print( f"[INFO] Using voice '{voice_name}' with reference audio: {ref_audio_path}", file=sys.stderr, ) elif self.audio_path: ref_audio_path = self._get_ref_audio_path() if ref_audio_path and os.path.exists(ref_audio_path): with open(ref_audio_path, "rb") as f: audio_bytes = f.read() ref_text = self.options.get("ref_text", "") references.append( ServeReferenceAudio(audio=audio_bytes, text=ref_text) ) print( f"[INFO] Using reference audio: {ref_audio_path}", file=sys.stderr, ) # Build ServeTTSRequest tts_request = ServeTTSRequest( text=text, references=references, top_p=top_p, temperature=temperature, repetition_penalty=repetition_penalty, max_new_tokens=max_new_tokens, chunk_length=chunk_length, ) # Run inference print(f"Generating speech for text: {text[:100]}...", file=sys.stderr) start_time = time.time() sample_rate = None audio_data = None for result in self.engine.inference(tts_request): if result.code == "final": sample_rate, audio_data = result.audio elif result.code == "error": error_msg = str(result.error) if result.error else "Unknown error" print(f"[ERROR] TTS inference error: {error_msg}", file=sys.stderr) return backend_pb2.Result( success=False, message=f"TTS inference error: {error_msg}" ) generation_duration = time.time() - start_time if audio_data is None or sample_rate is None: return backend_pb2.Result( success=False, message="No audio output generated" ) # Ensure audio_data is a numpy array if not isinstance(audio_data, np.ndarray): audio_data = np.array(audio_data) audio_duration = len(audio_data) / sample_rate if sample_rate > 0 else 0 print( f"[INFO] TTS generation completed: {generation_duration:.2f}s, " f"audio_duration={audio_duration:.2f}s, sample_rate={sample_rate}", file=sys.stderr, flush=True, ) # Save output sf.write(request.dst, audio_data, sample_rate) print(f"Saved {audio_duration:.2f}s audio to {request.dst}", file=sys.stderr) except Exception as err: print(f"Error in TTS: {err}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.Result( success=False, message=f"Unexpected {err=}, {type(err)=}" ) return backend_pb2.Result(success=True) def serve(address): server = grpc.server( futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ("grpc.max_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB ], ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/fish-speech/install.sh ================================================ #!/bin/bash set -e EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation" backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # fish-speech uses pyrootutils which requires a .project-root marker touch "${backend_dir}/.project-root" installRequirements # Clone fish-speech source (the pip package doesn't include inference modules) FISH_SPEECH_DIR="${EDIR}/fish-speech-src" FISH_SPEECH_REPO="https://github.com/fishaudio/fish-speech.git" FISH_SPEECH_BRANCH="main" if [ ! -d "${FISH_SPEECH_DIR}" ]; then echo "Cloning fish-speech source..." git clone --depth 1 --branch "${FISH_SPEECH_BRANCH}" "${FISH_SPEECH_REPO}" "${FISH_SPEECH_DIR}" else echo "Updating fish-speech source..." cd "${FISH_SPEECH_DIR}" && git pull && cd - fi # Remove pyaudio from fish-speech deps — it's only used by the upstream client tool # (tools/api_client.py) for speaker playback, not by our gRPC backend server. # It requires native portaudio libs which aren't available on all build environments. sed -i.bak '/"pyaudio"/d' "${FISH_SPEECH_DIR}/pyproject.toml" # Install fish-speech deps from source (without the package itself since we use PYTHONPATH) ensureVenv if [ "x${USE_PIP}" == "xtrue" ]; then pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e "${FISH_SPEECH_DIR}" else uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e "${FISH_SPEECH_DIR}" fi # fish-speech transitive deps (wandb, tensorboard) may downgrade protobuf to 3.x # but our generated backend_pb2.py requires protobuf 5+ ensureVenv if [ "x${USE_PIP}" == "xtrue" ]; then pip install "protobuf>=5.29.0" else uv pip install "protobuf>=5.29.0" fi ================================================ FILE: backend/python/fish-speech/package.sh ================================================ #!/bin/bash # Script to package runtime libraries for the fish-speech backend # This is needed because the final Docker image is FROM scratch, # so system libraries must be explicitly included. set -e CURDIR=$(dirname "$(realpath $0)") # Create lib directory mkdir -p $CURDIR/lib echo "fish-speech packaging completed successfully" ls -liah $CURDIR/lib/ ================================================ FILE: backend/python/fish-speech/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu torch torchaudio ================================================ FILE: backend/python/fish-speech/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu121 torch torchaudio ================================================ FILE: backend/python/fish-speech/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio ================================================ FILE: backend/python/fish-speech/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.3 torch==2.7.1+rocm6.3 torchaudio==2.7.1+rocm6.3 ================================================ FILE: backend/python/fish-speech/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch torchaudio ================================================ FILE: backend/python/fish-speech/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ torch torchaudio ================================================ FILE: backend/python/fish-speech/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio ================================================ FILE: backend/python/fish-speech/requirements-mps.txt ================================================ torch torchaudio ================================================ FILE: backend/python/fish-speech/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 soundfile setuptools six scipy numpy ================================================ FILE: backend/python/fish-speech/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/fish-speech/test.py ================================================ """ A test script to test the gRPC service """ import signal import threading import unittest import subprocess import time import os import sys import tempfile import backend_pb2 import backend_pb2_grpc import grpc BACKEND_LOG = "/tmp/fish-speech-backend.log" def _dump_backend_log(): """Print backend log — call before exiting so CI always shows it.""" if os.path.exists(BACKEND_LOG): with open(BACKEND_LOG, "r") as f: contents = f.read() if contents: print("=== Backend Log ===", file=sys.stderr, flush=True) print(contents, file=sys.stderr, flush=True) def _sigterm_handler(signum, frame): """Handle SIGTERM so the backend log is printed before exit.""" print(f"\nReceived signal {signum}, dumping backend log before exit...", file=sys.stderr, flush=True) _dump_backend_log() sys.exit(143) signal.signal(signal.SIGTERM, _sigterm_handler) def _tail_log(path, stop_event, interval=10): """Background thread that periodically prints new lines from the backend log.""" pos = 0 while not stop_event.is_set(): stop_event.wait(interval) try: with open(path, "r") as f: f.seek(pos) new = f.read() if new: print(f"[backend log] {new}", file=sys.stderr, end="", flush=True) pos = f.tell() except FileNotFoundError: pass class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ print("Starting backend server...", file=sys.stderr, flush=True) self.backend_log = open(BACKEND_LOG, "w") self.service = subprocess.Popen( ["python3", "backend.py", "--addr", "localhost:50051"], stdout=self.backend_log, stderr=self.backend_log, ) # Start tailing backend log so CI sees progress in real time self._log_stop = threading.Event() self._log_thread = threading.Thread( target=_tail_log, args=(BACKEND_LOG, self._log_stop), daemon=True ) self._log_thread.start() # Poll for readiness instead of a fixed sleep print("Waiting for backend to be ready...", file=sys.stderr, flush=True) max_wait = 60 start = time.time() ready = False while time.time() - start < max_wait: try: with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) resp = stub.Health(backend_pb2.HealthMessage(), timeout=2.0) if resp.message: ready = True break except Exception: pass # Check if process died if self.service.poll() is not None: self.fail(f"Backend process exited early with code {self.service.returncode}") time.sleep(2) elapsed = time.time() - start if not ready: self.fail(f"Backend not ready after {max_wait}s") print(f"Backend ready after {elapsed:.1f}s", file=sys.stderr, flush=True) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self._log_stop.set() self._log_thread.join(timeout=2) self.service.terminate() try: self.service.wait(timeout=5) except subprocess.TimeoutExpired: self.service.kill() self.service.wait() self.backend_log.close() _dump_backend_log() def test_tts(self): """ This method tests if the TTS generation works successfully """ with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) # Limit max_new_tokens for CPU testing (generation is very slow on CPU) print("Loading model fishaudio/s2-pro...", file=sys.stderr, flush=True) load_start = time.time() response = stub.LoadModel( backend_pb2.ModelOptions( Model="fishaudio/s2-pro", Options=["max_new_tokens:50"], ), timeout=1800.0 ) print( f"LoadModel response: success={response.success}, " f"message={response.message}, " f"took {time.time() - load_start:.1f}s", file=sys.stderr, flush=True ) self.assertTrue(response.success, f"LoadModel failed: {response.message}") # Create temporary output file with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: output_path = tmp_file.name tts_request = backend_pb2.TTSRequest( text="Hi.", dst=output_path ) # Allow up to 10 minutes for TTS generation on CPU print("Starting TTS generation...", file=sys.stderr, flush=True) tts_start = time.time() tts_response = stub.TTS(tts_request, timeout=600.0) print( f"TTS response: success={tts_response.success}, " f"took {time.time() - tts_start:.1f}s", file=sys.stderr, flush=True ) self.assertIsNotNone(tts_response) self.assertTrue(tts_response.success) # Verify output file exists and is not empty self.assertTrue(os.path.exists(output_path)) file_size = os.path.getsize(output_path) print(f"Output file size: {file_size} bytes", file=sys.stderr, flush=True) self.assertGreater(file_size, 0) # Cleanup os.unlink(output_path) if __name__ == "__main__": unittest.main() ================================================ FILE: backend/python/fish-speech/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/kitten-tts/Makefile ================================================ .PHONY: kitten-tts kitten-tts: bash install.sh .PHONY: run run: kitten-tts @echo "Running kitten-tts..." bash run.sh @echo "kitten-tts run." .PHONY: test test: kitten-tts @echo "Testing kitten-tts..." bash test.sh @echo "kitten-tts tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/kitten-tts/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for Kitten TTS """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc import torch from kittentts import KittenTTS import soundfile as sf import grpc _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) KITTEN_LANGUAGE = os.environ.get('KITTEN_LANGUAGE', None) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): self.AudioPath = None # List available KittenTTS models print("Available KittenTTS voices: expr-voice-2-m, expr-voice-2-f, expr-voice-3-m, expr-voice-3-f, expr-voice-4-m, expr-voice-4-f, expr-voice-5-m, expr-voice-5-f") if os.path.isabs(request.AudioPath): self.AudioPath = request.AudioPath elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath): # get base path of modelFile modelFileBase = os.path.dirname(request.ModelFile) # modify LoraAdapter to be relative to modelFileBase self.AudioPath = os.path.join(modelFileBase, request.AudioPath) try: print("Preparing KittenTTS model, please wait", file=sys.stderr) # Use the model name from request.Model, defaulting to "KittenML/kitten-tts-nano-0.1" if not specified model_name = request.Model if request.Model else "KittenML/kitten-tts-nano-0.1" self.tts = KittenTTS(model_name) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) def TTS(self, request, context): try: # KittenTTS doesn't use language parameter like TTS, so we ignore it # For multi-speaker models, use voice parameter voice = request.voice if request.voice else "expr-voice-2-f" # Generate audio using KittenTTS audio = self.tts.generate(request.text, voice=voice) # Save the audio using soundfile sf.write(request.dst, audio, 24000) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/kitten-tts/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi installRequirements ================================================ FILE: backend/python/kitten-tts/requirements-mps.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 https://github.com/KittenML/KittenTTS/releases/download/0.1/kittentts-0.1.0-py3-none-any.whl ================================================ FILE: backend/python/kitten-tts/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 https://github.com/KittenML/KittenTTS/releases/download/0.1/kittentts-0.1.0-py3-none-any.whl ================================================ FILE: backend/python/kitten-tts/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/kitten-tts/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(30) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits")) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_tts(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="tts_models/en/vctk/vits")) self.assertTrue(response.success) tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story") tts_response = stub.TTS(tts_request) self.assertIsNotNone(tts_response) except Exception as err: print(err) self.fail("TTS service failed") finally: self.tearDown() ================================================ FILE: backend/python/kitten-tts/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/kokoro/Makefile ================================================ .PHONY: kokoro kokoro: bash install.sh .PHONY: run run: kokoro @echo "Running kokoro..." bash run.sh @echo "kokoro run." .PHONY: test test: kokoro @echo "Testing kokoro..." bash test.sh @echo "kokoro tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/kokoro/README.md ================================================ # Kokoro TTS Backend for LocalAI This is a gRPC server backend for LocalAI that uses the Kokoro TTS pipeline. ## Creating a separate environment for kokoro project ```bash make kokoro ``` ## Testing the gRPC server ```bash make test ``` ## Features - Lightweight TTS model with 82 million parameters - Apache-licensed weights - Fast and cost-efficient - Multi-language support - Multiple voice options ================================================ FILE: backend/python/kokoro/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for Kokoro TTS """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc import torch from kokoro import KPipeline import soundfile as sf import grpc _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) KOKORO_LANG_CODE = os.environ.get('KOKORO_LANG_CODE', 'a') # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): try: print("Preparing Kokoro TTS pipeline, please wait", file=sys.stderr) # empty dict self.options = {} options = request.Options # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when # generating the images for opt in options: if ":" not in opt: continue key, value = opt.split(":") self.options[key] = value # Initialize Kokoro pipeline with language code lang_code = self.options.get("lang_code", KOKORO_LANG_CODE) self.pipeline = KPipeline(lang_code=lang_code) print(f"Kokoro TTS pipeline loaded with language code: {lang_code}", file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Kokoro TTS pipeline loaded successfully", success=True) def TTS(self, request, context): try: # Get voice from request, default to 'af_heart' if not specified voice = request.voice if request.voice else 'af_heart' # Generate audio using Kokoro pipeline generator = self.pipeline(request.text, voice=voice) speechs = [] # Get all the audio segment for i, (gs, ps, audio) in enumerate(generator): speechs.append(audio) print(f"Generated audio segment {i}: gs={gs}, ps={ps}", file=sys.stderr) # Merges the audio segments and writes them to the destination speech = torch.cat(speechs, dim=0) sf.write(request.dst, speech, 24000) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/kokoro/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then USE_PIP=true fi installRequirements ================================================ FILE: backend/python/kokoro/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu transformers accelerate torch kokoro soundfile ================================================ FILE: backend/python/kokoro/requirements-cublas12.txt ================================================ torch==2.7.1 torchaudio==2.7.1 transformers accelerate kokoro soundfile ================================================ FILE: backend/python/kokoro/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch==2.9.1 torchaudio==2.9.1 transformers accelerate kokoro soundfile ================================================ FILE: backend/python/kokoro/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0+rocm6.4 torchaudio==2.8.0+rocm6.4 transformers accelerate kokoro soundfile ================================================ FILE: backend/python/kokoro/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch torchaudio optimum[openvino] setuptools transformers==4.48.3 accelerate kokoro soundfile ================================================ FILE: backend/python/kokoro/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/ torch torchaudio transformers accelerate kokoro soundfile ================================================ FILE: backend/python/kokoro/requirements-mps.txt ================================================ torch==2.7.1 transformers accelerate kokoro soundfile ================================================ FILE: backend/python/kokoro/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 pip chardet ================================================ FILE: backend/python/kokoro/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/kokoro/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(30) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the Kokoro pipeline is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(language="a")) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Kokoro TTS pipeline loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_tts(self): """ This method tests if the TTS generation works successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(language="a")) self.assertTrue(response.success) tts_request = backend_pb2.TTSRequest( text="Kokoro is an open-weight TTS model with 82 million parameters.", voice="af_heart", dst="test_output.wav" ) tts_response = stub.TTS(tts_request) self.assertIsNotNone(tts_response) self.assertTrue(tts_response.success) except Exception as err: print(err) self.fail("TTS service failed") finally: self.tearDown() ================================================ FILE: backend/python/kokoro/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/mlx/Makefile ================================================ .PHONY: mlx mlx: bash install.sh .PHONY: run run: @echo "Running mlx..." bash run.sh @echo "mlx run." .PHONY: test test: @echo "Testing mlx..." bash test.sh @echo "mlx tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/mlx/backend.py ================================================ #!/usr/bin/env python3 import asyncio from concurrent import futures import argparse import signal import sys import os from typing import List import time import backend_pb2 import backend_pb2_grpc import grpc from mlx_lm import load, generate, stream_generate from mlx_lm.sample_utils import make_sampler from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache import mlx.core as mx import base64 import io from mlx_cache import ThreadSafeLRUPromptCache _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ A gRPC servicer that implements the Backend service defined in backend.proto. """ def Health(self, request, context): """ Returns a health check message. Args: request: The health check request. context: The gRPC context. Returns: backend_pb2.Reply: The health check reply. """ return backend_pb2.Reply(message=bytes("OK", 'utf-8')) async def LoadModel(self, request, context): """ Loads a language model using MLX. Args: request: The load model request. context: The gRPC context. Returns: backend_pb2.Result: The load model result. """ try: print(f"Loading MLX model: {request.Model}", file=sys.stderr) print(f"Request: {request}", file=sys.stderr) # Parse options like in the diffusers backend options = request.Options self.options = {} # The options are a list of strings in this form optname:optvalue # We store all the options in a dict for later use for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) # Split only on first colon to handle values with colons # Convert numeric values to appropriate types if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value print(f"Options: {self.options}", file=sys.stderr) # Build tokenizer config for MLX using options tokenizer_config = {} # Handle trust_remote_code from request or options if request.TrustRemoteCode or self.options.get("trust_remote_code", False): tokenizer_config["trust_remote_code"] = True # Handle EOS token from options if "eos_token" in self.options: tokenizer_config["eos_token"] = self.options["eos_token"] # Handle other tokenizer config options for key in ["pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]: if key in self.options: tokenizer_config[key] = self.options[key] # Load model and tokenizer using MLX if tokenizer_config: print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr) self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) else: self.model, self.tokenizer = load(request.Model) # Initialize thread-safe LRU prompt cache for efficient generation max_cache_entries = self.options.get("max_cache_entries", 10) self.max_kv_size = self.options.get("max_kv_size", None) self.model_key = request.Model self.lru_cache = ThreadSafeLRUPromptCache( max_size=max_cache_entries, can_trim_fn=can_trim_prompt_cache, trim_fn=trim_prompt_cache, ) except Exception as err: print(f"Error loading MLX model {err=}, {type(err)=}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Error loading MLX model: {err}") print("MLX model loaded successfully", file=sys.stderr) return backend_pb2.Result(message="MLX model loaded successfully", success=True) async def Predict(self, request, context): """ Generates text based on the given prompt and sampling parameters using MLX. Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. Args: request: The predict request. context: The gRPC context. Returns: backend_pb2.Reply: The predict result. """ prompt_cache = None cache_key = None try: # Prepare the prompt and tokenize for cache key prompt_text = self._prepare_prompt(request) cache_key = self._get_tokens_from_prompt(prompt_text) # Fetch nearest cache (exact, shorter prefix, or create new) prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( self.model_key, cache_key ) if prompt_cache is None: prompt_cache = make_prompt_cache(self.model, self.max_kv_size) remaining_tokens = cache_key # Build generation parameters using request attributes and options max_tokens, sampler_params = self._build_generation_params(request) print(f"Generating text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) # Create sampler with parameters sampler = make_sampler(**sampler_params) # Use stream_generate to track generated tokens for cache key generated_text = [] for response in stream_generate( self.model, self.tokenizer, prompt=remaining_tokens if remaining_tokens else cache_key, max_tokens=max_tokens, sampler=sampler, prompt_cache=prompt_cache, ): generated_text.append(response.text) cache_key.append(response.token) # Insert completed cache self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) return backend_pb2.Reply(message=bytes(''.join(generated_text), encoding='utf-8')) except Exception as e: print(f"Error in MLX Predict: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Generation failed: {str(e)}") return backend_pb2.Reply(message=bytes("", encoding='utf-8')) def Embedding(self, request, context): """ A gRPC method that calculates embeddings for a given sentence. Note: MLX-LM doesn't support embeddings directly. This method returns an error. Args: request: An EmbeddingRequest object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: An EmbeddingResult object that contains the calculated embeddings. """ print("Embeddings not supported in MLX backend", file=sys.stderr) context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Embeddings are not supported in the MLX backend.") return backend_pb2.EmbeddingResult() async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results using MLX. Uses thread-safe LRU prompt cache for efficient prefix reuse across requests. Args: request: The predict stream request. context: The gRPC context. Yields: backend_pb2.Reply: Streaming predict results. """ prompt_cache = None cache_key = None try: # Prepare the prompt and tokenize for cache key prompt_text = self._prepare_prompt(request) cache_key = self._get_tokens_from_prompt(prompt_text) # Fetch nearest cache (exact, shorter prefix, or create new) prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( self.model_key, cache_key ) if prompt_cache is None: prompt_cache = make_prompt_cache(self.model, self.max_kv_size) remaining_tokens = cache_key # Build generation parameters using request attributes and options max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) print(f"Streaming text with MLX - max_tokens: {max_tokens}, cache_hit: {len(remaining_tokens) < len(cache_key)}", file=sys.stderr) # Create sampler with parameters sampler = make_sampler(**sampler_params) # Stream text generation using MLX with proper parameters for response in stream_generate( self.model, self.tokenizer, prompt=remaining_tokens if remaining_tokens else cache_key, max_tokens=max_tokens, sampler=sampler, prompt_cache=prompt_cache, ): cache_key.append(response.token) yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) except Exception as e: print(f"Error in MLX PredictStream: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Streaming generation failed: {str(e)}") yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) finally: # Always insert cache, even on interruption if prompt_cache is not None and cache_key is not None: try: self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) except Exception as e: print(f"Error inserting cache: {e}", file=sys.stderr) def _prepare_prompt(self, request): """ Prepare the prompt for MLX generation, handling chat templates if needed. Args: request: The gRPC request containing prompt and message information. Returns: str: The prepared prompt. """ # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template if not request.Prompt and request.UseTokenizerTemplate and request.Messages: # Convert gRPC messages to the format expected by apply_chat_template messages = [] for msg in request.Messages: messages.append({"role": msg.role, "content": msg.content}) prompt = self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return prompt else: return request.Prompt def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: """ Tokenize prompt text for cache key generation. Args: prompt_text: The prompt string to tokenize. Returns: List[int]: List of token IDs. """ tokens = self.tokenizer.encode(prompt_text) if hasattr(tokens, 'tolist'): return tokens.tolist() return list(tokens) def _build_generation_params(self, request, default_max_tokens=200): """ Build generation parameters from request attributes and options. Args: request: The gRPC request. default_max_tokens: Default max_tokens if not specified. Returns: tuple: (max_tokens, sampler_params dict) """ # Extract max_tokens max_tokens = getattr(request, 'Tokens', default_max_tokens) if max_tokens == 0: max_tokens = default_max_tokens # Extract sampler parameters from request attributes temp = getattr(request, 'Temperature', 0.0) if temp == 0.0: temp = 0.6 # Default temperature top_p = getattr(request, 'TopP', 0.0) if top_p == 0.0: top_p = 1.0 # Default top_p min_p = getattr(request, 'MinP', 0.0) # min_p default of 0.0 means disabled (no filtering) top_k = getattr(request, 'TopK', 0) # top_k default of 0 means disabled (no filtering) # Initialize sampler parameters sampler_params = { 'temp': temp, 'top_p': top_p, 'min_p': min_p, 'top_k': top_k, 'xtc_threshold': 0.0, 'xtc_probability': 0.0, } # Add seed if specified seed = getattr(request, 'Seed', 0) if seed != 0: mx.random.seed(seed) # Override with options if available if hasattr(self, 'options'): # Max tokens from options if 'max_tokens' in self.options: max_tokens = self.options['max_tokens'] # Sampler parameters from options sampler_option_mapping = { 'temp': 'temp', 'temperature': 'temp', # alias 'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k', 'xtc_threshold': 'xtc_threshold', 'xtc_probability': 'xtc_probability', } for option_key, param_key in sampler_option_mapping.items(): if option_key in self.options: sampler_params[param_key] = self.options[option_key] # Handle seed from options if 'seed' in self.options: mx.random.seed(self.options['seed']) # Special tokens for XTC sampling (if tokenizer has eos_token_ids) xtc_special_tokens = [] if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids: xtc_special_tokens = list(self.tokenizer.eos_token_ids) elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: xtc_special_tokens = [self.tokenizer.eos_token_id] # Add newline token if available try: newline_tokens = self.tokenizer.encode("\n") xtc_special_tokens.extend(newline_tokens) except: pass # Skip if encoding fails sampler_params['xtc_special_tokens'] = xtc_special_tokens return max_tokens, sampler_params async def serve(address): # Start asyncio gRPC server server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address server.add_insecure_port(address) # Gracefully shutdown the server on SIGTERM or SIGINT loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler( sig, lambda: asyncio.ensure_future(server.stop(5)) ) # Start the server await server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Wait for the server to be terminated await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() asyncio.run(serve(args.addr)) ================================================ FILE: backend/python/mlx/install.sh ================================================ #!/bin/bash set -e USE_PIP=true PYTHON_VERSION="" backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi installRequirements ================================================ FILE: backend/python/mlx/mlx_cache.py ================================================ """ Thread-safe LRU prompt cache for MLX-based backends. Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.) with thread-safety additions for LocalAI's gRPC backend. Usage: from mlx_cache import ThreadSafeLRUPromptCache # In LoadModel: self.lru_cache = ThreadSafeLRUPromptCache(max_size=10) # In Predict/PredictStream: prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens) # ... generate ... self.lru_cache.insert_cache(model_key, tokens, prompt_cache) """ import copy import threading from collections import deque from dataclasses import dataclass from typing import Any, List, Optional, Tuple @dataclass class CacheEntry: """A cache entry with reference counting.""" prompt_cache: List[Any] count: int @dataclass class SearchResult: """Result of searching the cache trie.""" model: Any exact: Optional[List[int]] shorter: Optional[List[int]] longer: Optional[List[int]] common_prefix: int class ThreadSafeLRUPromptCache: """ Thread-safe LRU cache with prefix matching for prompt KV caches. This cache stores KV caches keyed by token sequences and supports: - Exact match: Return the cache for the exact token sequence - Shorter prefix match: Return a cache for a prefix of the tokens - Longer prefix match: If a longer sequence is cached and can be trimmed - LRU eviction: When max_size is exceeded, evict least recently used Thread safety is provided via a threading.Lock that protects all cache operations. Args: max_size: Maximum number of cache entries (default: 10) can_trim_fn: Optional function to check if a cache can be trimmed trim_fn: Optional function to trim a cache """ def __init__( self, max_size: int = 10, can_trim_fn: Optional[Any] = None, trim_fn: Optional[Any] = None, ): self.max_size = max_size self._cache = {} self._lru = deque() self._lock = threading.Lock() # Optional trim functions (for longer prefix reuse) self._can_trim_fn = can_trim_fn self._trim_fn = trim_fn def _search(self, model, tokens: List[int]) -> SearchResult: """ Search the cache for a prompt cache. Return exact or close match. The cache is organized as a trie where each node is keyed by a token. This allows efficient prefix matching. """ if model not in self._cache: return SearchResult(model, None, None, None, 0) current = self._cache[model] last_cache_index = -1 index = 0 # Traverse the trie following the token sequence while index < len(tokens) and tokens[index] in current: current = current[tokens[index]] if "cache" in current: last_cache_index = index index += 1 # Exact match - no need to search for longer or shorter caches if last_cache_index == len(tokens) - 1: return SearchResult(model, tuple(tokens), None, None, 0) # Find the shorter cache (a prefix that has a cache) # Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior. # Single-token prefixes are not matched, which allows longer cached # sequences to be preferred for trimming. This is acceptable because # real prompts with chat templates are always many tokens. shorter = None if last_cache_index > 0: shorter = tuple(tokens[: last_cache_index + 1]) # Check for caches that are longer than our token sequence longer = None common_prefix = index if index > 0 and last_cache_index <= 0: best = None stack = [(current, [])] while stack: current, extra = stack.pop() if "cache" in current: if best is None or len(extra) < len(best): best = extra else: for tok in current: stack.append((current[tok], extra + [tok])) if best is not None: longer = tuple(tokens[:index] + best) return SearchResult(model, None, shorter, longer, common_prefix) def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry: """Get a cache entry by traversing the trie.""" current = self._cache[model] for tok in tokens: current = current[tok] return current["cache"] def _delete(self, model, tokens: Tuple[int, ...]) -> None: """Delete a cache entry and clean up empty trie nodes.""" path = [self._cache[model]] for tok in tokens: path.append(path[-1][tok]) del path[-1]["cache"] # Clean up empty nodes bottom-up for i in reversed(range(len(tokens))): d_prev, d, t = path[i], path[i + 1], tokens[i] if len(d) > 0: break del d_prev[t] def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry: """ Extract a cache entry for exclusive use. If the entry has count > 1, deep copy and decrement. If count == 1, remove from cache entirely. """ cache_entry = self._get(model, tokens) if cache_entry.count == 1: self._delete(model, tokens) self._lru.remove((model, tokens)) return cache_entry cache_entry.count -= 1 return CacheEntry( copy.deepcopy(cache_entry.prompt_cache), 1, ) def fetch_nearest_cache( self, model, tokens: List[int] ) -> Tuple[Optional[List[Any]], List[int]]: """ Fetch the nearest cache for the given token sequence. Thread-safe. Returns (cache, remaining_tokens) where: - cache: The KV cache to use (or None if no cache found) - remaining_tokens: Tokens that still need to be processed Args: model: Model identifier (used to namespace caches) tokens: The full token sequence for the prompt Returns: Tuple of (prompt_cache, remaining_tokens) """ with self._lock: tokens_tuple = tuple(tokens) result = self._search(model, tokens) # Exact match - extract and return if result.exact is not None: cache_entry = self._extract(result.model, result.exact) return cache_entry.prompt_cache, [] # Shorter prefix match - extract and return remaining if result.shorter is not None: cache_entry = self._extract(result.model, result.shorter) prefix_len = len(result.shorter) return cache_entry.prompt_cache, list(tokens[prefix_len:]) # Longer prefix match - try to trim if possible if result.longer is not None and self._can_trim_fn is not None: cache_entry = self._get(result.model, result.longer) if self._can_trim_fn(cache_entry.prompt_cache): # Deep copy and trim trimmed_cache = copy.deepcopy(cache_entry.prompt_cache) prefix = min(len(tokens) - 1, result.common_prefix) num_to_trim = len(result.longer) - prefix if self._trim_fn is not None: self._trim_fn(trimmed_cache, num_to_trim) return trimmed_cache, list(tokens[prefix:]) # No match found return None, list(tokens) def insert_cache( self, model, tokens: List[int], prompt_cache: List[Any] ) -> None: """ Insert a cache entry after generation completes. Thread-safe. Handles LRU eviction if max_size is exceeded. Args: model: Model identifier (used to namespace caches) tokens: The full token sequence (prompt + generated) prompt_cache: The KV cache to store """ with self._lock: tokens_tuple = tuple(tokens) if model not in self._cache: self._cache[model] = {} current = self._cache[model] # Build trie path for tok in tokens_tuple: if tok not in current: current[tok] = {} current = current[tok] # Update or create entry if "cache" in current: current["cache"].count += 1 self._lru.remove((model, tokens_tuple)) else: current["cache"] = CacheEntry(prompt_cache, 1) # Update LRU order self._lru.append((model, tokens_tuple)) # Evict if over capacity if len(self._lru) > self.max_size: evict_model, evict_tokens = self._lru.popleft() self._delete(evict_model, evict_tokens) def clear(self) -> None: """Clear all cache entries. Thread-safe.""" with self._lock: self._cache.clear() self._lru.clear() def __len__(self) -> int: """Return the number of cache entries. Thread-safe.""" with self._lock: return len(self._lru) ================================================ FILE: backend/python/mlx/requirements-cpu.txt ================================================ mlx-lm mlx[cpu] ================================================ FILE: backend/python/mlx/requirements-cublas12.txt ================================================ mlx-lm mlx[cuda12] ================================================ FILE: backend/python/mlx/requirements-cublas13.txt ================================================ mlx-lm mlx[cuda13] ================================================ FILE: backend/python/mlx/requirements-l4t12.txt ================================================ mlx-lm mlx[cuda12] ================================================ FILE: backend/python/mlx/requirements-l4t13.txt ================================================ mlx-lm mlx[cuda13] ================================================ FILE: backend/python/mlx/requirements-mps.txt ================================================ mlx-lm ================================================ FILE: backend/python/mlx/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi setuptools ================================================ FILE: backend/python/mlx/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/mlx/test.py ================================================ import unittest import subprocess import time import grpc import backend_pb2 import backend_pb2_grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service. This class contains methods to test the startup and shutdown of the gRPC service. """ def setUp(self): self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"]) time.sleep(10) def tearDown(self) -> None: self.service.terminate() self.service.wait() def test_server_startup(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) self.assertEqual(response.message, "MLX model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_text(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) req = backend_pb2.PredictOptions(Prompt="The capital of France is") resp = stub.Predict(req) self.assertIsNotNone(resp.message) except Exception as err: print(err) self.fail("text service failed") finally: self.tearDown() def test_sampling_params(self): """ This method tests if all sampling parameters are correctly processed NOTE: this does NOT test for correctness, just that we received a compatible response """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) req = backend_pb2.PredictOptions( Prompt="The capital of France is", TopP=0.8, Tokens=50, Temperature=0.7, TopK=40, PresencePenalty=0.1, FrequencyPenalty=0.2, MinP=0.05, Seed=42, StopPrompts=["\n"], IgnoreEOS=True, ) resp = stub.Predict(req) self.assertIsNotNone(resp.message) except Exception as err: print(err) self.fail("sampling params service failed") finally: self.tearDown() def test_embedding(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct")) self.assertTrue(response.success) embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.") embedding_response = stub.Embedding(embedding_request) self.assertIsNotNone(embedding_response.embeddings) # assert that is a list of floats self.assertIsInstance(embedding_response.embeddings, list) # assert that the list is not empty self.assertTrue(len(embedding_response.embeddings) > 0) except Exception as err: print(err) self.fail("Embedding service failed") finally: self.tearDown() def test_concurrent_requests(self): """ This method tests that concurrent requests don't corrupt each other's cache state. This is a regression test for the race condition in the original implementation. """ import concurrent.futures try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) def make_request(prompt): req = backend_pb2.PredictOptions(Prompt=prompt, Tokens=20) return stub.Predict(req) # Run 5 concurrent requests with different prompts prompts = [ "The capital of France is", "The capital of Germany is", "The capital of Italy is", "The capital of Spain is", "The capital of Portugal is", ] with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor: futures = [executor.submit(make_request, p) for p in prompts] results = [f.result() for f in concurrent.futures.as_completed(futures)] # All results should be non-empty messages = [r.message for r in results] self.assertTrue(all(len(m) > 0 for m in messages), "All requests should return non-empty responses") print(f"Concurrent test passed: {len(messages)} responses received") except Exception as err: print(err) self.fail("Concurrent requests test failed") finally: self.tearDown() def test_cache_reuse(self): """ This method tests that repeated prompts reuse cached KV states. The second request should benefit from the cached prompt processing. """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) prompt = "The quick brown fox jumps over the lazy dog. " # First request - populates cache req1 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10) resp1 = stub.Predict(req1) self.assertIsNotNone(resp1.message) # Second request with same prompt - should reuse cache req2 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10) resp2 = stub.Predict(req2) self.assertIsNotNone(resp2.message) print(f"Cache reuse test passed: first={len(resp1.message)} bytes, second={len(resp2.message)} bytes") except Exception as err: print(err) self.fail("Cache reuse test failed") finally: self.tearDown() def test_prefix_cache_reuse(self): """ This method tests that prompts sharing a common prefix benefit from cached KV states. """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) # First request with base prompt prompt_base = "Once upon a time in a land far away, " req1 = backend_pb2.PredictOptions(Prompt=prompt_base, Tokens=10) resp1 = stub.Predict(req1) self.assertIsNotNone(resp1.message) # Second request with extended prompt (same prefix) prompt_extended = prompt_base + "there lived a brave knight who " req2 = backend_pb2.PredictOptions(Prompt=prompt_extended, Tokens=10) resp2 = stub.Predict(req2) self.assertIsNotNone(resp2.message) print(f"Prefix cache test passed: base={len(resp1.message)} bytes, extended={len(resp2.message)} bytes") except Exception as err: print(err) self.fail("Prefix cache reuse test failed") finally: self.tearDown() # Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py ================================================ FILE: backend/python/mlx/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/mlx/test_mlx_cache.py ================================================ """ Comprehensive unit tests for ThreadSafeLRUPromptCache. Tests all cache operation modes: - Exact match - Shorter prefix match - Longer prefix match (with trimming) - No match - LRU eviction - Reference counting - Multi-model namespacing - Thread safety with data integrity verification """ import unittest import concurrent.futures import threading import copy from mlx_cache import ThreadSafeLRUPromptCache class TestCacheExactMatch(unittest.TestCase): """Tests for exact match cache behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_exact_match_returns_cache_and_empty_remaining(self): """Exact match should return the cache with no remaining tokens.""" tokens = [1, 2, 3, 4, 5] mock_cache = ["kv_cache_data"] self.cache.insert_cache("model1", tokens, mock_cache) result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) self.assertEqual(result_cache, mock_cache) self.assertEqual(remaining, []) def test_exact_match_extracts_and_removes_from_cache(self): """Fetching exact match with count=1 should remove entry from cache.""" tokens = [1, 2, 3] self.cache.insert_cache("model1", tokens, ["cache"]) self.assertEqual(len(self.cache), 1) # First fetch extracts the entry self.cache.fetch_nearest_cache("model1", tokens) # Cache should now be empty self.assertEqual(len(self.cache), 0) # Second fetch should return None (no match) result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) self.assertIsNone(result_cache) self.assertEqual(remaining, tokens) class TestCacheShorterPrefix(unittest.TestCase): """Tests for shorter prefix match behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_shorter_prefix_returns_cache_with_remaining_tokens(self): """When cached prefix is shorter, return cache and remaining suffix.""" short_tokens = [1, 2, 3] long_tokens = [1, 2, 3, 4, 5, 6] mock_cache = ["prefix_cache"] self.cache.insert_cache("model1", short_tokens, mock_cache) result_cache, remaining = self.cache.fetch_nearest_cache("model1", long_tokens) self.assertEqual(result_cache, mock_cache) self.assertEqual(remaining, [4, 5, 6]) def test_shorter_prefix_correct_remaining_calculation(self): """Verify remaining tokens are calculated correctly for various prefix lengths.""" # Note: Single-token prefixes ([1] -> [1,2,3]) are deliberately not matched # to allow longer cached sequences to be preferred for trimming. # This matches upstream mlx_lm/server.py behavior. test_cases = [ # (cached_tokens, requested_tokens, expected_remaining) ([1, 2], [1, 2, 3, 4, 5], [3, 4, 5]), ([10, 20, 30, 40], [10, 20, 30, 40, 50], [50]), ] for cached, requested, expected_remaining in test_cases: with self.subTest(cached=cached, requested=requested): cache = ThreadSafeLRUPromptCache(max_size=10) cache.insert_cache("model", cached, ["cache"]) result_cache, remaining = cache.fetch_nearest_cache("model", requested) self.assertIsNotNone(result_cache) self.assertEqual(remaining, expected_remaining) def test_single_token_prefix_not_matched(self): """Single-token prefixes are not matched (by design, matches upstream). This allows longer cached sequences to be preferred for trimming, which provides better KV cache reuse. Single-token caches are rare in practice since real prompts with chat templates are many tokens. """ cache = ThreadSafeLRUPromptCache(max_size=10) cache.insert_cache("model", [1], ["cache"]) result_cache, remaining = cache.fetch_nearest_cache("model", [1, 2, 3]) # Single-token prefix is NOT matched self.assertIsNone(result_cache) self.assertEqual(remaining, [1, 2, 3]) class TestCacheLongerPrefix(unittest.TestCase): """Tests for longer prefix match behavior (trimming).""" def setUp(self): # Track trim calls for verification self.trim_calls = [] def mock_can_trim(cache): return True def mock_trim(cache, num_to_trim): self.trim_calls.append(num_to_trim) # Simulate trimming by modifying the cache cache.append(f"trimmed_{num_to_trim}") self.cache = ThreadSafeLRUPromptCache( max_size=10, can_trim_fn=mock_can_trim, trim_fn=mock_trim, ) def test_longer_prefix_triggers_trim(self): """When cached sequence is longer, should trim to match requested prefix.""" long_tokens = [1, 2, 3, 4, 5] short_tokens = [1, 2, 3] self.cache.insert_cache("model1", long_tokens, ["original_cache"]) result_cache, remaining = self.cache.fetch_nearest_cache("model1", short_tokens) # Should have called trim self.assertTrue(len(self.trim_calls) > 0, "trim_fn should have been called") # Result should be a trimmed copy, not the original self.assertIn("trimmed_", str(result_cache)) def test_longer_prefix_without_trim_fn_returns_no_match(self): """Without trim functions, longer prefix should not match.""" cache_no_trim = ThreadSafeLRUPromptCache(max_size=10) long_tokens = [1, 2, 3, 4, 5] short_tokens = [1, 2, 3] cache_no_trim.insert_cache("model1", long_tokens, ["cache"]) result_cache, remaining = cache_no_trim.fetch_nearest_cache("model1", short_tokens) # Without trim_fn, should return no match self.assertIsNone(result_cache) self.assertEqual(remaining, short_tokens) def test_longer_prefix_can_trim_false_returns_no_match(self): """When can_trim_fn returns False, should not attempt trim.""" cache = ThreadSafeLRUPromptCache( max_size=10, can_trim_fn=lambda c: False, trim_fn=lambda c, n: None, ) cache.insert_cache("model1", [1, 2, 3, 4, 5], ["cache"]) result_cache, remaining = cache.fetch_nearest_cache("model1", [1, 2, 3]) self.assertIsNone(result_cache) self.assertEqual(remaining, [1, 2, 3]) class TestCacheNoMatch(unittest.TestCase): """Tests for no match behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_empty_cache_returns_none(self): """Empty cache should return None and all tokens as remaining.""" tokens = [1, 2, 3] result_cache, remaining = self.cache.fetch_nearest_cache("model1", tokens) self.assertIsNone(result_cache) self.assertEqual(remaining, tokens) def test_different_prefix_returns_none(self): """Tokens with different prefix should not match.""" self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) # Completely different tokens result_cache, remaining = self.cache.fetch_nearest_cache("model1", [4, 5, 6]) self.assertIsNone(result_cache) self.assertEqual(remaining, [4, 5, 6]) def test_partial_prefix_mismatch_returns_none(self): """Tokens that diverge mid-sequence should not match.""" self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) # Same start but diverges result_cache, remaining = self.cache.fetch_nearest_cache("model1", [1, 2, 99]) self.assertIsNone(result_cache) self.assertEqual(remaining, [1, 2, 99]) def test_wrong_model_returns_none(self): """Different model key should not match.""" self.cache.insert_cache("model1", [1, 2, 3], ["cache"]) result_cache, remaining = self.cache.fetch_nearest_cache("model2", [1, 2, 3]) self.assertIsNone(result_cache) self.assertEqual(remaining, [1, 2, 3]) class TestCacheLRUEviction(unittest.TestCase): """Tests for LRU eviction behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=3) def test_evicts_oldest_when_full(self): """Should evict least recently used entry when capacity exceeded.""" self.cache.insert_cache("model", [1], ["cache1"]) self.cache.insert_cache("model", [2], ["cache2"]) self.cache.insert_cache("model", [3], ["cache3"]) self.assertEqual(len(self.cache), 3) # Insert 4th entry - should evict [1] self.cache.insert_cache("model", [4], ["cache4"]) self.assertEqual(len(self.cache), 3) # [1] should be evicted result, _ = self.cache.fetch_nearest_cache("model", [1]) self.assertIsNone(result) # [2], [3], [4] should still exist for tokens in [[2], [3], [4]]: # Re-insert since fetch extracts self.cache.insert_cache("model", tokens, [f"cache{tokens[0]}"]) result2, _ = self.cache.fetch_nearest_cache("model", [2]) self.assertIsNotNone(result2) def test_access_updates_lru_order(self): """Accessing an entry should move it to most recently used.""" self.cache.insert_cache("model", [1], ["cache1"]) self.cache.insert_cache("model", [2], ["cache2"]) self.cache.insert_cache("model", [3], ["cache3"]) # Access [1] to make it most recently used cache1, _ = self.cache.fetch_nearest_cache("model", [1]) # Re-insert it (simulating normal usage pattern) self.cache.insert_cache("model", [1], cache1) # Now insert two more entries - should evict [2] then [3], not [1] self.cache.insert_cache("model", [4], ["cache4"]) self.cache.insert_cache("model", [5], ["cache5"]) # [1] should still exist (was accessed, so not evicted) result1, _ = self.cache.fetch_nearest_cache("model", [1]) self.assertIsNotNone(result1) # [2] should be evicted (was oldest after [1] was accessed) result2, _ = self.cache.fetch_nearest_cache("model", [2]) self.assertIsNone(result2) class TestCacheReferenceCount(unittest.TestCase): """Tests for reference counting behavior.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_multiple_inserts_increment_count(self): """Inserting same tokens multiple times should increment count.""" tokens = [1, 2, 3] self.cache.insert_cache("model", tokens, ["cache"]) self.cache.insert_cache("model", tokens, ["cache"]) self.cache.insert_cache("model", tokens, ["cache"]) # Should still be one entry (with count=3 internally) self.assertEqual(len(self.cache), 1) # First two fetches should return copies (count decremented) result1, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertIsNotNone(result1) result2, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertIsNotNone(result2) # Third fetch extracts the last reference result3, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertIsNotNone(result3) # Fourth fetch should return None (entry fully extracted) result4, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertIsNone(result4) def test_extract_with_high_count_returns_deep_copy(self): """When count > 1, extract should return a deep copy.""" tokens = [1, 2, 3] original_cache = [{"nested": "data"}] self.cache.insert_cache("model", tokens, original_cache) self.cache.insert_cache("model", tokens, original_cache) # count=2 result1, _ = self.cache.fetch_nearest_cache("model", tokens) # Modify the returned cache result1[0]["nested"] = "modified" # Second fetch should get unmodified copy result2, _ = self.cache.fetch_nearest_cache("model", tokens) self.assertEqual(result2[0]["nested"], "data") class TestCacheMultiModel(unittest.TestCase): """Tests for multi-model namespacing.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_same_tokens_different_models_are_separate(self): """Same token sequence under different models should be independent.""" tokens = [1, 2, 3] self.cache.insert_cache("model_a", tokens, ["cache_a"]) self.cache.insert_cache("model_b", tokens, ["cache_b"]) self.assertEqual(len(self.cache), 2) result_a, _ = self.cache.fetch_nearest_cache("model_a", tokens) result_b, _ = self.cache.fetch_nearest_cache("model_b", tokens) self.assertEqual(result_a, ["cache_a"]) self.assertEqual(result_b, ["cache_b"]) def test_eviction_across_models(self): """LRU eviction should work across different models.""" cache = ThreadSafeLRUPromptCache(max_size=3) cache.insert_cache("model_a", [1], ["a1"]) cache.insert_cache("model_b", [1], ["b1"]) cache.insert_cache("model_a", [2], ["a2"]) self.assertEqual(len(cache), 3) # Insert 4th - should evict model_a:[1] (oldest) cache.insert_cache("model_b", [2], ["b2"]) result, _ = cache.fetch_nearest_cache("model_a", [1]) self.assertIsNone(result) class TestCacheThreadSafety(unittest.TestCase): """Tests for thread safety with data integrity verification.""" def test_concurrent_inserts_no_data_loss(self): """Concurrent inserts should not lose data.""" cache = ThreadSafeLRUPromptCache(max_size=100) num_threads = 10 inserts_per_thread = 20 def insert_entries(thread_id): for i in range(inserts_per_thread): tokens = [thread_id, i] cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"]) with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: futures = [executor.submit(insert_entries, tid) for tid in range(num_threads)] concurrent.futures.wait(futures) # Verify expected number of entries (may be less due to LRU eviction with max_size=100) # But should be exactly 100 since we inserted exactly 200 and max_size is 100 self.assertEqual(len(cache), 100) def test_concurrent_fetch_and_insert_no_corruption(self): """Concurrent fetches and inserts should not corrupt data.""" cache = ThreadSafeLRUPromptCache(max_size=50) errors = [] lock = threading.Lock() # Pre-populate with known data for i in range(20): cache.insert_cache("model", [i], [f"original_{i}"]) def fetch_and_verify(thread_id): try: for _ in range(50): token_id = thread_id % 20 result, remaining = cache.fetch_nearest_cache("model", [token_id]) if result is not None: # Verify data integrity expected_prefix = f"original_{token_id}" if not str(result[0]).startswith("original_"): with lock: errors.append(f"Corrupted data: {result}") # Re-insert to keep cache populated cache.insert_cache("model", [token_id], result) except Exception as e: with lock: errors.append(str(e)) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(fetch_and_verify, tid) for tid in range(10)] concurrent.futures.wait(futures) self.assertEqual(errors, [], f"Thread safety errors: {errors}") def test_concurrent_operations_maintain_cache_bounds(self): """Cache size should never exceed max_size under concurrent operations.""" max_size = 10 cache = ThreadSafeLRUPromptCache(max_size=max_size) size_violations = [] lock = threading.Lock() def random_operations(thread_id): import random for i in range(100): tokens = [random.randint(0, 50)] if random.random() < 0.7: cache.insert_cache("model", tokens, [f"cache_{thread_id}_{i}"]) else: cache.fetch_nearest_cache("model", tokens) current_size = len(cache) if current_size > max_size: with lock: size_violations.append(current_size) with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor: futures = [executor.submit(random_operations, tid) for tid in range(10)] concurrent.futures.wait(futures) self.assertEqual(size_violations, [], f"Size exceeded max: {size_violations}") self.assertLessEqual(len(cache), max_size) class TestCacheClear(unittest.TestCase): """Tests for cache clear operation.""" def setUp(self): self.cache = ThreadSafeLRUPromptCache(max_size=10) def test_clear_removes_all_entries(self): """Clear should remove all entries.""" self.cache.insert_cache("model1", [1, 2], ["cache1"]) self.cache.insert_cache("model2", [3, 4], ["cache2"]) self.cache.insert_cache("model1", [5, 6], ["cache3"]) self.assertEqual(len(self.cache), 3) self.cache.clear() self.assertEqual(len(self.cache), 0) def test_clear_allows_new_inserts(self): """After clear, new inserts should work normally.""" self.cache.insert_cache("model", [1], ["cache1"]) self.cache.clear() self.cache.insert_cache("model", [2], ["cache2"]) self.assertEqual(len(self.cache), 1) result, _ = self.cache.fetch_nearest_cache("model", [2]) self.assertEqual(result, ["cache2"]) if __name__ == "__main__": unittest.main() ================================================ FILE: backend/python/mlx-audio/Makefile ================================================ .PHONY: mlx-audio mlx-audio: bash install.sh .PHONY: run run: mlx-audio @echo "Running mlx-audio..." bash run.sh @echo "mlx run." .PHONY: test test: mlx-audio @echo "Testing mlx-audio..." bash test.sh @echo "mlx tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/mlx-audio/backend.py ================================================ #!/usr/bin/env python3 import asyncio from concurrent import futures import argparse import signal import sys import os import shutil import glob from typing import List import time import tempfile import backend_pb2 import backend_pb2_grpc import grpc from mlx_audio.tts.utils import load_model import soundfile as sf import numpy as np import uuid def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ A gRPC servicer that implements the Backend service defined in backend.proto. This backend provides TTS (Text-to-Speech) functionality using MLX-Audio. """ def Health(self, request, context): """ Returns a health check message. Args: request: The health check request. context: The gRPC context. Returns: backend_pb2.Reply: The health check reply. """ return backend_pb2.Reply(message=bytes("OK", 'utf-8')) async def LoadModel(self, request, context): """ Loads a TTS model using MLX-Audio. Args: request: The load model request. context: The gRPC context. Returns: backend_pb2.Result: The load model result. """ try: print(f"Loading MLX-Audio TTS model: {request.Model}", file=sys.stderr) print(f"Request: {request}", file=sys.stderr) # Parse options like in the kokoro backend options = request.Options self.options = {} # The options are a list of strings in this form optname:optvalue # We store all the options in a dict for later use for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) # Split only on first colon to handle values with colons # Convert numeric values to appropriate types if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value print(f"Options: {self.options}", file=sys.stderr) # Load the model using MLX-Audio's load_model function try: self.tts_model = load_model(request.Model) self.model_path = request.Model print(f"TTS model loaded successfully from {request.Model}", file=sys.stderr) except Exception as model_err: print(f"Error loading TTS model: {model_err}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Failed to load model: {model_err}") except Exception as err: print(f"Error loading MLX-Audio TTS model {err=}, {type(err)=}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Error loading MLX-Audio TTS model: {err}") print("MLX-Audio TTS model loaded successfully", file=sys.stderr) return backend_pb2.Result(message="MLX-Audio TTS model loaded successfully", success=True) def TTS(self, request, context): """ Generates TTS audio from text using MLX-Audio. Args: request: A TTSRequest object containing text, model, destination, voice, and language. context: A grpc.ServicerContext object that provides information about the RPC. Returns: A Result object indicating success or failure. """ try: # Check if model is loaded if not hasattr(self, 'tts_model') or self.tts_model is None: return backend_pb2.Result(success=False, message="TTS model not loaded. Please call LoadModel first.") print(f"Generating TTS with MLX-Audio - text: {request.text[:50]}..., voice: {request.voice}, language: {request.language}", file=sys.stderr) # Handle speed parameter based on model type speed_value = self._handle_speed_parameter(request, self.model_path) # Map language names to codes if needed lang_code = self._map_language_code(request.language, request.voice) # Prepare generation parameters gen_params = { "text": request.text, "speed": speed_value, "verbose": False, } # Add model-specific parameters if request.voice and request.voice.strip(): gen_params["voice"] = request.voice # Check if model supports language codes (primarily Kokoro) if "kokoro" in self.model_path.lower(): gen_params["lang_code"] = lang_code # Add pitch and gender for Spark models if "spark" in self.model_path.lower(): gen_params["pitch"] = 1.0 # Default to moderate gen_params["gender"] = "female" # Default to female print(f"Generation parameters: {gen_params}", file=sys.stderr) # Generate audio using the loaded model try: results = self.tts_model.generate(**gen_params) except Exception as gen_err: print(f"Error during TTS generation: {gen_err}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"TTS generation failed: {gen_err}") # Process the generated audio segments audio_arrays = [] for segment in results: audio_arrays.append(segment.audio) # If no segments, return error if not audio_arrays: print("No audio segments generated", file=sys.stderr) return backend_pb2.Result(success=False, message="No audio generated") # Concatenate all segments cat_audio = np.concatenate(audio_arrays, axis=0) # Generate output filename and path if request.dst: output_path = request.dst else: unique_id = str(uuid.uuid4()) filename = f"tts_{unique_id}.wav" output_path = filename # Write the audio as a WAV try: sf.write(output_path, cat_audio, 24000) print(f"Successfully wrote audio file to {output_path}", file=sys.stderr) # Verify the file exists and has content if not os.path.exists(output_path): print(f"File was not created at {output_path}", file=sys.stderr) return backend_pb2.Result(success=False, message="Failed to create audio file") file_size = os.path.getsize(output_path) if file_size == 0: print("File was created but is empty", file=sys.stderr) return backend_pb2.Result(success=False, message="Generated audio file is empty") print(f"Audio file size: {file_size} bytes", file=sys.stderr) except Exception as write_err: print(f"Error writing audio file: {write_err}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Failed to save audio: {write_err}") return backend_pb2.Result(success=True, message=f"TTS audio generated successfully: {output_path}") except Exception as e: print(f"Error in MLX-Audio TTS: {e}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"TTS generation failed: {str(e)}") async def Predict(self, request, context): """ Generates TTS audio based on the given prompt using MLX-Audio TTS. This is a fallback method for compatibility with the Predict endpoint. Args: request: The predict request. context: The gRPC context. Returns: backend_pb2.Reply: The predict result. """ try: # Check if model is loaded if not hasattr(self, 'tts_model') or self.tts_model is None: context.set_code(grpc.StatusCode.FAILED_PRECONDITION) context.set_details("TTS model not loaded. Please call LoadModel first.") return backend_pb2.Reply(message=bytes("", encoding='utf-8')) # For TTS, we expect the prompt to contain the text to synthesize if not request.Prompt: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("Prompt is required for TTS generation") return backend_pb2.Reply(message=bytes("", encoding='utf-8')) # Handle speed parameter based on model type speed_value = self._handle_speed_parameter(request, self.model_path) # Map language names to codes if needed lang_code = self._map_language_code(None, None) # Use defaults for Predict # Prepare generation parameters gen_params = { "text": request.Prompt, "speed": speed_value, "verbose": False, } # Add model-specific parameters if hasattr(self, 'options') and 'voice' in self.options: gen_params["voice"] = self.options['voice'] # Check if model supports language codes (primarily Kokoro) if "kokoro" in self.model_path.lower(): gen_params["lang_code"] = lang_code print(f"Generating TTS with MLX-Audio - text: {request.Prompt[:50]}..., params: {gen_params}", file=sys.stderr) # Generate audio using the loaded model try: results = self.tts_model.generate(**gen_params) except Exception as gen_err: print(f"Error during TTS generation: {gen_err}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"TTS generation failed: {gen_err}") return backend_pb2.Reply(message=bytes("", encoding='utf-8')) # Process the generated audio segments audio_arrays = [] for segment in results: audio_arrays.append(segment.audio) # If no segments, return error if not audio_arrays: print("No audio segments generated", file=sys.stderr) return backend_pb2.Reply(message=bytes("No audio generated", encoding='utf-8')) # Concatenate all segments cat_audio = np.concatenate(audio_arrays, axis=0) duration = len(cat_audio) / 24000 # Assuming 24kHz sample rate # Return success message with audio information response = f"TTS audio generated successfully. Duration: {duration:.2f}s, Sample rate: 24000Hz" return backend_pb2.Reply(message=bytes(response, encoding='utf-8')) except Exception as e: print(f"Error in MLX-Audio TTS Predict: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"TTS generation failed: {str(e)}") return backend_pb2.Reply(message=bytes("", encoding='utf-8')) def _handle_speed_parameter(self, request, model_path): """ Handle speed parameter based on model type. Args: request: The TTSRequest object. model_path: The model path to determine model type. Returns: float: The processed speed value. """ # Get speed from options if available speed = 1.0 if hasattr(self, 'options') and 'speed' in self.options: speed = self.options['speed'] # Handle speed parameter based on model type if "spark" in model_path.lower(): # Spark actually expects float values that map to speed descriptions speed_map = { "very_low": 0.0, "low": 0.5, "moderate": 1.0, "high": 1.5, "very_high": 2.0, } if isinstance(speed, str) and speed in speed_map: speed_value = speed_map[speed] else: # Try to use as float, default to 1.0 (moderate) if invalid try: speed_value = float(speed) if speed_value not in [0.0, 0.5, 1.0, 1.5, 2.0]: speed_value = 1.0 # Default to moderate except: speed_value = 1.0 # Default to moderate else: # Other models use float speed values try: speed_value = float(speed) if speed_value < 0.5 or speed_value > 2.0: speed_value = 1.0 # Default to 1.0 if out of range except ValueError: speed_value = 1.0 # Default to 1.0 if invalid return speed_value def _map_language_code(self, language, voice): """ Map language names to codes if needed. Args: language: The language parameter from the request. voice: The voice parameter from the request. Returns: str: The language code. """ if not language: # Default to voice[0] if not found return voice[0] if voice else "a" # Map language names to codes if needed language_map = { "american_english": "a", "british_english": "b", "spanish": "e", "french": "f", "hindi": "h", "italian": "i", "portuguese": "p", "japanese": "j", "mandarin_chinese": "z", # Also accept direct language codes "a": "a", "b": "b", "e": "e", "f": "f", "h": "h", "i": "i", "p": "p", "j": "j", "z": "z", } return language_map.get(language.lower(), language) def _build_generation_params(self, request, default_speed=1.0): """ Build generation parameters from request attributes and options for MLX-Audio TTS. Args: request: The gRPC request. default_speed: Default speed if not specified. Returns: dict: Generation parameters for MLX-Audio """ # Initialize generation parameters for MLX-Audio TTS generation_params = { 'speed': default_speed, 'voice': 'af_heart', # Default voice 'lang_code': 'a', # Default language code } # Extract parameters from request attributes if hasattr(request, 'Temperature') and request.Temperature > 0: # Temperature could be mapped to speed variation generation_params['speed'] = 1.0 + (request.Temperature - 0.5) * 0.5 # Override with options if available if hasattr(self, 'options'): # Speed from options if 'speed' in self.options: generation_params['speed'] = self.options['speed'] # Voice from options if 'voice' in self.options: generation_params['voice'] = self.options['voice'] # Language code from options if 'lang_code' in self.options: generation_params['lang_code'] = self.options['lang_code'] # Model-specific parameters param_option_mapping = { 'temp': 'speed', 'temperature': 'speed', 'top_p': 'speed', # Map top_p to speed variation } for option_key, param_key in param_option_mapping.items(): if option_key in self.options: if param_key == 'speed': # Ensure speed is within reasonable bounds speed_val = float(self.options[option_key]) if 0.5 <= speed_val <= 2.0: generation_params[param_key] = speed_val return generation_params async def serve(address): # Start asyncio gRPC server server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address server.add_insecure_port(address) # Gracefully shutdown the server on SIGTERM or SIGINT loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler( sig, lambda: asyncio.ensure_future(server.stop(5)) ) # Start the server await server.start() print("MLX-Audio TTS Server started. Listening on: " + address, file=sys.stderr) # Wait for the server to be terminated await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the MLX-Audio TTS gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() asyncio.run(serve(args.addr)) ================================================ FILE: backend/python/mlx-audio/install.sh ================================================ #!/bin/bash set -e USE_PIP=true backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi installRequirements ================================================ FILE: backend/python/mlx-audio/requirements-cpu.txt ================================================ git+https://github.com/Blaizzy/mlx-audio mlx[cpu] ================================================ FILE: backend/python/mlx-audio/requirements-cublas12.txt ================================================ git+https://github.com/Blaizzy/mlx-audio mlx[cuda12] ================================================ FILE: backend/python/mlx-audio/requirements-cublas13.txt ================================================ git+https://github.com/Blaizzy/mlx-audio mlx[cuda13] ================================================ FILE: backend/python/mlx-audio/requirements-l4t12.txt ================================================ git+https://github.com/Blaizzy/mlx-audio mlx[cuda12] ================================================ FILE: backend/python/mlx-audio/requirements-l4t13.txt ================================================ git+https://github.com/Blaizzy/mlx-audio mlx[cuda13] ================================================ FILE: backend/python/mlx-audio/requirements-mps.txt ================================================ git+https://github.com/Blaizzy/mlx-audio ================================================ FILE: backend/python/mlx-audio/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi setuptools mlx-audio soundfile numpy ================================================ FILE: backend/python/mlx-audio/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/mlx-audio/test.py ================================================ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc import unittest import subprocess import time import grpc import backend_pb2_grpc import backend_pb2 class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service. This class contains methods to test the startup and shutdown of the gRPC service. """ def setUp(self): self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"]) time.sleep(10) def tearDown(self) -> None: self.service.terminate() self.service.wait() def test_server_startup(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the TTS model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit")) self.assertTrue(response.success) self.assertEqual(response.message, "MLX-Audio TTS model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_tts_generation(self): """ This method tests if TTS audio is generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit")) self.assertTrue(response.success) # Test TTS generation tts_req = backend_pb2.TTSRequest( text="Hello, this is a test of the MLX-Audio TTS system.", model="mlx-community/Kokoro-82M-4bit", voice="af_heart", language="a" ) tts_resp = stub.TTS(tts_req) self.assertTrue(tts_resp.success) self.assertIn("TTS audio generated successfully", tts_resp.message) except Exception as err: print(err) self.fail("TTS service failed") finally: self.tearDown() def test_tts_with_options(self): """ This method tests if TTS works with various options and parameters """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions( Model="mlx-community/Kokoro-82M-4bit", Options=["voice:af_soft", "speed:1.2", "lang_code:b"] )) self.assertTrue(response.success) # Test TTS generation with different voice and language tts_req = backend_pb2.TTSRequest( text="Hello, this is a test with British English accent.", model="mlx-community/Kokoro-82M-4bit", voice="af_soft", language="b" ) tts_resp = stub.TTS(tts_req) self.assertTrue(tts_resp.success) self.assertIn("TTS audio generated successfully", tts_resp.message) except Exception as err: print(err) self.fail("TTS with options service failed") finally: self.tearDown() def test_tts_multilingual(self): """ This method tests if TTS works with different languages """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Kokoro-82M-4bit")) self.assertTrue(response.success) # Test Spanish TTS tts_req = backend_pb2.TTSRequest( text="Hola, esto es una prueba del sistema TTS MLX-Audio.", model="mlx-community/Kokoro-82M-4bit", voice="af_heart", language="e" ) tts_resp = stub.TTS(tts_req) self.assertTrue(tts_resp.success) self.assertIn("TTS audio generated successfully", tts_resp.message) except Exception as err: print(err) self.fail("Multilingual TTS service failed") finally: self.tearDown() ================================================ FILE: backend/python/mlx-audio/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/mlx-distributed/Makefile ================================================ .PHONY: mlx-distributed mlx-distributed: bash install.sh .PHONY: run run: @echo "Running mlx-distributed..." bash run.sh @echo "mlx-distributed run." .PHONY: test test: @echo "Testing mlx-distributed..." bash test.sh @echo "mlx-distributed tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/mlx-distributed/backend.py ================================================ #!/usr/bin/env python3 """ MLX Distributed Inference Backend for LocalAI. Two startup modes: 1. Server mode (started by LocalAI automatically): run.sh --addr localhost:50051 Distributed config comes from LoadModel options or env vars. 2. Worker mode (started by CLI for remote ranks): run.sh --worker --hostfile hosts.json --rank 1 --backend ring Enters a loop waiting for commands from rank 0. """ import asyncio from concurrent import futures import argparse import json import os import signal import sys import tempfile from typing import List import grpc import backend_pb2 import backend_pb2_grpc MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) def mlx_distributed_init(rank, hostfile, backend="ring", coordinator=None): """Initialize MLX distributed runtime. Ring: MLX_HOSTFILE points to a JSON array of "ip:port" strings. Each rank binds to its own entry (hostfile[rank]) and connects to neighbors for the ring pipeline. JACCL: MLX_IBV_DEVICES points to a JSON 2D matrix of RDMA device names. MLX_JACCL_COORDINATOR is rank 0's ip:port where it runs a TCP service that helps all ranks establish RDMA connections. """ import mlx.core as mx if backend == "ring": os.environ["MLX_HOSTFILE"] = hostfile os.environ["MLX_RANK"] = str(rank) os.environ["MLX_RING_VERBOSE"] = "1" return mx.distributed.init(backend="ring", strict=True) elif backend == "jaccl": os.environ["MLX_IBV_DEVICES"] = hostfile os.environ["MLX_RANK"] = str(rank) if coordinator: os.environ["MLX_JACCL_COORDINATOR"] = coordinator return mx.distributed.init(backend="jaccl", strict=True) else: raise ValueError(f"Unknown backend: {backend}") def is_float(s): try: float(s) return True except ValueError: return False def is_int(s): try: int(s) return True except ValueError: return False def parse_options(options): """Parse key:value option strings into a dict.""" result = {} for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" result[key] = value return result class BackendServicer(backend_pb2_grpc.BackendServicer): """gRPC servicer for distributed MLX inference (runs on rank 0). When started by LocalAI (server mode), distributed init happens at LoadModel time using config from model options or environment variables. """ def __init__(self): self.group = None self.dist_backend = None self.model = None self.tokenizer = None self.coordinator = None self.options = {} self.lru_cache = None self.model_key = None self.max_kv_size = None def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) async def LoadModel(self, request, context): try: import mlx.core as mx from mlx_lm import load from mlx_lm.models.cache import make_prompt_cache, can_trim_prompt_cache, trim_prompt_cache print(f"[Rank 0] Loading model: {request.Model}", file=sys.stderr) self.options = parse_options(request.Options) print(f"Options: {self.options}", file=sys.stderr) # Get distributed config from model options, falling back to env vars. # If neither is set, run as single-node (no distributed). hostfile = self.options.get("hostfile", os.environ.get("MLX_DISTRIBUTED_HOSTFILE", "")) dist_backend = str(self.options.get("distributed_backend", os.environ.get("MLX_DISTRIBUTED_BACKEND", "ring"))) # JACCL coordinator: rank 0 reads from env (set by CLI --coordinator). # Not in model options — rank 0 is the coordinator, workers get # the address via their own --coordinator CLI flag. jaccl_coordinator = os.environ.get("MLX_JACCL_COORDINATOR", "") if hostfile: from coordinator import DistributedCoordinator, CMD_LOAD_MODEL from sharding import pipeline_auto_parallel print(f"[Rank 0] Initializing distributed: backend={dist_backend}, hostfile={hostfile}", file=sys.stderr) self.dist_backend = dist_backend self.group = mlx_distributed_init( rank=0, hostfile=hostfile, backend=dist_backend, coordinator=jaccl_coordinator or None, ) self.coordinator = DistributedCoordinator(self.group) self.coordinator.broadcast_command(CMD_LOAD_MODEL) self.coordinator.broadcast_model_name(request.Model) else: print("[Rank 0] No hostfile configured, running single-node", file=sys.stderr) # Build tokenizer config from request and options tokenizer_config = {} if request.TrustRemoteCode or self.options.get("trust_remote_code", False): tokenizer_config["trust_remote_code"] = True # Token overrides from options for key in ["eos_token", "pad_token", "bos_token", "unk_token", "sep_token", "cls_token", "mask_token"]: if key in self.options: tokenizer_config[key] = self.options[key] if tokenizer_config: print(f"Loading with tokenizer_config: {tokenizer_config}", file=sys.stderr) self.model, self.tokenizer = load(request.Model, tokenizer_config=tokenizer_config) else: self.model, self.tokenizer = load(request.Model) if self.group is not None: from sharding import pipeline_auto_parallel self.model = pipeline_auto_parallel(self.model, self.group) print(f"[Rank 0] Model loaded and sharded across {self.group.size()} ranks", file=sys.stderr) else: # Single-node: set up prompt cache for efficient generation from mlx_cache import ThreadSafeLRUPromptCache max_cache_entries = self.options.get("max_cache_entries", 10) self.max_kv_size = self.options.get("max_kv_size", None) self.model_key = request.Model self.lru_cache = ThreadSafeLRUPromptCache( max_size=max_cache_entries, can_trim_fn=can_trim_prompt_cache, trim_fn=trim_prompt_cache, ) print("[Rank 0] Model loaded (single-node with prompt cache)", file=sys.stderr) except Exception as err: print(f"[Rank 0] Error loading model: {err}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Error loading model: {err}") return backend_pb2.Result(message="Model loaded successfully", success=True) async def Predict(self, request, context): prompt_cache = None cache_key = None try: import mlx.core as mx from mlx_lm import stream_generate from mlx_lm.sample_utils import make_sampler prompt_text = self._prepare_prompt(request) tokens = self._get_tokens_from_prompt(prompt_text) if self.coordinator: from coordinator import CMD_GENERATE self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) self.coordinator.broadcast_tokens(tokens) max_tokens, sampler_params = self._build_generation_params(request) if self.coordinator: gen_params = self.coordinator.broadcast_generation_params( max_tokens=max_tokens, temperature=sampler_params.get('temp', 0.6), top_p=sampler_params.get('top_p', 1.0), ) max_tokens = gen_params["max_tokens"] sampler = make_sampler(**sampler_params) # Use prompt cache in single-node mode gen_kwargs = {} if self.lru_cache is not None: from mlx_lm.models.cache import make_prompt_cache cache_key = list(tokens) prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( self.model_key, cache_key ) if prompt_cache is None: prompt_cache = make_prompt_cache(self.model, self.max_kv_size) remaining_tokens = cache_key gen_kwargs['prompt_cache'] = prompt_cache tokens = remaining_tokens if remaining_tokens else cache_key generated = [] for response in stream_generate( self.model, self.tokenizer, prompt=tokens, max_tokens=max_tokens, sampler=sampler, **gen_kwargs, ): generated.append(response.text) if cache_key is not None: cache_key.append(response.token) if self.lru_cache is not None and cache_key is not None: self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) return backend_pb2.Reply(message=bytes(''.join(generated), encoding='utf-8')) except Exception as e: print(f"[Rank 0] Error in Predict: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Generation failed: {str(e)}") return backend_pb2.Reply(message=bytes("", encoding='utf-8')) async def PredictStream(self, request, context): prompt_cache = None cache_key = None try: import mlx.core as mx from mlx_lm import stream_generate from mlx_lm.sample_utils import make_sampler prompt_text = self._prepare_prompt(request) tokens = self._get_tokens_from_prompt(prompt_text) if self.coordinator: from coordinator import CMD_GENERATE self.coordinator.broadcast_command(CMD_GENERATE, len(tokens)) self.coordinator.broadcast_tokens(tokens) max_tokens, sampler_params = self._build_generation_params(request, default_max_tokens=512) if self.coordinator: gen_params = self.coordinator.broadcast_generation_params( max_tokens=max_tokens, temperature=sampler_params.get('temp', 0.6), top_p=sampler_params.get('top_p', 1.0), ) max_tokens = gen_params["max_tokens"] sampler = make_sampler(**sampler_params) # Use prompt cache in single-node mode gen_kwargs = {} if self.lru_cache is not None: from mlx_lm.models.cache import make_prompt_cache cache_key = list(tokens) prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache( self.model_key, cache_key ) if prompt_cache is None: prompt_cache = make_prompt_cache(self.model, self.max_kv_size) remaining_tokens = cache_key gen_kwargs['prompt_cache'] = prompt_cache tokens = remaining_tokens if remaining_tokens else cache_key for response in stream_generate( self.model, self.tokenizer, prompt=tokens, max_tokens=max_tokens, sampler=sampler, **gen_kwargs, ): if cache_key is not None: cache_key.append(response.token) yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) except Exception as e: print(f"[Rank 0] Error in PredictStream: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Streaming failed: {str(e)}") yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) finally: if self.lru_cache is not None and prompt_cache is not None and cache_key is not None: try: self.lru_cache.insert_cache(self.model_key, cache_key, prompt_cache) except Exception as e: print(f"Error inserting cache: {e}", file=sys.stderr) def Embedding(self, request, context): print("Embeddings not supported in MLX distributed backend", file=sys.stderr) context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Embeddings are not supported in the MLX distributed backend.") return backend_pb2.EmbeddingResult() def _prepare_prompt(self, request): if not request.Prompt and request.UseTokenizerTemplate and request.Messages: messages = [{"role": msg.role, "content": msg.content} for msg in request.Messages] return self.tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) return request.Prompt def _get_tokens_from_prompt(self, prompt_text: str) -> List[int]: tokens = self.tokenizer.encode(prompt_text) if hasattr(tokens, 'tolist'): return tokens.tolist() return list(tokens) def _build_generation_params(self, request, default_max_tokens=200): import mlx.core as mx max_tokens = getattr(request, 'Tokens', default_max_tokens) if max_tokens == 0: max_tokens = default_max_tokens temp = getattr(request, 'Temperature', 0.0) if temp == 0.0: temp = 0.6 top_p = getattr(request, 'TopP', 0.0) if top_p == 0.0: top_p = 1.0 sampler_params = { 'temp': temp, 'top_p': top_p, 'min_p': getattr(request, 'MinP', 0.0), 'top_k': getattr(request, 'TopK', 0), 'xtc_threshold': 0.0, 'xtc_probability': 0.0, } seed = getattr(request, 'Seed', 0) if seed != 0: mx.random.seed(seed) if hasattr(self, 'options'): if 'max_tokens' in self.options: max_tokens = self.options['max_tokens'] option_mapping = { 'temp': 'temp', 'temperature': 'temp', 'top_p': 'top_p', 'min_p': 'min_p', 'top_k': 'top_k', 'xtc_threshold': 'xtc_threshold', 'xtc_probability': 'xtc_probability', } for opt_key, param_key in option_mapping.items(): if opt_key in self.options: sampler_params[param_key] = self.options[opt_key] if 'seed' in self.options: mx.random.seed(self.options['seed']) # XTC special tokens xtc_special_tokens = [] if hasattr(self.tokenizer, 'eos_token_ids') and self.tokenizer.eos_token_ids: xtc_special_tokens = list(self.tokenizer.eos_token_ids) elif hasattr(self.tokenizer, 'eos_token_id') and self.tokenizer.eos_token_id is not None: xtc_special_tokens = [self.tokenizer.eos_token_id] try: newline_tokens = self.tokenizer.encode("\n") xtc_special_tokens.extend(newline_tokens) except: pass sampler_params['xtc_special_tokens'] = xtc_special_tokens return max_tokens, sampler_params def run_worker(group): """Worker loop for ranks > 0. Waits for commands from rank 0.""" from mlx_lm import load, stream_generate from mlx_lm.sample_utils import make_sampler from coordinator import DistributedCoordinator, CMD_LOAD_MODEL, CMD_GENERATE, CMD_SHUTDOWN from sharding import pipeline_auto_parallel import mlx.core as mx coordinator = DistributedCoordinator(group) model = None tokenizer = None print(f"[Rank {group.rank()}] Worker started, waiting for commands...", file=sys.stderr) while True: cmd, payload_size = coordinator.wait_for_command() if cmd == CMD_LOAD_MODEL: model_name = coordinator.broadcast_model_name() print(f"[Rank {group.rank()}] Loading model: {model_name}", file=sys.stderr) model, tokenizer = load(model_name) model = pipeline_auto_parallel(model, group) print(f"[Rank {group.rank()}] Model loaded and sharded", file=sys.stderr) elif cmd == CMD_GENERATE: if model is None: print(f"[Rank {group.rank()}] No model loaded, skipping generate", file=sys.stderr) continue token_count = coordinator.broadcast_token_count(payload_size) tokens_array = coordinator.broadcast_tokens([0] * token_count) tokens = tokens_array.tolist() gen_params = coordinator.broadcast_generation_params() sampler = make_sampler( temp=gen_params["temperature"], top_p=gen_params["top_p"], ) for _ in stream_generate( model, tokenizer, prompt=tokens, max_tokens=gen_params["max_tokens"], sampler=sampler, ): pass elif cmd == CMD_SHUTDOWN: print(f"[Rank {group.rank()}] Shutting down", file=sys.stderr) break async def serve(address): server = grpc.aio.server( migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ], ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler(sig, lambda: asyncio.ensure_future(server.stop(5))) await server.start() print(f"[Rank 0] gRPC server listening on {address}", file=sys.stderr) await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="MLX Distributed Backend") parser.add_argument("--addr", default="localhost:50051", help="gRPC listen address (used by LocalAI to send requests)") parser.add_argument("--worker", action="store_true", help="Run in worker mode (for remote ranks started by CLI)") parser.add_argument("--backend", default="ring", choices=["ring", "jaccl"], help="ring = TCP pipeline parallelism, jaccl = RDMA tensor parallelism") parser.add_argument("--hostfile", default=None, help="Path to hostfile JSON (required for --worker mode)") parser.add_argument("--rank", type=int, default=0, help="Rank of this process (0 = server, >0 = worker)") parser.add_argument("--coordinator", default=None, help="JACCL coordinator ip:port (jaccl backend only)") args = parser.parse_args() if args.worker: if not args.hostfile: print("Error: --hostfile is required in worker mode", file=sys.stderr) sys.exit(1) group = mlx_distributed_init(args.rank, args.hostfile, args.backend, args.coordinator) run_worker(group) else: # Server mode: started by LocalAI with just --addr. # Distributed init deferred to LoadModel (reads config from model options/env vars). asyncio.run(serve(args.addr)) ================================================ FILE: backend/python/mlx-distributed/coordinator.py ================================================ """ Distributed coordination using MLX distributed primitives. Rank 0 broadcasts commands and tokens to all ranks via all_sum/all_gather. Worker ranks wait in a loop for commands from rank 0. """ import json import struct import mlx.core as mx CMD_IDLE = 0 CMD_GENERATE = 1 CMD_LOAD_MODEL = 2 CMD_SHUTDOWN = -1 class DistributedCoordinator: def __init__(self, group): self.group = group self.rank = group.rank() self.world_size = group.size() def broadcast_command(self, cmd, payload_size=0): """Rank 0 broadcasts a command to all ranks. Uses all_sum with only rank 0 providing non-zero values so every rank receives the same command array. """ if self.rank == 0: cmd_array = mx.array([cmd, payload_size], dtype=mx.int32) else: cmd_array = mx.zeros((2,), dtype=mx.int32) result = mx.distributed.all_sum(cmd_array, group=self.group) mx.eval(result) return int(result[0].item()), int(result[1].item()) def broadcast_tokens(self, tokens): """Broadcast input token ids from rank 0 to all ranks. Rank 0 provides the real token array; other ranks provide zeros of the same shape. ``all_sum`` ensures every rank ends up with identical data. """ if self.rank == 0: token_array = mx.array(tokens, dtype=mx.int32) else: token_array = mx.zeros((len(tokens),), dtype=mx.int32) result = mx.distributed.all_sum(token_array, group=self.group) mx.eval(result) return result def broadcast_token_count(self, count): """Broadcast the number of tokens so workers can prepare a buffer.""" if self.rank == 0: count_array = mx.array([count], dtype=mx.int32) else: count_array = mx.zeros((1,), dtype=mx.int32) result = mx.distributed.all_sum(count_array, group=self.group) mx.eval(result) return int(result[0].item()) def broadcast_generation_params(self, max_tokens=200, temperature=0.6, top_p=1.0): """Broadcast generation parameters from rank 0.""" if self.rank == 0: params = mx.array([max_tokens, temperature, top_p], dtype=mx.float32) else: params = mx.zeros((3,), dtype=mx.float32) result = mx.distributed.all_sum(params, group=self.group) mx.eval(result) return { "max_tokens": int(result[0].item()), "temperature": float(result[1].item()), "top_p": float(result[2].item()), } def wait_for_command(self): """Worker ranks block here until rank 0 broadcasts a command.""" return self.broadcast_command(CMD_IDLE, 0) def broadcast_model_name(self, model_name=""): """Broadcast model name string from rank 0 to all ranks. Encodes the model name as int32 codepoints so it can travel via all_sum. """ if self.rank == 0: encoded = [ord(c) for c in model_name] # First broadcast the length length = self.broadcast_token_count(len(encoded)) if length > 0: name_array = mx.array(encoded, dtype=mx.int32) result = mx.distributed.all_sum(name_array, group=self.group) mx.eval(result) return model_name return "" else: length = self.broadcast_token_count(0) if length > 0: name_array = mx.zeros((length,), dtype=mx.int32) result = mx.distributed.all_sum(name_array, group=self.group) mx.eval(result) return "".join(chr(int(c.item())) for c in result) return "" ================================================ FILE: backend/python/mlx-distributed/install.sh ================================================ #!/bin/bash set -e USE_PIP=true PYTHON_VERSION="" backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi installRequirements ================================================ FILE: backend/python/mlx-distributed/mlx_cache.py ================================================ """ Thread-safe LRU prompt cache for MLX-based backends. Ported from mlx_lm/server.py (MIT License, Copyright 2023-2024 Apple Inc.) with thread-safety additions for LocalAI's gRPC backend. Usage: from mlx_cache import ThreadSafeLRUPromptCache # In LoadModel: self.lru_cache = ThreadSafeLRUPromptCache(max_size=10) # In Predict/PredictStream: prompt_cache, remaining_tokens = self.lru_cache.fetch_nearest_cache(model_key, tokens) # ... generate ... self.lru_cache.insert_cache(model_key, tokens, prompt_cache) """ import copy import threading from collections import deque from dataclasses import dataclass from typing import Any, List, Optional, Tuple @dataclass class CacheEntry: """A cache entry with reference counting.""" prompt_cache: List[Any] count: int @dataclass class SearchResult: """Result of searching the cache trie.""" model: Any exact: Optional[List[int]] shorter: Optional[List[int]] longer: Optional[List[int]] common_prefix: int class ThreadSafeLRUPromptCache: """ Thread-safe LRU cache with prefix matching for prompt KV caches. This cache stores KV caches keyed by token sequences and supports: - Exact match: Return the cache for the exact token sequence - Shorter prefix match: Return a cache for a prefix of the tokens - Longer prefix match: If a longer sequence is cached and can be trimmed - LRU eviction: When max_size is exceeded, evict least recently used Thread safety is provided via a threading.Lock that protects all cache operations. Args: max_size: Maximum number of cache entries (default: 10) can_trim_fn: Optional function to check if a cache can be trimmed trim_fn: Optional function to trim a cache """ def __init__( self, max_size: int = 10, can_trim_fn: Optional[Any] = None, trim_fn: Optional[Any] = None, ): self.max_size = max_size self._cache = {} self._lru = deque() self._lock = threading.Lock() # Optional trim functions (for longer prefix reuse) self._can_trim_fn = can_trim_fn self._trim_fn = trim_fn def _search(self, model, tokens: List[int]) -> SearchResult: """ Search the cache for a prompt cache. Return exact or close match. The cache is organized as a trie where each node is keyed by a token. This allows efficient prefix matching. """ if model not in self._cache: return SearchResult(model, None, None, None, 0) current = self._cache[model] last_cache_index = -1 index = 0 # Traverse the trie following the token sequence while index < len(tokens) and tokens[index] in current: current = current[tokens[index]] if "cache" in current: last_cache_index = index index += 1 # Exact match - no need to search for longer or shorter caches if last_cache_index == len(tokens) - 1: return SearchResult(model, tuple(tokens), None, None, 0) # Find the shorter cache (a prefix that has a cache) # Note: Uses > 0 (not >= 0) to match upstream mlx_lm/server.py behavior. # Single-token prefixes are not matched, which allows longer cached # sequences to be preferred for trimming. This is acceptable because # real prompts with chat templates are always many tokens. shorter = None if last_cache_index > 0: shorter = tuple(tokens[: last_cache_index + 1]) # Check for caches that are longer than our token sequence longer = None common_prefix = index if index > 0 and last_cache_index <= 0: best = None stack = [(current, [])] while stack: current, extra = stack.pop() if "cache" in current: if best is None or len(extra) < len(best): best = extra else: for tok in current: stack.append((current[tok], extra + [tok])) if best is not None: longer = tuple(tokens[:index] + best) return SearchResult(model, None, shorter, longer, common_prefix) def _get(self, model, tokens: Tuple[int, ...]) -> CacheEntry: """Get a cache entry by traversing the trie.""" current = self._cache[model] for tok in tokens: current = current[tok] return current["cache"] def _delete(self, model, tokens: Tuple[int, ...]) -> None: """Delete a cache entry and clean up empty trie nodes.""" path = [self._cache[model]] for tok in tokens: path.append(path[-1][tok]) del path[-1]["cache"] # Clean up empty nodes bottom-up for i in reversed(range(len(tokens))): d_prev, d, t = path[i], path[i + 1], tokens[i] if len(d) > 0: break del d_prev[t] def _extract(self, model, tokens: Tuple[int, ...]) -> CacheEntry: """ Extract a cache entry for exclusive use. If the entry has count > 1, deep copy and decrement. If count == 1, remove from cache entirely. """ cache_entry = self._get(model, tokens) if cache_entry.count == 1: self._delete(model, tokens) self._lru.remove((model, tokens)) return cache_entry cache_entry.count -= 1 return CacheEntry( copy.deepcopy(cache_entry.prompt_cache), 1, ) def fetch_nearest_cache( self, model, tokens: List[int] ) -> Tuple[Optional[List[Any]], List[int]]: """ Fetch the nearest cache for the given token sequence. Thread-safe. Returns (cache, remaining_tokens) where: - cache: The KV cache to use (or None if no cache found) - remaining_tokens: Tokens that still need to be processed Args: model: Model identifier (used to namespace caches) tokens: The full token sequence for the prompt Returns: Tuple of (prompt_cache, remaining_tokens) """ with self._lock: tokens_tuple = tuple(tokens) result = self._search(model, tokens) # Exact match - extract and return if result.exact is not None: cache_entry = self._extract(result.model, result.exact) return cache_entry.prompt_cache, [] # Shorter prefix match - extract and return remaining if result.shorter is not None: cache_entry = self._extract(result.model, result.shorter) prefix_len = len(result.shorter) return cache_entry.prompt_cache, list(tokens[prefix_len:]) # Longer prefix match - try to trim if possible if result.longer is not None and self._can_trim_fn is not None: cache_entry = self._get(result.model, result.longer) if self._can_trim_fn(cache_entry.prompt_cache): # Deep copy and trim trimmed_cache = copy.deepcopy(cache_entry.prompt_cache) prefix = min(len(tokens) - 1, result.common_prefix) num_to_trim = len(result.longer) - prefix if self._trim_fn is not None: self._trim_fn(trimmed_cache, num_to_trim) return trimmed_cache, list(tokens[prefix:]) # No match found return None, list(tokens) def insert_cache( self, model, tokens: List[int], prompt_cache: List[Any] ) -> None: """ Insert a cache entry after generation completes. Thread-safe. Handles LRU eviction if max_size is exceeded. Args: model: Model identifier (used to namespace caches) tokens: The full token sequence (prompt + generated) prompt_cache: The KV cache to store """ with self._lock: tokens_tuple = tuple(tokens) if model not in self._cache: self._cache[model] = {} current = self._cache[model] # Build trie path for tok in tokens_tuple: if tok not in current: current[tok] = {} current = current[tok] # Update or create entry if "cache" in current: current["cache"].count += 1 self._lru.remove((model, tokens_tuple)) else: current["cache"] = CacheEntry(prompt_cache, 1) # Update LRU order self._lru.append((model, tokens_tuple)) # Evict if over capacity if len(self._lru) > self.max_size: evict_model, evict_tokens = self._lru.popleft() self._delete(evict_model, evict_tokens) def clear(self) -> None: """Clear all cache entries. Thread-safe.""" with self._lock: self._cache.clear() self._lru.clear() def __len__(self) -> int: """Return the number of cache entries. Thread-safe.""" with self._lock: return len(self._lru) ================================================ FILE: backend/python/mlx-distributed/requirements-cpu.txt ================================================ mlx-lm mlx[cpu] ================================================ FILE: backend/python/mlx-distributed/requirements-cublas12.txt ================================================ mlx-lm mlx[cuda12] ================================================ FILE: backend/python/mlx-distributed/requirements-cublas13.txt ================================================ mlx-lm mlx[cuda13] ================================================ FILE: backend/python/mlx-distributed/requirements-l4t12.txt ================================================ mlx-lm mlx[cuda12] ================================================ FILE: backend/python/mlx-distributed/requirements-l4t13.txt ================================================ mlx-lm mlx[cuda13] ================================================ FILE: backend/python/mlx-distributed/requirements-mps.txt ================================================ mlx-lm ================================================ FILE: backend/python/mlx-distributed/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi setuptools ================================================ FILE: backend/python/mlx-distributed/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/mlx-distributed/sharding.py ================================================ """ Auto-parallelism for MLX distributed inference. Provides pipeline parallelism (Ring backend) by wrapping model layers with distributed send/recv operations. Ported from exo's auto_parallel.py with simplifications for LocalAI's use case. """ import mlx.core as mx import mlx.nn as nn class PipelineFirstLayer(nn.Module): """Wraps the first layer on each rank to receive from the previous rank.""" def __init__(self, original_layer, rank, group): super().__init__() dict.__setitem__(self, "_original_layer", original_layer) self.rank = rank self.group = group @property def original_layer(self): return self["_original_layer"] def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self["_original_layer"], name) def __call__(self, x, *args, **kwargs): if self.rank != 0: mx.eval(x) x = mx.distributed.recv_like(x, self.rank - 1, group=self.group) mx.eval(x) return self.original_layer(x, *args, **kwargs) class PipelineLastLayer(nn.Module): """Wraps the last layer on each rank to send to the next rank.""" def __init__(self, original_layer, rank, world_size, group): super().__init__() dict.__setitem__(self, "_original_layer", original_layer) self.rank = rank self.world_size = world_size self.group = group @property def original_layer(self): return self["_original_layer"] def __getattr__(self, name): try: return super().__getattr__(name) except AttributeError: return getattr(self["_original_layer"], name) def __call__(self, x, *args, **kwargs): output = self.original_layer(x, *args, **kwargs) mx.eval(output) if self.rank != self.world_size - 1: output = mx.distributed.send( output, (self.rank + 1) % self.world_size, group=self.group ) mx.eval(output) # Gather output from all ranks so every rank has the final result output = mx.distributed.all_gather(output, group=self.group)[ -output.shape[0] : ] mx.eval(output) return output def get_inner_model(model): """Get the inner model (model.model or model.transformer).""" for attr in ("model", "transformer"): inner = getattr(model, attr, None) if isinstance(inner, nn.Module): # Some models have model.model (e.g. language_model.model) inner_inner = getattr(inner, "model", None) if isinstance(inner_inner, nn.Module): return inner_inner return inner raise ValueError("Model must have a 'model' or 'transformer' attribute") def get_layers(inner_model): """Get the list of transformer layers.""" for attr in ("layers", "h"): layers = getattr(inner_model, attr, None) if layers is not None: return layers raise ValueError("Model must have a 'layers' or 'h' attribute") def pipeline_auto_parallel(model, group, start_layer=None, end_layer=None): """Apply pipeline parallelism to a model. Each rank only keeps its slice of layers. The first layer receives from the previous rank, and the last layer sends to the next rank. Args: model: The MLX model (must have model.layers or similar) group: The distributed group start_layer: First layer index for this rank (auto-computed if None) end_layer: Last layer index (exclusive) for this rank (auto-computed if None) """ rank = group.rank() world_size = group.size() inner = get_inner_model(model) layers = list(get_layers(inner)) total_layers = len(layers) if start_layer is None or end_layer is None: layers_per_rank = total_layers // world_size remainder = total_layers % world_size start_layer = rank * layers_per_rank + min(rank, remainder) end_layer = start_layer + layers_per_rank + (1 if rank < remainder else 0) layers = layers[start_layer:end_layer] for layer in layers: mx.eval(layer) # Wrap first and last layers layers[0] = PipelineFirstLayer(layers[0], rank, group=group) layers[-1] = PipelineLastLayer(layers[-1], rank, world_size, group=group) # Replace layers on the inner model if hasattr(inner, "layers"): inner.layers = layers elif hasattr(inner, "h"): inner.h = layers return model ================================================ FILE: backend/python/mlx-distributed/test.py ================================================ import unittest import subprocess import time import grpc import backend_pb2 import backend_pb2_grpc class TestBackendServicer(unittest.TestCase): def setUp(self): self.service = subprocess.Popen( ["python", "backend.py", "--addr", "localhost:50051"] ) time.sleep(10) def tearDown(self) -> None: self.service.terminate() self.service.wait() def test_server_startup(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_text(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) req = backend_pb2.PredictOptions(Prompt="The capital of France is") resp = stub.Predict(req) self.assertIsNotNone(resp.message) except Exception as err: print(err) self.fail("text service failed") finally: self.tearDown() def test_sampling_params(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit")) self.assertTrue(response.success) req = backend_pb2.PredictOptions( Prompt="The capital of France is", TopP=0.8, Tokens=50, Temperature=0.7, TopK=40, MinP=0.05, Seed=42, ) resp = stub.Predict(req) self.assertIsNotNone(resp.message) except Exception as err: print(err) self.fail("sampling params service failed") finally: self.tearDown() ================================================ FILE: backend/python/mlx-distributed/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/mlx-vlm/Makefile ================================================ .PHONY: mlx-vlm mlx-vlm: bash install.sh .PHONY: run run: mlx-vlm @echo "Running mlx-vlm..." bash run.sh @echo "mlx run." .PHONY: test test: mlx-vlm @echo "Testing mlx-vlm..." bash test.sh @echo "mlx tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/mlx-vlm/backend.py ================================================ #!/usr/bin/env python3 import asyncio from concurrent import futures import argparse import signal import sys import os from typing import List import time import backend_pb2 import backend_pb2_grpc import grpc from mlx_vlm import load, generate, stream_generate from mlx_vlm.prompt_utils import apply_chat_template from mlx_vlm.utils import load_config, load_image import mlx.core as mx import base64 import io from PIL import Image import tempfile def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ A gRPC servicer that implements the Backend service defined in backend.proto. """ def Health(self, request, context): """ Returns a health check message. Args: request: The health check request. context: The gRPC context. Returns: backend_pb2.Reply: The health check reply. """ return backend_pb2.Reply(message=bytes("OK", 'utf-8')) async def LoadModel(self, request, context): """ Loads a multimodal vision-language model using MLX-VLM. Args: request: The load model request. context: The gRPC context. Returns: backend_pb2.Result: The load model result. """ try: print(f"Loading MLX-VLM model: {request.Model}", file=sys.stderr) print(f"Request: {request}", file=sys.stderr) # Parse options like in the diffusers backend options = request.Options self.options = {} # The options are a list of strings in this form optname:optvalue # We store all the options in a dict for later use for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) # Split only on first colon to handle values with colons if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value print(f"Options: {self.options}", file=sys.stderr) # Load model and processor using MLX-VLM # mlx-vlm load function returns (model, processor) instead of (model, tokenizer) self.model, self.processor = load(request.Model) # Load model config for chat template support self.config = load_config(request.Model) except Exception as err: print(f"Error loading MLX-VLM model {err=}, {type(err)=}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Error loading MLX-VLM model: {err}") print("MLX-VLM model loaded successfully", file=sys.stderr) return backend_pb2.Result(message="MLX-VLM model loaded successfully", success=True) async def Predict(self, request, context): """ Generates text based on the given prompt and sampling parameters using MLX-VLM with multimodal support. Args: request: The predict request. context: The gRPC context. Returns: backend_pb2.Reply: The predict result. """ temp_files = [] try: # Process images and audios from request image_paths = [] audio_paths = [] # Process images if request.Images: for img_data in request.Images: img_path = self.load_image_from_base64(img_data) if img_path: image_paths.append(img_path) temp_files.append(img_path) # Process audios if request.Audios: for audio_data in request.Audios: audio_path = self.load_audio_from_base64(audio_data) if audio_path: audio_paths.append(audio_path) temp_files.append(audio_path) # Prepare the prompt with multimodal information prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths)) # Build generation parameters using request attributes and options max_tokens, generation_params = self._build_generation_params(request) print(f"Generating text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr) print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr) # Generate text using MLX-VLM with multimodal inputs response = generate( model=self.model, processor=self.processor, prompt=prompt, image=image_paths if image_paths else None, audio=audio_paths if audio_paths else None, max_tokens=max_tokens, temperature=generation_params.get('temp', 0.6), top_p=generation_params.get('top_p', 1.0), verbose=False ) return backend_pb2.Reply(message=bytes(response, encoding='utf-8')) except Exception as e: print(f"Error in MLX-VLM Predict: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Generation failed: {str(e)}") return backend_pb2.Reply(message=bytes("", encoding='utf-8')) finally: # Clean up temporary files self.cleanup_temp_files(temp_files) def Embedding(self, request, context): """ A gRPC method that calculates embeddings for a given sentence. Note: MLX-VLM doesn't support embeddings directly. This method returns an error. Args: request: An EmbeddingRequest object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: An EmbeddingResult object that contains the calculated embeddings. """ print("Embeddings not supported in MLX-VLM backend", file=sys.stderr) context.set_code(grpc.StatusCode.UNIMPLEMENTED) context.set_details("Embeddings are not supported in the MLX-VLM backend.") return backend_pb2.EmbeddingResult() async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results using MLX-VLM with multimodal support. Args: request: The predict stream request. context: The gRPC context. Yields: backend_pb2.Reply: Streaming predict results. """ temp_files = [] try: # Process images and audios from request image_paths = [] audio_paths = [] # Process images if request.Images: for img_data in request.Images: img_path = self.load_image_from_base64(img_data) if img_path: image_paths.append(img_path) temp_files.append(img_path) # Process audios if request.Audios: for audio_data in request.Audios: audio_path = self.load_audio_from_base64(audio_data) if audio_path: audio_paths.append(audio_path) temp_files.append(audio_path) # Prepare the prompt with multimodal information prompt = self._prepare_prompt(request, num_images=len(image_paths), num_audios=len(audio_paths)) # Build generation parameters using request attributes and options max_tokens, generation_params = self._build_generation_params(request, default_max_tokens=512) print(f"Streaming text with MLX-VLM - max_tokens: {max_tokens}, params: {generation_params}", file=sys.stderr) print(f"Images: {len(image_paths)}, Audios: {len(audio_paths)}", file=sys.stderr) # Stream text generation using MLX-VLM with multimodal inputs for response in stream_generate( model=self.model, processor=self.processor, prompt=prompt, image=image_paths if image_paths else None, audio=audio_paths if audio_paths else None, max_tokens=max_tokens, temperature=generation_params.get('temp', 0.6), top_p=generation_params.get('top_p', 1.0), ): yield backend_pb2.Reply(message=bytes(response.text, encoding='utf-8')) except Exception as e: print(f"Error in MLX-VLM PredictStream: {e}", file=sys.stderr) context.set_code(grpc.StatusCode.INTERNAL) context.set_details(f"Streaming generation failed: {str(e)}") yield backend_pb2.Reply(message=bytes("", encoding='utf-8')) finally: # Clean up temporary files self.cleanup_temp_files(temp_files) def _prepare_prompt(self, request, num_images=0, num_audios=0): """ Prepare the prompt for MLX-VLM generation, handling chat templates and multimodal inputs. Args: request: The gRPC request containing prompt and message information. num_images: Number of images in the request. num_audios: Number of audio files in the request. Returns: str: The prepared prompt. """ # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template if not request.Prompt and request.UseTokenizerTemplate and request.Messages: # Convert gRPC messages to the format expected by apply_chat_template messages = [] for msg in request.Messages: messages.append({"role": msg.role, "content": msg.content}) # Use mlx-vlm's apply_chat_template which handles multimodal inputs prompt = apply_chat_template( self.processor, self.config, messages, num_images=num_images, num_audios=num_audios ) return prompt elif request.Prompt: # If we have a direct prompt but also have images/audio, we need to format it properly if num_images > 0 or num_audios > 0: # Create a simple message structure for multimodal prompt messages = [{"role": "user", "content": request.Prompt}] prompt = apply_chat_template( self.processor, self.config, messages, num_images=num_images, num_audios=num_audios ) return prompt else: return request.Prompt else: # Fallback to empty prompt with multimodal template if we have media if num_images > 0 or num_audios > 0: messages = [{"role": "user", "content": ""}] prompt = apply_chat_template( self.processor, self.config, messages, num_images=num_images, num_audios=num_audios ) return prompt else: return "" def _build_generation_params(self, request, default_max_tokens=200): """ Build generation parameters from request attributes and options for MLX-VLM. Args: request: The gRPC request. default_max_tokens: Default max_tokens if not specified. Returns: tuple: (max_tokens, generation_params dict) """ # Extract max_tokens max_tokens = getattr(request, 'Tokens', default_max_tokens) if max_tokens == 0: max_tokens = default_max_tokens # Extract generation parameters from request attributes temp = getattr(request, 'Temperature', 0.0) if temp == 0.0: temp = 0.6 # Default temperature top_p = getattr(request, 'TopP', 0.0) if top_p == 0.0: top_p = 1.0 # Default top_p # Initialize generation parameters for MLX-VLM generation_params = { 'temp': temp, 'top_p': top_p, } # Add seed if specified seed = getattr(request, 'Seed', 0) if seed != 0: mx.random.seed(seed) # Override with options if available if hasattr(self, 'options'): # Max tokens from options if 'max_tokens' in self.options: max_tokens = self.options['max_tokens'] # Generation parameters from options param_option_mapping = { 'temp': 'temp', 'temperature': 'temp', # alias 'top_p': 'top_p', } for option_key, param_key in param_option_mapping.items(): if option_key in self.options: generation_params[param_key] = self.options[option_key] # Handle seed from options if 'seed' in self.options: mx.random.seed(self.options['seed']) return max_tokens, generation_params def load_image_from_base64(self, image_data: str): """ Load an image from base64 encoded data. Args: image_data (str): Base64 encoded image data. Returns: PIL.Image or str: The loaded image or path to the image. """ try: decoded_data = base64.b64decode(image_data) image = Image.open(io.BytesIO(decoded_data)) # Save to temporary file for mlx-vlm with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as tmp_file: image.save(tmp_file.name, format='JPEG') return tmp_file.name except Exception as e: print(f"Error loading image from base64: {e}", file=sys.stderr) return None def load_audio_from_base64(self, audio_data: str): """ Load audio from base64 encoded data. Args: audio_data (str): Base64 encoded audio data. Returns: str: Path to the loaded audio file. """ try: decoded_data = base64.b64decode(audio_data) # Save to temporary file for mlx-vlm with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as tmp_file: tmp_file.write(decoded_data) return tmp_file.name except Exception as e: print(f"Error loading audio from base64: {e}", file=sys.stderr) return None def cleanup_temp_files(self, file_paths: List[str]): """ Clean up temporary files. Args: file_paths (List[str]): List of file paths to clean up. """ for file_path in file_paths: try: if file_path and os.path.exists(file_path): os.remove(file_path) except Exception as e: print(f"Error removing temporary file {file_path}: {e}", file=sys.stderr) async def serve(address): # Start asyncio gRPC server server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address server.add_insecure_port(address) # Gracefully shutdown the server on SIGTERM or SIGINT loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler( sig, lambda: asyncio.ensure_future(server.stop(5)) ) # Start the server await server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Wait for the server to be terminated await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() asyncio.run(serve(args.addr)) ================================================ FILE: backend/python/mlx-vlm/install.sh ================================================ #!/bin/bash set -e USE_PIP=true backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi installRequirements ================================================ FILE: backend/python/mlx-vlm/requirements-cpu.txt ================================================ git+https://github.com/Blaizzy/mlx-vlm mlx[cpu] ================================================ FILE: backend/python/mlx-vlm/requirements-cublas12.txt ================================================ git+https://github.com/Blaizzy/mlx-vlm mlx[cuda12] ================================================ FILE: backend/python/mlx-vlm/requirements-cublas13.txt ================================================ git+https://github.com/Blaizzy/mlx-vlm mlx[cuda13] ================================================ FILE: backend/python/mlx-vlm/requirements-l4t12.txt ================================================ git+https://github.com/Blaizzy/mlx-vlm mlx[cuda12] ================================================ FILE: backend/python/mlx-vlm/requirements-l4t13.txt ================================================ git+https://github.com/Blaizzy/mlx-vlm mlx[cuda13] ================================================ FILE: backend/python/mlx-vlm/requirements-mps.txt ================================================ git+https://github.com/Blaizzy/mlx-vlm ================================================ FILE: backend/python/mlx-vlm/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi setuptools ================================================ FILE: backend/python/mlx-vlm/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/mlx-vlm/test.py ================================================ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc import unittest import subprocess import time import grpc import backend_pb2_grpc import backend_pb2 class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service. This class contains methods to test the startup and shutdown of the gRPC service. """ def setUp(self): self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"]) time.sleep(10) def tearDown(self) -> None: self.service.terminate() self.service.wait() def test_server_startup(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_text(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) self.assertTrue(response.success) req = backend_pb2.PredictOptions(Prompt="The capital of France is") resp = stub.Predict(req) self.assertIsNotNone(resp.message) except Exception as err: print(err) self.fail("text service failed") finally: self.tearDown() def test_sampling_params(self): """ This method tests if all sampling parameters are correctly processed NOTE: this does NOT test for correctness, just that we received a compatible response """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) self.assertTrue(response.success) req = backend_pb2.PredictOptions( Prompt="The capital of France is", TopP=0.8, Tokens=50, Temperature=0.7, TopK=40, PresencePenalty=0.1, FrequencyPenalty=0.2, RepetitionPenalty=1.1, MinP=0.05, Seed=42, StopPrompts=["\n"], StopTokenIds=[50256], BadWords=["badword"], IncludeStopStrInOutput=True, IgnoreEOS=True, MinTokens=5, Logprobs=5, PromptLogprobs=5, SkipSpecialTokens=True, SpacesBetweenSpecialTokens=True, TruncatePromptTokens=10, GuidedDecoding=True, N=2, ) resp = stub.Predict(req) self.assertIsNotNone(resp.message) self.assertIsNotNone(resp.logprobs) except Exception as err: print(err) self.fail("sampling params service failed") finally: self.tearDown() def test_embedding(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct")) self.assertTrue(response.success) embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.") embedding_response = stub.Embedding(embedding_request) self.assertIsNotNone(embedding_response.embeddings) # assert that is a list of floats self.assertIsInstance(embedding_response.embeddings, list) # assert that the list is not empty self.assertTrue(len(embedding_response.embeddings) > 0) except Exception as err: print(err) self.fail("Embedding service failed") finally: self.tearDown() ================================================ FILE: backend/python/mlx-vlm/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/moonshine/Makefile ================================================ .DEFAULT_GOAL := install .PHONY: install install: bash install.sh .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ test: install bash test.sh ================================================ FILE: backend/python/moonshine/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for Moonshine transcription """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc from moonshine_voice import ( Transcriber, get_model_for_language, load_wav_file, ) import grpc _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def __init__(self): self.transcriber = None self.model_name = None def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): try: print("Preparing models, please wait", file=sys.stderr) self.model_name = request.Model print(f"Model name set to: {self.model_name}", file=sys.stderr) # Default values language = "en" model_arch = None # Parse options from request options = request.Options self.options = {} # The options are a list of strings in this form optname:optvalue for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) self.options[key] = value print(f"Options: {self.options}", file=sys.stderr) # Extract language and model_arch from options if "language" in self.options: language = self.options["language"] if "model_arch" in self.options: model_arch = self.options["model_arch"] # Get the model path and architecture model_path, model_arch = get_model_for_language(language, model_arch) print(f"Loading model: {model_path} with architecture: {model_arch} for language: {language}", file=sys.stderr) # Initialize the transcriber self.transcriber = Transcriber(model_path=model_path, model_arch=model_arch) print("Model loaded successfully", file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Model loaded successfully", success=True) def AudioTranscription(self, request, context): resultSegments = [] text = "" try: if self.transcriber is None: raise Exception("Model not loaded. Call LoadModel first.") # Load the audio file audio_data, sample_rate = load_wav_file(request.dst) print(f"Loaded audio file: {request.dst} with sample rate: {sample_rate}", file=sys.stderr) # Transcribe without streaming transcript = self.transcriber.transcribe_without_streaming( audio_data, sample_rate=sample_rate, flags=0 ) # Process transcript lines full_text_parts = [] for idx, line in enumerate(transcript.lines): line_text = line.text.strip() full_text_parts.append(line_text) # Create segment with timing information start_ms = int(line.start_time * 1000) end_ms = int((line.start_time + line.duration) * 1000) resultSegments.append(backend_pb2.TranscriptSegment( id=idx, start=start_ms, end=end_ms, text=line_text )) print(f"Segment {idx}: [{line.start_time:.2f}s - {line.start_time + line.duration:.2f}s] {line_text}", file=sys.stderr) # Combine all transcriptions into a single text text = " ".join(full_text_parts) except Exception as err: print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) import traceback traceback.print_exc() return backend_pb2.TranscriptResult(segments=[], text="") return backend_pb2.TranscriptResult(segments=resultSegments, text=text) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/moonshine/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi installRequirements ================================================ FILE: backend/python/moonshine/protogen.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto ================================================ FILE: backend/python/moonshine/requirements-mps.txt ================================================ grpcio==1.71.0 protobuf grpcio-tools moonshine-voice ================================================ FILE: backend/python/moonshine/requirements.txt ================================================ grpcio==1.71.0 protobuf grpcio-tools moonshine-voice ================================================ FILE: backend/python/moonshine/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/moonshine/test.py ================================================ """ A test script to test the gRPC service for Moonshine transcription """ import unittest import subprocess import time import os import tempfile import shutil import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(10) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="moonshine/tiny")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_audio_transcription(self): """ This method tests if audio transcription works successfully """ # Create a temporary directory for the audio file temp_dir = tempfile.mkdtemp() audio_file = os.path.join(temp_dir, 'audio.wav') try: # Download the audio file to the temporary directory print(f"Downloading audio file to {audio_file}...") url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav" result = subprocess.run( ["wget", "-q", url, "-O", audio_file], capture_output=True, text=True ) if result.returncode != 0: self.fail(f"Failed to download audio file: {result.stderr}") # Verify the file was downloaded if not os.path.exists(audio_file): self.fail(f"Audio file was not downloaded to {audio_file}") self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) # Load the model first load_response = stub.LoadModel(backend_pb2.ModelOptions(Model="moonshine/tiny")) self.assertTrue(load_response.success) # Perform transcription transcript_request = backend_pb2.TranscriptRequest(dst=audio_file) transcript_response = stub.AudioTranscription(transcript_request) # Print the transcribed text for debugging print(f"Transcribed text: {transcript_response.text}") print(f"Number of segments: {len(transcript_response.segments)}") # Verify response structure self.assertIsNotNone(transcript_response) self.assertIsNotNone(transcript_response.text) # Protobuf repeated fields return a sequence, not a list self.assertIsNotNone(transcript_response.segments) # Check if segments is iterable (has length) self.assertGreaterEqual(len(transcript_response.segments), 0) # Verify the transcription contains the expected text expected_text = "This is the micro machine man" self.assertIn( expected_text.lower(), transcript_response.text.lower(), f"Expected text '{expected_text}' not found in transcription: '{transcript_response.text}'" ) # If we got segments, verify they have the expected structure if len(transcript_response.segments) > 0: segment = transcript_response.segments[0] self.assertIsNotNone(segment.text) self.assertIsInstance(segment.id, int) else: # Even if no segments, we should have text self.assertIsNotNone(transcript_response.text) self.assertGreater(len(transcript_response.text), 0) except Exception as err: print(err) self.fail("AudioTranscription service failed") finally: self.tearDown() # Clean up the temporary directory if os.path.exists(temp_dir): shutil.rmtree(temp_dir) ================================================ FILE: backend/python/moonshine/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/nemo/Makefile ================================================ .PHONY: nemo-asr nemo-asr: bash install.sh .PHONY: run run: nemo-asr @echo "Running nemo-asr..." bash run.sh @echo "nemo-asr run." .PHONY: test test: nemo-asr @echo "Testing nemo-asr..." bash test.sh @echo "nemo-asr tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/nemo/backend.py ================================================ #!/usr/bin/env python3 """ gRPC server of LocalAI for NVIDIA NEMO Toolkit ASR. """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc import torch import nemo.collections.asr as nemo_asr import grpc def is_float(s): try: float(s) return True except ValueError: return False def is_int(s): try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) class BackendServicer(backend_pb2_grpc.BackendServicer): def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): if torch.cuda.is_available(): device = "cuda" else: device = "cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") self.device = device self.options = {} for opt in request.Options: if ":" not in opt: continue key, value = opt.split(":", 1) if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value model_name = request.Model or "nvidia/parakeet-tdt-0.6b-v3" try: print(f"Loading NEMO ASR model from {model_name}", file=sys.stderr) self.model = nemo_asr.models.ASRModel.from_pretrained(model_name=model_name) print("NEMO ASR model loaded successfully", file=sys.stderr) except Exception as err: print(f"[ERROR] LoadModel failed: {err}", file=sys.stderr) import traceback traceback.print_exc(file=sys.stderr) return backend_pb2.Result(success=False, message=str(err)) return backend_pb2.Result(message="Model loaded successfully", success=True) def AudioTranscription(self, request, context): result_segments = [] text = "" try: audio_path = request.dst if not audio_path or not os.path.exists(audio_path): print(f"Error: Audio file not found: {audio_path}", file=sys.stderr) return backend_pb2.TranscriptResult(segments=[], text="") # NEMO's transcribe method accepts a list of audio paths and returns a list of transcripts results = self.model.transcribe([audio_path]) if not results or len(results) == 0: return backend_pb2.TranscriptResult(segments=[], text="") # Get the transcript text from the first result text = results[0] if text: # Create a single segment with the full transcription result_segments.append(backend_pb2.TranscriptSegment( id=0, start=0, end=0, text=text )) except Exception as err: print(f"Error in AudioTranscription: {err}", file=sys.stderr) import traceback traceback.print_exc(file=sys.stderr) return backend_pb2.TranscriptResult(segments=[], text="") return backend_pb2.TranscriptResult(segments=result_segments, text=text) def serve(address): server = grpc.server( futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.") args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/nemo/install.sh ================================================ #!/bin/bash set -e EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation" backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi installRequirements ================================================ FILE: backend/python/nemo/protogen.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto ================================================ FILE: backend/python/nemo/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu torch nemo_toolkit[asr] ================================================ FILE: backend/python/nemo/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu128 torch nemo_toolkit[asr] ================================================ FILE: backend/python/nemo/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch nemo_toolkit[asr] ================================================ FILE: backend/python/nemo/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.3 torch nemo_toolkit[asr] ================================================ FILE: backend/python/nemo/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch nemo_toolkit[asr] ================================================ FILE: backend/python/nemo/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ torch nemo_toolkit[asr] ================================================ FILE: backend/python/nemo/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch nemo_toolkit[asr] ================================================ FILE: backend/python/nemo/requirements-mps.txt ================================================ torch nemo_toolkit[asr] ================================================ FILE: backend/python/nemo/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 setuptools pyarrow==20.0.0 ================================================ FILE: backend/python/nemo/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/nemo/test.py ================================================ """ Tests for the NEMO Toolkit ASR gRPC backend. """ import unittest import subprocess import time import os import tempfile import shutil import backend_pb2 import backend_pb2_grpc import grpc # Skip heavy transcription test in CI (model download + inference) SKIP_ASR_TESTS = os.environ.get("SKIP_ASR_TESTS", "false").lower() == "true" class TestBackendServicer(unittest.TestCase): def setUp(self): self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(15) def tearDown(self): self.service.terminate() self.service.wait() def test_server_startup(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="nvidia/parakeet-tdt-0.6b-v3")) self.assertTrue(response.success, response.message) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() @unittest.skipIf(SKIP_ASR_TESTS, "ASR transcription test skipped (SKIP_ASR_TESTS=true)") def test_audio_transcription(self): temp_dir = tempfile.mkdtemp() audio_file = os.path.join(temp_dir, 'audio.wav') try: # Download a sample audio file for testing url = "https://audio-samples.github.io/samples/mp3/crowd-cheering-and-applause-sound-effect.mp3" result = subprocess.run( ["wget", "-q", url, "-O", audio_file], capture_output=True, text=True, timeout=30, ) if result.returncode != 0: self.skipTest(f"Could not download sample audio: {result.stderr}") if not os.path.exists(audio_file): self.skipTest("Sample audio file not found after download") self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) load_response = stub.LoadModel(backend_pb2.ModelOptions(Model="nvidia/parakeet-tdt-0.6b-v3")) self.assertTrue(load_response.success, load_response.message) transcript_response = stub.AudioTranscription( backend_pb2.TranscriptRequest(dst=audio_file) ) self.assertIsNotNone(transcript_response) self.assertIsNotNone(transcript_response.text) self.assertGreaterEqual(len(transcript_response.segments), 0) all_text = "" for segment in transcript_response.segments: all_text += segment.text print(f"Transcription result: {all_text}") self.assertIn("big", all_text) if transcript_response.segments: self.assertIsNotNone(transcript_response.segments[0].text) finally: self.tearDown() if os.path.exists(temp_dir): shutil.rmtree(temp_dir) if __name__ == '__main__': unittest.main() ================================================ FILE: backend/python/nemo/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/neutts/Makefile ================================================ .PHONY: neutts neutts: bash install.sh .PHONY: run run: neutts @echo "Running neutts..." bash run.sh @echo "neutts run." .PHONY: test test: neutts @echo "Testing neutts..." bash test.sh @echo "neutts tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/neutts/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for NeuTTSAir """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc import torch from neuttsair.neutts import NeuTTSAir import soundfile as sf import grpc def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): # Get device # device = "cuda" if request.CUDA else "cpu" if torch.cuda.is_available(): print("CUDA is available", file=sys.stderr) device = "cuda" else: print("CUDA is not available", file=sys.stderr) device = "cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") options = request.Options # empty dict self.options = {} self.ref_text = None # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when # generating the images for opt in options: if ":" not in opt: continue key, value = opt.split(":") # if value is a number, convert it to the appropriate type if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value codec_repo = "neuphonic/neucodec" if "codec_repo" in self.options: codec_repo = self.options["codec_repo"] del self.options["codec_repo"] if "ref_text" in self.options: self.ref_text = self.options["ref_text"] del self.options["ref_text"] self.AudioPath = None if os.path.isabs(request.AudioPath): self.AudioPath = request.AudioPath elif request.AudioPath and request.ModelFile != "" and not os.path.isabs(request.AudioPath): # get base path of modelFile modelFileBase = os.path.dirname(request.ModelFile) # modify LoraAdapter to be relative to modelFileBase self.AudioPath = os.path.join(modelFileBase, request.AudioPath) try: print("Preparing models, please wait", file=sys.stderr) self.model = NeuTTSAir(backbone_repo=request.Model, backbone_device=device, codec_repo=codec_repo, codec_device=device) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) def TTS(self, request, context): try: kwargs = {} # add options to kwargs kwargs.update(self.options) ref_codes = self.model.encode_reference(self.AudioPath) wav = self.model.infer(request.text, ref_codes, self.ref_text) sf.write(request.dst, wav, 24000) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/neutts/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi if [ "x${BUILD_TYPE}" == "xcublas" ] || [ "x${BUILD_TYPE}" == "xl4t" ]; then export CMAKE_ARGS="-DGGML_CUDA=on" fi if [ "x${BUILD_TYPE}" == "xhipblas" ]; then export CMAKE_ARGS="-DGGML_HIPBLAS=on" fi EXTRA_PIP_INSTALL_FLAGS+=" --no-build-isolation" if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then USE_PIP=true fi git clone --depth 100 https://github.com/neuphonic/neutts-air neutts-air cd neutts-air git checkout 1737487debe5b40a0bb97875edce8c66b391722b cd .. cp -rfv neutts-air/neuttsair ./ installRequirements ================================================ FILE: backend/python/neutts/requirements-after.txt ================================================ datasets==4.1.1 torchtune==0.6.1 ================================================ FILE: backend/python/neutts/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu accelerate torch==2.8.0 transformers==4.56.1 librosa==0.11.0 neucodec>=0.0.4 phonemizer==3.3.0 soundfile==0.13.1 resemble-perth==1.0.1 llama-cpp-python ================================================ FILE: backend/python/neutts/requirements-cublas12.txt ================================================ librosa==0.11.0 neucodec>=0.0.4 phonemizer==3.3.0 soundfile==0.13.1 torch==2.8.0 transformers==4.56.1 resemble-perth==1.0.1 accelerate ================================================ FILE: backend/python/neutts/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0+rocm6.4 transformers==4.56.1 accelerate librosa==0.11.0 neucodec>=0.0.4 phonemizer==3.3.0 soundfile==0.13.1 resemble-perth==1.0.1 llama-cpp-python ================================================ FILE: backend/python/neutts/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu126/ torch transformers accelerate librosa==0.11.0 neucodec>=0.0.4 phonemizer==3.3.0 soundfile==0.13.1 resemble-perth==1.0.1 llama-cpp-python ================================================ FILE: backend/python/neutts/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging setuptools numpy==2.2.6 scikit_build_core ================================================ FILE: backend/python/neutts/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/neutts/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(30) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions()) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_tts(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions()) self.assertTrue(response.success) tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story") tts_response = stub.TTS(tts_request) self.assertIsNotNone(tts_response) except Exception as err: print(err) self.fail("TTS service failed") finally: self.tearDown() ================================================ FILE: backend/python/neutts/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/outetts/Makefile ================================================ .PHONY: outetts outetts: bash install.sh .PHONY: run run: outetts @echo "Running outetts..." bash run.sh @echo "outetts run." .PHONY: test test: outetts @echo "Testing outetts..." bash test.sh @echo "outetts tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/outetts/backend.py ================================================ #!/usr/bin/env python3 """ gRPC server for OuteTTS (OuteAI TTS) models. """ from concurrent import futures import argparse import signal import sys import os import asyncio import backend_pb2 import backend_pb2_grpc import grpc import outetts _ONE_DAY_IN_SECONDS = 60 * 60 * 24 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) class BackendServicer(backend_pb2_grpc.BackendServicer): def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): model_name = request.Model if os.path.exists(request.ModelFile): model_name = request.ModelFile self.options = {} for opt in request.Options: if ":" not in opt: continue key, value = opt.split(":", 1) try: if "." in value: value = float(value) else: value = int(value) except ValueError: pass self.options[key] = value MODELNAME = "OuteAI/OuteTTS-0.3-1B" TOKENIZER = "OuteAI/OuteTTS-0.3-1B" VERSION = "0.3" SPEAKER = "en_male_1" for opt in request.Options: if opt.startswith("tokenizer:"): TOKENIZER = opt.split(":")[1] break if opt.startswith("version:"): VERSION = opt.split(":")[1] break if opt.startswith("speaker:"): SPEAKER = opt.split(":")[1] break if model_name != "": MODELNAME = model_name try: model_config = outetts.HFModelConfig_v2( model_path=MODELNAME, tokenizer_path=TOKENIZER ) self.interface = outetts.InterfaceHF(model_version=VERSION, cfg=model_config) self.interface.print_default_speakers() if request.AudioPath: if os.path.isabs(request.AudioPath): self.AudioPath = request.AudioPath else: self.AudioPath = os.path.join(request.ModelPath, request.AudioPath) self.speaker = self.interface.create_speaker(audio_path=self.AudioPath) else: self.speaker = self.interface.load_default_speaker(name=SPEAKER) if request.ContextSize > 0: self.max_tokens = request.ContextSize else: self.max_tokens = self.options.get("max_new_tokens", 512) except Exception as err: print("Error:", err, file=sys.stderr) return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Model loaded successfully", success=True) def TTS(self, request, context): try: text = request.text if request.text else "Speech synthesis is the artificial production of human speech." print("[OuteTTS] generating TTS", file=sys.stderr) gen_cfg = outetts.GenerationConfig( text=text, temperature=self.options.get("temperature", 0.1), repetition_penalty=self.options.get("repetition_penalty", 1.1), max_length=self.max_tokens, speaker=self.speaker, ) output = self.interface.generate(config=gen_cfg) print("[OuteTTS] Generated TTS", file=sys.stderr) output.save(request.dst) print("[OuteTTS] TTS done", file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) async def serve(address): server = grpc.aio.server( migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler( sig, lambda: asyncio.ensure_future(server.stop(5)) ) await server.start() print("Server started. Listening on: " + address, file=sys.stderr) await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the OuteTTS gRPC server.") parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.") args = parser.parse_args() asyncio.run(serve(args.addr)) ================================================ FILE: backend/python/outetts/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi installRequirements ================================================ FILE: backend/python/outetts/requirements-cpu.txt ================================================ torch==2.7.1 llvmlite==0.43.0 numba==0.60.0 accelerate bitsandbytes outetts protobuf==6.33.5 ================================================ FILE: backend/python/outetts/requirements-cublas12.txt ================================================ torch==2.7.1 accelerate llvmlite==0.43.0 numba==0.60.0 bitsandbytes protobuf==6.33.5 outetts ================================================ FILE: backend/python/outetts/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch==2.9.0 llvmlite==0.43.0 numba==0.60.0 bitsandbytes outetts protobuf==6.33.5 ================================================ FILE: backend/python/outetts/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0+rocm6.4 accelerate llvmlite==0.43.0 numba==0.60.0 bitsandbytes outetts protobuf==6.33.5 ================================================ FILE: backend/python/outetts/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch optimum[openvino] llvmlite==0.43.0 numba==0.60.0 bitsandbytes outetts protobuf==6.33.5 ================================================ FILE: backend/python/outetts/requirements.txt ================================================ grpcio==1.76.0 protobuf==6.33.5 certifi setuptools scipy==1.15.1 numpy>=2.0.0 ================================================ FILE: backend/python/outetts/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/outetts/test.py ================================================ """ Test script for the OuteTTS gRPC service. """ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): def setUp(self): self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(5) def tearDown(self): self.service.terminate() self.service.wait() def test_health(self): try: with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: self.fail(f"Health check failed: {err}") finally: self.tearDown() if __name__ == "__main__": unittest.main() ================================================ FILE: backend/python/outetts/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/pocket-tts/Makefile ================================================ .PHONY: pocket-tts pocket-tts: bash install.sh .PHONY: run run: pocket-tts @echo "Running pocket-tts..." bash run.sh @echo "pocket-tts run." .PHONY: test test: pocket-tts @echo "Testing pocket-tts..." bash test.sh @echo "pocket-tts tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/pocket-tts/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for Pocket TTS """ from concurrent import futures import time import argparse import signal import sys import os import traceback import scipy.io.wavfile import backend_pb2 import backend_pb2_grpc import torch from pocket_tts import TTSModel import grpc def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): # Get device if torch.cuda.is_available(): print("CUDA is available", file=sys.stderr) device = "cuda" else: print("CUDA is not available", file=sys.stderr) device = "cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") # Normalize potential 'mpx' typo to 'mps' if device == "mpx": print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr) device = "mps" # Validate mps availability if requested if device == "mps" and not torch.backends.mps.is_available(): print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr) device = "cpu" self.device = device options = request.Options # empty dict self.options = {} # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when # generating the audio for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) # Split only on first colon # if value is a number, convert it to the appropriate type if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value # Default voice for caching self.default_voice_url = self.options.get("default_voice", None) self._voice_cache = {} try: print("Loading Pocket TTS model", file=sys.stderr) self.tts_model = TTSModel.load_model() print(f"Model loaded successfully. Sample rate: {self.tts_model.sample_rate}", file=sys.stderr) # Pre-load default voice if specified if self.default_voice_url: try: print(f"Pre-loading default voice: {self.default_voice_url}", file=sys.stderr) voice_state = self.tts_model.get_state_for_audio_prompt(self.default_voice_url) self._voice_cache[self.default_voice_url] = voice_state print("Default voice loaded successfully", file=sys.stderr) except Exception as e: print(f"Warning: Failed to pre-load default voice: {e}", file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Model loaded successfully", success=True) def _get_voice_state(self, voice_input): """ Get voice state from cache or load it. voice_input can be: - HuggingFace URL (e.g., hf://kyutai/tts-voices/alba-mackenna/casual.wav) - Local file path - None (use default) """ # Use default if no voice specified if not voice_input: voice_input = self.default_voice_url if not voice_input: return None # Check cache first if voice_input in self._voice_cache: return self._voice_cache[voice_input] # Load voice state try: print(f"Loading voice from: {voice_input}", file=sys.stderr) voice_state = self.tts_model.get_state_for_audio_prompt(voice_input) self._voice_cache[voice_input] = voice_state return voice_state except Exception as e: print(f"Error loading voice from {voice_input}: {e}", file=sys.stderr) return None def TTS(self, request, context): try: # Determine voice input # Priority: request.voice > AudioPath (from ModelOptions) > default voice_input = None if request.voice: voice_input = request.voice elif hasattr(request, 'AudioPath') and request.AudioPath: # Use AudioPath as voice file if os.path.isabs(request.AudioPath): voice_input = request.AudioPath elif hasattr(request, 'ModelFile') and request.ModelFile: model_file_base = os.path.dirname(request.ModelFile) voice_input = os.path.join(model_file_base, request.AudioPath) elif hasattr(request, 'ModelPath') and request.ModelPath: voice_input = os.path.join(request.ModelPath, request.AudioPath) else: voice_input = request.AudioPath # Get voice state print(f"DEBUG: voice_input={voice_input}", file=sys.stderr) voice_state = self._get_voice_state(voice_input) print(f"DEBUG: voice_state={voice_state}", file=sys.stderr) if voice_state is None: return backend_pb2.Result( success=False, message=f"Voice not found or failed to load: {voice_input}. Please provide a valid voice URL or file path." ) # Prepare text text = request.text.strip() if not text: return backend_pb2.Result( success=False, message="Text is empty" ) print(f"Generating audio for text: {text[:50]}...", file=sys.stderr) # Generate audio audio = self.tts_model.generate_audio(voice_state, text) # Audio is a 1D torch tensor containing PCM data if audio is None or audio.numel() == 0: return backend_pb2.Result( success=False, message="No audio generated" ) # Save audio to file output_path = request.dst if not output_path: output_path = "/tmp/pocket-tts-output.wav" # Ensure output directory exists output_dir = os.path.dirname(output_path) if output_dir and not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True) # Convert torch tensor to numpy and save audio_numpy = audio.numpy() scipy.io.wavfile.write(output_path, self.tts_model.sample_rate, audio_numpy) print(f"Saved audio to {output_path}", file=sys.stderr) except Exception as err: print(f"Error in TTS: {err}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/pocket-tts/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi # Use python 3.12 for l4t if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then PYTHON_VERSION="3.12" PYTHON_PATCH="12" PY_STANDALONE_TAG="20251120" fi if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then USE_PIP=true fi installRequirements ================================================ FILE: backend/python/pocket-tts/protogen.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto ================================================ FILE: backend/python/pocket-tts/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu pocket-tts scipy torch ================================================ FILE: backend/python/pocket-tts/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu121 pocket-tts scipy torch ================================================ FILE: backend/python/pocket-tts/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 pocket-tts scipy torch ================================================ FILE: backend/python/pocket-tts/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.3 pocket-tts scipy torch==2.7.1+rocm6.3 ================================================ FILE: backend/python/pocket-tts/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu pocket-tts scipy torch ================================================ FILE: backend/python/pocket-tts/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ pocket-tts scipy torch ================================================ FILE: backend/python/pocket-tts/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 pocket-tts scipy torch ================================================ FILE: backend/python/pocket-tts/requirements-mps.txt ================================================ pocket-tts scipy torch==2.7.1 torchvision==0.22.1 ================================================ FILE: backend/python/pocket-tts/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 ================================================ FILE: backend/python/pocket-tts/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/pocket-tts/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import os import tempfile import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(30) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions()) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_tts_with_hf_voice(self): """ This method tests TTS generation with HuggingFace voice URL """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) # Load model response = stub.LoadModel(backend_pb2.ModelOptions()) self.assertTrue(response.success) # Create temporary output file with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: output_path = tmp_file.name # Test TTS with HuggingFace voice URL tts_request = backend_pb2.TTSRequest( text="Hello world, this is a test.", dst=output_path, voice="azelma" ) tts_response = stub.TTS(tts_request) self.assertTrue(tts_response.success) # Verify output file exists and is not empty self.assertTrue(os.path.exists(output_path)) self.assertGreater(os.path.getsize(output_path), 0) # Cleanup os.unlink(output_path) except Exception as err: print(err) self.fail("TTS service failed") finally: self.tearDown() def test_tts_with_default_voice(self): """ This method tests TTS generation with default voice (via AudioPath in LoadModel) """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) # Load model with default voice load_request = backend_pb2.ModelOptions( Options=["default_voice:azelma"] ) response = stub.LoadModel(load_request) self.assertTrue(response.success) # Create temporary output file with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: output_path = tmp_file.name # Test TTS without specifying voice (should use default) tts_request = backend_pb2.TTSRequest( text="Hello world, this is a test.", dst=output_path ) tts_response = stub.TTS(tts_request) self.assertTrue(tts_response.success) # Verify output file exists and is not empty self.assertTrue(os.path.exists(output_path)) self.assertGreater(os.path.getsize(output_path), 0) # Cleanup os.unlink(output_path) except Exception as err: print(err) self.fail("TTS service with default voice failed") finally: self.tearDown() ================================================ FILE: backend/python/pocket-tts/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/qwen-asr/Makefile ================================================ .PHONY: qwen-asr qwen-asr: bash install.sh .PHONY: run run: qwen-asr @echo "Running qwen-asr..." bash run.sh @echo "qwen-asr run." .PHONY: test test: qwen-asr @echo "Testing qwen-asr..." bash test.sh @echo "qwen-asr tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/qwen-asr/backend.py ================================================ #!/usr/bin/env python3 """ gRPC server of LocalAI for Qwen3-ASR (transformers backend, non-vLLM). """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc import torch from qwen_asr import Qwen3ASRModel import grpc def is_float(s): try: float(s) return True except ValueError: return False def is_int(s): try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) class BackendServicer(backend_pb2_grpc.BackendServicer): def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): if torch.cuda.is_available(): device = "cuda" else: device = "cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") self.device = device self.options = {} for opt in request.Options: if ":" not in opt: continue key, value = opt.split(":", 1) if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value model_path = request.Model or "Qwen/Qwen3-ASR-1.7B" default_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 load_dtype = default_dtype if "torch_dtype" in self.options: d = str(self.options["torch_dtype"]).lower() if d == "fp16": load_dtype = torch.float16 elif d == "bf16": load_dtype = torch.bfloat16 elif d == "fp32": load_dtype = torch.float32 del self.options["torch_dtype"] self.max_inference_batch_size = self.options.get("max_inference_batch_size", 32) self.max_new_tokens = self.options.get("max_new_tokens", 256) forced_aligner = self.options.get("forced_aligner") if forced_aligner is not None and isinstance(forced_aligner, str): forced_aligner = forced_aligner.strip() or None attn_implementation = self.options.get("attn_implementation") if attn_implementation is not None and isinstance(attn_implementation, str): attn_implementation = attn_implementation.strip() or None if self.device == "mps": device_map = None elif self.device == "cuda": device_map = "cuda:0" else: device_map = "cpu" load_kwargs = dict( dtype=load_dtype, device_map=device_map, max_inference_batch_size=self.max_inference_batch_size, max_new_tokens=self.max_new_tokens, ) if attn_implementation: load_kwargs["attn_implementation"] = attn_implementation if forced_aligner: load_kwargs["forced_aligner"] = forced_aligner forced_aligner_kwargs = dict( dtype=load_dtype, device_map=device_map, ) if attn_implementation: forced_aligner_kwargs["attn_implementation"] = attn_implementation load_kwargs["forced_aligner_kwargs"] = forced_aligner_kwargs try: print(f"Loading Qwen3-ASR from {model_path}", file=sys.stderr) if attn_implementation: print(f"Using attn_implementation: {attn_implementation}", file=sys.stderr) if forced_aligner: print(f"Loading with forced_aligner: {forced_aligner}", file=sys.stderr) self.model = Qwen3ASRModel.from_pretrained(model_path, **load_kwargs) print("Qwen3-ASR model loaded successfully", file=sys.stderr) except Exception as err: print(f"[ERROR] LoadModel failed: {err}", file=sys.stderr) import traceback traceback.print_exc(file=sys.stderr) return backend_pb2.Result(success=False, message=str(err)) return backend_pb2.Result(message="Model loaded successfully", success=True) def AudioTranscription(self, request, context): result_segments = [] text = "" try: audio_path = request.dst if not audio_path or not os.path.exists(audio_path): print(f"Error: Audio file not found: {audio_path}", file=sys.stderr) return backend_pb2.TranscriptResult(segments=[], text="") language = None if request.language and request.language.strip(): language = request.language.strip() results = self.model.transcribe(audio=audio_path, language=language) if not results: return backend_pb2.TranscriptResult(segments=[], text="") r = results[0] text = r.text or "" if getattr(r, 'time_stamps', None) and len(r.time_stamps) > 0: for idx, ts in enumerate(r.time_stamps): start_ms = 0 end_ms = 0 seg_text = text if isinstance(ts, (list, tuple)) and len(ts) >= 3: start_ms = int(float(ts[0]) * 1000) if ts[0] is not None else 0 end_ms = int(float(ts[1]) * 1000) if ts[1] is not None else 0 seg_text = ts[2] if len(ts) > 2 and ts[2] is not None else "" result_segments.append(backend_pb2.TranscriptSegment( id=idx, start=start_ms, end=end_ms, text=seg_text )) else: if text: result_segments.append(backend_pb2.TranscriptSegment( id=0, start=0, end=0, text=text )) except Exception as err: print(f"Error in AudioTranscription: {err}", file=sys.stderr) import traceback traceback.print_exc(file=sys.stderr) return backend_pb2.TranscriptResult(segments=[], text="") return backend_pb2.TranscriptResult(segments=result_segments, text=text) def serve(address): server = grpc.server( futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument("--addr", default="localhost:50051", help="The address to bind the server to.") args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/qwen-asr/install.sh ================================================ #!/bin/bash set -e EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation" backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi PYTHON_VERSION="3.12" PYTHON_PATCH="12" PY_STANDALONE_TAG="20251120" installRequirements ================================================ FILE: backend/python/qwen-asr/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu torch qwen-asr ================================================ FILE: backend/python/qwen-asr/requirements-cublas12-after.txt ================================================ flash-attn ================================================ FILE: backend/python/qwen-asr/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu121 torch qwen-asr ================================================ FILE: backend/python/qwen-asr/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch qwen-asr ================================================ FILE: backend/python/qwen-asr/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.3 torch==2.7.1+rocm6.3 qwen-asr ================================================ FILE: backend/python/qwen-asr/requirements-intel-after.txt ================================================ flash-attn ================================================ FILE: backend/python/qwen-asr/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch qwen-asr ================================================ FILE: backend/python/qwen-asr/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ torch qwen-asr ================================================ FILE: backend/python/qwen-asr/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch qwen-asr ================================================ FILE: backend/python/qwen-asr/requirements-mps.txt ================================================ torch==2.7.1 qwen-asr ================================================ FILE: backend/python/qwen-asr/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 setuptools h11 gradio uvicorn ================================================ FILE: backend/python/qwen-asr/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/qwen-asr/test.py ================================================ """ Tests for the Qwen3-ASR gRPC backend. """ import unittest import subprocess import time import os import tempfile import shutil import backend_pb2 import backend_pb2_grpc import grpc # Skip heavy transcription test in CI (model download + inference) SKIP_ASR_TESTS = os.environ.get("SKIP_ASR_TESTS", "false").lower() == "true" class TestBackendServicer(unittest.TestCase): def setUp(self): self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(15) def tearDown(self): self.service.terminate() self.service.wait() def test_server_startup(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="Qwen/Qwen3-ASR-1.7B")) self.assertTrue(response.success, response.message) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() @unittest.skipIf(SKIP_ASR_TESTS, "ASR transcription test skipped (SKIP_ASR_TESTS=true)") def test_audio_transcription(self): temp_dir = tempfile.mkdtemp() audio_file = os.path.join(temp_dir, 'audio.wav') try: url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen3-ASR-Repo/asr_en.wav" result = subprocess.run( ["wget", "-q", url, "-O", audio_file], capture_output=True, text=True, timeout=30, ) if result.returncode != 0: self.skipTest(f"Could not download sample audio: {result.stderr}") if not os.path.exists(audio_file): self.skipTest("Sample audio file not found after download") self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) load_response = stub.LoadModel(backend_pb2.ModelOptions(Model="Qwen/Qwen3-ASR-0.6B")) self.assertTrue(load_response.success, load_response.message) transcript_response = stub.AudioTranscription( backend_pb2.TranscriptRequest(dst=audio_file) ) self.assertIsNotNone(transcript_response) self.assertIsNotNone(transcript_response.text) self.assertGreaterEqual(len(transcript_response.segments), 0) all_text = "" for segment in transcript_response.segments: all_text += segment.text print(f"All text: {all_text}") self.assertIn("big", all_text) if transcript_response.segments: self.assertIsNotNone(transcript_response.segments[0].text) finally: self.tearDown() if os.path.exists(temp_dir): shutil.rmtree(temp_dir) ================================================ FILE: backend/python/qwen-asr/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/qwen-tts/Makefile ================================================ .PHONY: qwen-tts qwen-tts: bash install.sh .PHONY: run run: qwen-tts @echo "Running qwen-tts..." bash run.sh @echo "qwen-tts run." .PHONY: test test: qwen-tts @echo "Testing qwen-tts..." bash test.sh @echo "qwen-tts tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/qwen-tts/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for Qwen3-TTS """ from concurrent import futures import time import argparse import signal import sys import os import copy import traceback from pathlib import Path import backend_pb2 import backend_pb2_grpc import torch import soundfile as sf from qwen_tts import Qwen3TTSModel import json import hashlib import pickle import grpc def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get("PYTHON_GRPC_MAX_WORKERS", "1")) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", "utf-8")) def LoadModel(self, request, context): # Get device if torch.cuda.is_available(): print("CUDA is available", file=sys.stderr) device = "cuda" else: print("CUDA is not available", file=sys.stderr) device = "cpu" mps_available = ( hasattr(torch.backends, "mps") and torch.backends.mps.is_available() ) if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") # Normalize potential 'mpx' typo to 'mps' if device == "mpx": print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr) device = "mps" # Validate mps availability if requested if device == "mps" and not torch.backends.mps.is_available(): print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr) device = "cpu" self.device = device self._torch_device = torch.device(device) options = request.Options # empty dict self.options = {} # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when # generating the audio for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) # Split only on first colon # if value is a number, convert it to the appropriate type if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value # Parse voices configuration from options self.voices = {} if "voices" in self.options: try: voices_data = self.options["voices"] if isinstance(voices_data, str): voices_list = json.loads(voices_data) else: voices_list = voices_data # Validate and store voices for voice_entry in voices_list: if not isinstance(voice_entry, dict): print( f"[WARNING] Invalid voice entry (not a dict): {voice_entry}", file=sys.stderr, ) continue name = voice_entry.get("name") audio = voice_entry.get("audio") ref_text = voice_entry.get("ref_text") if not name or not isinstance(name, str): print( f"[WARNING] Voice entry missing required 'name' field: {voice_entry}", file=sys.stderr, ) continue if not audio or not isinstance(audio, str): print( f"[WARNING] Voice entry missing required 'audio' field: {voice_entry}", file=sys.stderr, ) continue if ref_text is None or not isinstance(ref_text, str): print( f"[WARNING] Voice entry missing required 'ref_text' field: {voice_entry}", file=sys.stderr, ) continue self.voices[name] = {"audio": audio, "ref_text": ref_text} print( f"[INFO] Registered voice '{name}' with audio: {audio}", file=sys.stderr, ) print(f"[INFO] Loaded {len(self.voices)} voice(s)", file=sys.stderr) except json.JSONDecodeError as e: print(f"[ERROR] Failed to parse voices JSON: {e}", file=sys.stderr) except Exception as e: print( f"[ERROR] Error processing voices configuration: {e}", file=sys.stderr, ) print(traceback.format_exc(), file=sys.stderr) # Get model path from request model_path = request.Model if not model_path: model_path = "Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice" # Determine model type from model path or options self.model_type = self.options.get("model_type", None) if not self.model_type: if "CustomVoice" in model_path: self.model_type = "CustomVoice" elif "VoiceDesign" in model_path: self.model_type = "VoiceDesign" elif "Base" in model_path or "0.6B" in model_path or "1.7B" in model_path: self.model_type = "Base" # VoiceClone model else: # Default to CustomVoice self.model_type = "CustomVoice" # Cache for voice clone prompts self._voice_clone_cache = {} # Pre-load cached voices if disk_cache is enabled self._preload_cached_voices() # Store AudioPath, ModelFile, and ModelPath from LoadModel request # These are used later in TTS for VoiceClone mode self.audio_path = ( request.AudioPath if hasattr(request, "AudioPath") and request.AudioPath else None ) self.model_file = ( request.ModelFile if hasattr(request, "ModelFile") and request.ModelFile else None ) self.model_path = ( request.ModelPath if hasattr(request, "ModelPath") and request.ModelPath else None ) # Decide dtype & attention implementation if self.device == "mps": load_dtype = torch.float32 # MPS requires float32 device_map = None attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS elif self.device == "cuda": load_dtype = torch.bfloat16 device_map = "cuda" attn_impl_primary = "flash_attention_2" else: # cpu load_dtype = torch.float32 device_map = "cpu" attn_impl_primary = "sdpa" print( f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}, model_type: {self.model_type}", file=sys.stderr, ) print(f"Loading model from: {model_path}", file=sys.stderr) # Load model with device-specific logic # Common parameters for all devices load_kwargs = { "dtype": load_dtype, "attn_implementation": attn_impl_primary, "trust_remote_code": True, # Required for qwen-tts models } try: if self.device == "mps": load_kwargs["device_map"] = None # load then move self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) self.model.to("mps") elif self.device == "cuda": load_kwargs["device_map"] = device_map self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) else: # cpu load_kwargs["device_map"] = device_map self.model = Qwen3TTSModel.from_pretrained(model_path, **load_kwargs) except Exception as e: error_msg = str(e) print( f"[ERROR] Loading model: {type(e).__name__}: {error_msg}", file=sys.stderr, ) print(traceback.format_exc(), file=sys.stderr) # Check if it's a missing feature extractor/tokenizer error if ( "speech_tokenizer" in error_msg or "preprocessor_config.json" in error_msg or "feature extractor" in error_msg.lower() ): print( "\n[ERROR] Model files appear to be incomplete. This usually means:", file=sys.stderr, ) print( " 1. The model download was interrupted or incomplete", file=sys.stderr, ) print(" 2. The model cache is corrupted", file=sys.stderr) print("\nTo fix this, try:", file=sys.stderr) print( f" rm -rf ~/.cache/huggingface/hub/models--Qwen--Qwen3-TTS-*", file=sys.stderr, ) print(" Then re-run to trigger a fresh download.", file=sys.stderr) print( "\nAlternatively, try using a different model variant:", file=sys.stderr, ) print(" - Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice", file=sys.stderr) print(" - Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign", file=sys.stderr) print(" - Qwen/Qwen3-TTS-12Hz-1.7B-Base", file=sys.stderr) if attn_impl_primary == "flash_attention_2": print( "\nTrying to use SDPA instead of flash_attention_2...", file=sys.stderr, ) load_kwargs["attn_implementation"] = "sdpa" try: if self.device == "mps": load_kwargs["device_map"] = None self.model = Qwen3TTSModel.from_pretrained( model_path, **load_kwargs ) self.model.to("mps") else: load_kwargs["device_map"] = ( self.device if self.device in ("cuda", "cpu") else None ) self.model = Qwen3TTSModel.from_pretrained( model_path, **load_kwargs ) except Exception as e2: print( f"[ERROR] Failed to load with SDPA: {type(e2).__name__}: {e2}", file=sys.stderr, ) print(traceback.format_exc(), file=sys.stderr) raise e2 else: raise e print(f"Model loaded successfully: {model_path}", file=sys.stderr) return backend_pb2.Result(message="Model loaded successfully", success=True) def _detect_mode(self, request): """Detect which mode to use based on request parameters.""" # Priority: VoiceClone > VoiceDesign > CustomVoice # model_type explicitly set if self.model_type == "CustomVoice": return "CustomVoice" if self.model_type == "VoiceClone": return "VoiceClone" if self.model_type == "VoiceDesign": return "VoiceDesign" # VoiceClone: AudioPath is provided OR voices dict is populated if self.audio_path or self.voices: return "VoiceClone" # VoiceDesign: instruct option is provided if "instruct" in self.options and self.options["instruct"]: return "VoiceDesign" # Default to CustomVoice return "CustomVoice" def _get_ref_audio_path(self, request, voice_name=None): """Get reference audio path from stored AudioPath or from voices dict.""" # If voice_name is provided and exists in voices dict, use that if voice_name and voice_name in self.voices: audio_path = self.voices[voice_name]["audio"] # If absolute path, use as-is if os.path.isabs(audio_path): return audio_path # Try relative to ModelFile if self.model_file: model_file_base = os.path.dirname(self.model_file) ref_path = os.path.join(model_file_base, audio_path) if os.path.exists(ref_path): return ref_path # Try relative to ModelPath if self.model_path: ref_path = os.path.join(self.model_path, audio_path) if os.path.exists(ref_path): return ref_path # Return as-is (might be URL or base64) return audio_path # Fall back to legacy single-voice mode using self.audio_path if not self.audio_path: return None # If absolute path, use as-is if os.path.isabs(self.audio_path): return self.audio_path # Try relative to ModelFile if self.model_file: model_file_base = os.path.dirname(self.model_file) ref_path = os.path.join(model_file_base, self.audio_path) if os.path.exists(ref_path): return ref_path # Try relative to ModelPath if self.model_path: ref_path = os.path.join(self.model_path, self.audio_path) if os.path.exists(ref_path): return ref_path # Return as-is (might be URL or base64) return self.audio_path def _get_voice_clone_prompt(self, request, ref_audio, ref_text): """Get or create voice clone prompt, with in-memory and disk caching.""" cache_key = self._get_voice_cache_key(ref_audio, ref_text) if cache_key not in self._voice_clone_cache: # Check disk cache first (if enabled) disk_cached = self._get_cached_voice_clone_prompt_from_disk( ref_audio, ref_text ) if disk_cached is not None: self._voice_clone_cache[cache_key] = disk_cached else: # Create new prompt print(f"Creating voice clone prompt from {ref_audio}", file=sys.stderr) try: prompt_items = self.model.create_voice_clone_prompt( ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=self.options.get( "x_vector_only_mode", False ), ) self._voice_clone_cache[cache_key] = prompt_items # Save to disk cache if enabled self._save_voice_clone_prompt_to_disk( ref_audio, ref_text, prompt_items ) except Exception as e: print(f"Error creating voice clone prompt: {e}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return None return self._voice_clone_cache[cache_key] def _is_text_file_path(self, text): """Check if the text is a file path to a text file.""" if not text or not isinstance(text, str): return False # Check if it looks like a file path (contains / or \ and ends with common text file extensions) text_extensions = [".txt", ".md", ".rst", ".text"] has_path_separator = "/" in text or "\\" in text ends_with_text_ext = any(text.lower().endswith(ext) for ext in text_extensions) return has_path_separator and ends_with_text_ext def _read_text_file(self, file_path): """Read text content from a file path, resolving relative paths.""" try: # If absolute path, use as-is if os.path.isabs(file_path): resolved_path = file_path else: # Try relative to ModelFile if self.model_file: model_file_base = os.path.dirname(self.model_file) candidate_path = os.path.join(model_file_base, file_path) if os.path.exists(candidate_path): resolved_path = candidate_path else: resolved_path = file_path else: resolved_path = file_path # Try relative to ModelPath if not os.path.exists(resolved_path) and self.model_path: candidate_path = os.path.join(self.model_path, file_path) if os.path.exists(candidate_path): resolved_path = candidate_path # Check if file exists and is readable if not os.path.exists(resolved_path): print( f"[ERROR] ref_text file not found: {resolved_path}", file=sys.stderr ) return None if not os.path.isfile(resolved_path): print( f"[ERROR] ref_text path is not a file: {resolved_path}", file=sys.stderr, ) return None # Read and return file contents with open(resolved_path, "r", encoding="utf-8") as f: content = f.read().strip() print( f"[INFO] Successfully read ref_text from file: {resolved_path}", file=sys.stderr, ) return content except Exception as e: print( f"[ERROR] Failed to read ref_text file {file_path}: {e}", file=sys.stderr, ) print(traceback.format_exc(), file=sys.stderr) return None def _compute_file_hash(self, file_path): """Compute SHA256 hash of file content.""" try: sha256 = hashlib.sha256() with open(file_path, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): sha256.update(chunk) return sha256.hexdigest() except Exception as e: print( f"[ERROR] Failed to compute hash for {file_path}: {e}", file=sys.stderr ) return None def _compute_string_hash(self, text): """Compute SHA256 hash of string.""" return hashlib.sha256(text.encode("utf-8")).hexdigest() def _get_cached_voice_clone_prompt_from_disk(self, ref_audio, ref_text_content): """Load cached prompt from disk if available and valid.""" if not self.options.get("disk_cache", False): return None cache_file = f"{ref_audio}.voice_cache.pkl" if not os.path.exists(cache_file): return None try: with open(cache_file, "rb") as f: cached = pickle.load(f) # Validate checksums current_audio_hash = self._compute_file_hash(ref_audio) current_text_hash = self._compute_string_hash(ref_text_content) if current_audio_hash is None or cached["audio_hash"] != current_audio_hash: print("[INFO] Cache invalidation: audio file changed", file=sys.stderr) os.remove(cache_file) return None if cached["ref_text_hash"] != current_text_hash: print( "[INFO] Cache invalidation: ref_text content changed", file=sys.stderr, ) os.remove(cache_file) return None print( f"[INFO] Loaded voice clone prompt from disk cache: {cache_file}", file=sys.stderr, ) return cached["prompt_items"] except Exception as e: print( f"[WARNING] Failed to load disk cache {cache_file}: {e}", file=sys.stderr, ) return None def _save_voice_clone_prompt_to_disk( self, ref_audio, ref_text_content, prompt_items ): """Save prompt to disk cache alongside audio file.""" if not self.options.get("disk_cache", False): return cache_file = f"{ref_audio}.voice_cache.pkl" try: cache_data = { "audio_hash": self._compute_file_hash(ref_audio), "ref_text_hash": self._compute_string_hash(ref_text_content), "prompt_items": prompt_items, } with open(cache_file, "wb") as f: pickle.dump(cache_data, f) print( f"[INFO] Saved voice clone prompt to disk cache: {cache_file}", file=sys.stderr, ) except Exception as e: print( f"[WARNING] Failed to save disk cache {cache_file}: {e}", file=sys.stderr, ) def _get_voice_cache_key(self, ref_audio, ref_text): """Get the cache key for a voice.""" return f"{ref_audio}:{ref_text}" def _preload_cached_voices(self): """Pre-load cached voice prompts at model startup.""" if not self.voices or not self.options.get("disk_cache", False): return print( f"[INFO] Pre-loading {len(self.voices)} cached voice(s)...", file=sys.stderr ) loaded_count = 0 missing_count = 0 invalid_count = 0 for voice_name, voice_config in self.voices.items(): audio_path = voice_config["audio"] ref_text_path = voice_config["ref_text"] # Check for cache file cache_file = f"{audio_path}.voice_cache.pkl" if os.path.exists(cache_file): # Read ref_text content for validation ref_text_content = self._read_text_file(ref_text_path) if ref_text_content is None: invalid_count += 1 print( f"[INFO] Cannot read ref_text for {voice_name} (will recreate on first use)", file=sys.stderr, ) continue cached_prompt = self._get_cached_voice_clone_prompt_from_disk( audio_path, ref_text_content ) if cached_prompt: # Pre-populate memory cache with content-based key cache_key = self._get_voice_cache_key(audio_path, ref_text_content) self._voice_clone_cache[cache_key] = cached_prompt loaded_count += 1 print(f"[INFO] Pre-loaded voice: {voice_name}", file=sys.stderr) else: invalid_count += 1 print( f"[INFO] Cache invalid for {voice_name} (will recreate on first use)", file=sys.stderr, ) else: missing_count += 1 print( f"[INFO] No cache found for {voice_name} (will create on first use)", file=sys.stderr, ) # Summary line print( f"[INFO] Pre-loaded {loaded_count}/{len(self.voices)} voices ({missing_count} missing, {invalid_count} invalid)", file=sys.stderr, ) def TTS(self, request, context): try: # Check if dst is provided if not request.dst: return backend_pb2.Result( success=False, message="dst (output path) is required" ) # Prepare text text = request.text.strip() if not text: return backend_pb2.Result(success=False, message="Text is empty") # Get language (auto-detect if not provided) language = ( request.language if hasattr(request, "language") and request.language else None ) if not language or language == "": language = "Auto" # Auto-detect language # Detect mode mode = self._detect_mode(request) print(f"Detected mode: {mode}", file=sys.stderr) # Get generation parameters from options max_new_tokens = self.options.get("max_new_tokens", None) top_p = self.options.get("top_p", None) temperature = self.options.get("temperature", None) do_sample = self.options.get("do_sample", None) # Prepare generation kwargs generation_kwargs = {} if max_new_tokens is not None: generation_kwargs["max_new_tokens"] = max_new_tokens if top_p is not None: generation_kwargs["top_p"] = top_p if temperature is not None: generation_kwargs["temperature"] = temperature if do_sample is not None: generation_kwargs["do_sample"] = do_sample instruct = self.options.get("instruct", "") if instruct is not None and instruct != "": generation_kwargs["instruct"] = instruct # Generate audio based on mode if mode == "VoiceClone": # VoiceClone mode # Check if multi-voice mode is active (voices dict is populated) voice_name = None if self.voices: # Get voice from request (priority) or options voice_name = request.voice if request.voice else None if not voice_name: voice_name = self.options.get("voice", None) # Validate voice exists if voice_name and voice_name not in self.voices: available_voices = ", ".join(sorted(self.voices.keys())) return backend_pb2.Result( success=False, message=f"Voice '{voice_name}' not found. Available voices: {available_voices}", ) # Get reference audio path (with voice-specific lookup if in multi-voice mode) ref_audio = self._get_ref_audio_path(request, voice_name) if not ref_audio: if voice_name: return backend_pb2.Result( success=False, message=f"Audio path for voice '{voice_name}' could not be resolved", ) else: return backend_pb2.Result( success=False, message="AudioPath is required for VoiceClone mode", ) # Get reference text (from voice config if multi-voice, else from options/request) if voice_name and voice_name in self.voices: ref_text_source = self.voices[voice_name]["ref_text"] else: ref_text_source = self.options.get("ref_text", None) if not ref_text_source: # Try to get from request if available if hasattr(request, "ref_text") and request.ref_text: ref_text_source = request.ref_text if not ref_text_source: # x_vector_only_mode doesn't require ref_text if not self.options.get("x_vector_only_mode", False): return backend_pb2.Result( success=False, message="ref_text is required for VoiceClone mode (or set x_vector_only_mode=true)", ) # Determine if ref_text_source is a file path ref_text_is_file = ref_text_source and self._is_text_file_path( ref_text_source ) if ref_text_is_file: ref_text_content = self._read_text_file(ref_text_source) if ref_text_content is None: return backend_pb2.Result( success=False, message=f"Failed to read ref_text from file: {ref_text_source}", ) ref_text_source = ref_text_content print( f"[INFO] Loaded ref_text from file: {ref_text_content[:100]}...", file=sys.stderr, ) # For caching: use the content as the key (since we've read the file if it was one) ref_text_for_cache = ref_text_source # Check if we should use cached prompt use_cached_prompt = self.options.get("use_cached_prompt", True) voice_clone_prompt = None if use_cached_prompt: voice_clone_prompt = self._get_voice_clone_prompt( request, ref_audio, ref_text_for_cache ) if voice_clone_prompt is None: return backend_pb2.Result( success=False, message="Failed to create voice clone prompt" ) if voice_clone_prompt: # Use cached prompt start_time = time.time() wavs, sr = self.model.generate_voice_clone( text=text, language=language, voice_clone_prompt=voice_clone_prompt, **generation_kwargs, ) generation_duration = time.time() - start_time print( f"[INFO] Voice clone generation completed: {generation_duration:.2f}s, output_samples={len(wavs) if wavs else 0}", file=sys.stderr, flush=True, ) else: # Create prompt on-the-fly (only for non-file ref_text that wasn't cached) start_time = time.time() wavs, sr = self.model.generate_voice_clone( text=text, language=language, ref_audio=ref_audio, ref_text=ref_text_source, x_vector_only_mode=self.options.get( "x_vector_only_mode", False ), **generation_kwargs, ) generation_duration = time.time() - start_time print( f"[INFO] Voice clone generation (on-the-fly) completed: {generation_duration:.2f}s, output_samples={len(wavs) if wavs else 0}", file=sys.stderr, flush=True, ) elif mode == "VoiceDesign": # VoiceDesign mode if not instruct: return backend_pb2.Result( success=False, message="instruct option is required for VoiceDesign mode", ) wavs, sr = self.model.generate_voice_design( text=text, language=language, **generation_kwargs ) else: # CustomVoice mode (default) speaker = request.voice if request.voice else None if not speaker: # Try to get from options speaker = self.options.get("speaker", None) if not speaker: # Use default speaker speaker = "Vivian" print( f"No speaker specified, using default: {speaker}", file=sys.stderr, ) # Validate speaker if model supports it if hasattr(self.model, "get_supported_speakers"): try: supported_speakers = self.model.get_supported_speakers() if speaker not in supported_speakers: print( f"Warning: Speaker '{speaker}' not in supported list. Available: {supported_speakers}", file=sys.stderr, ) # Try to find a close match (case-insensitive) speaker_lower = speaker.lower() for sup_speaker in supported_speakers: if sup_speaker.lower() == speaker_lower: speaker = sup_speaker print( f"Using matched speaker: {speaker}", file=sys.stderr, ) break except Exception as e: print( f"Warning: Could not get supported speakers: {e}", file=sys.stderr, ) wavs, sr = self.model.generate_custom_voice( text=text, language=language, speaker=speaker, **generation_kwargs ) # Save output if wavs is not None and len(wavs) > 0: # wavs is a list, take first element audio_data = wavs[0] if isinstance(wavs, list) else wavs audio_duration = len(audio_data) / sr if sr > 0 else 0 sf.write(request.dst, audio_data, sr) print( f"Saved {audio_duration:.2f}s audio to {request.dst}", file=sys.stderr, ) else: return backend_pb2.Result( success=False, message="No audio output generated" ) except Exception as err: print(f"Error in TTS: {err}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.Result( success=False, message=f"Unexpected {err=}, {type(err)=}" ) return backend_pb2.Result(success=True) def serve(address): server = grpc.server( futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ("grpc.max_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_send_message_length", 50 * 1024 * 1024), # 50MB ("grpc.max_receive_message_length", 50 * 1024 * 1024), # 50MB ], ) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/qwen-tts/install.sh ================================================ #!/bin/bash set -e EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation" backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi installRequirements ================================================ FILE: backend/python/qwen-tts/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu torch torchaudio qwen-tts sox ================================================ FILE: backend/python/qwen-tts/requirements-cublas12-after.txt ================================================ flash-attn ================================================ FILE: backend/python/qwen-tts/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu121 torch torchaudio qwen-tts sox ================================================ FILE: backend/python/qwen-tts/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio qwen-tts sox ================================================ FILE: backend/python/qwen-tts/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.3 torch==2.7.1+rocm6.3 torchaudio==2.7.1+rocm6.3 qwen-tts sox ================================================ FILE: backend/python/qwen-tts/requirements-intel-after.txt ================================================ flash-attn ================================================ FILE: backend/python/qwen-tts/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch torchaudio qwen-tts sox ================================================ FILE: backend/python/qwen-tts/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ torch torchaudio qwen-tts sox ================================================ FILE: backend/python/qwen-tts/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch torchaudio qwen-tts sox ================================================ FILE: backend/python/qwen-tts/requirements-mps.txt ================================================ torch torchaudio qwen-tts sox ================================================ FILE: backend/python/qwen-tts/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 soundfile setuptools six scipy librosa ================================================ FILE: backend/python/qwen-tts/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/qwen-tts/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import os import sys import tempfile import threading import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen( ["python3", "backend.py", "--addr", "localhost:50051"], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True ) time.sleep(30) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() try: stdout, stderr = self.service.communicate(timeout=5) # Output should already be printed by threads, but print any remaining if stdout: print("=== REMAINING STDOUT ===") print(stdout) if stderr: print("=== REMAINING STDERR ===") print(stderr) except subprocess.TimeoutExpired: self.service.kill() stdout, stderr = self.service.communicate() if stdout: print("=== REMAINING STDOUT ===") print(stdout) if stderr: print("=== REMAINING STDERR ===") print(stderr) def test_tts(self): """ This method tests if the TTS generation works successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) # Allow up to 10 minutes for model download on first run response = stub.LoadModel( backend_pb2.ModelOptions(Model="Qwen/Qwen3-TTS-12Hz-0.6B-CustomVoice"), timeout=600.0 ) self.assertTrue(response.success) # Create temporary output file with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as tmp_file: output_path = tmp_file.name tts_request = backend_pb2.TTSRequest( text="Hello, this is a test of the qwen-tts backend.", voice="Vivian", dst=output_path ) # Allow up to 2 minutes for TTS generation tts_response = stub.TTS(tts_request, timeout=120.0) self.assertIsNotNone(tts_response) self.assertTrue(tts_response.success) # Verify output file exists and is not empty self.assertTrue(os.path.exists(output_path)) self.assertGreater(os.path.getsize(output_path), 0) # Cleanup os.unlink(output_path) except Exception as err: print(f"Exception: {err}", file=sys.stderr) # Give threads a moment to flush any remaining output time.sleep(1) self.fail("TTS service failed") finally: self.tearDown() ================================================ FILE: backend/python/qwen-tts/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/rerankers/Makefile ================================================ .PHONY: rerankers rerankers: bash install.sh .PHONY: run run: rerankers @echo "Running rerankers..." bash run.sh @echo "rerankers run." # It is not working well by using command line. It only6 works with IDE like VSCode. .PHONY: test test: rerankers @echo "Testing rerankers..." bash test.sh @echo "rerankers tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/rerankers/README.md ================================================ # Creating a separate environment for the reranker project ``` make reranker ``` ================================================ FILE: backend/python/rerankers/backend.py ================================================ #!/usr/bin/env python3 """ Extra gRPC server for Rerankers models. """ from concurrent import futures import argparse import signal import sys import os import time import backend_pb2 import backend_pb2_grpc import grpc from rerankers import Reranker _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ A gRPC servicer for the backend service. This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding. """ def Health(self, request, context): """ A gRPC method that returns the health status of the backend service. Args: request: A HealthRequest object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: A Reply object that contains the health status of the backend service. """ return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): """ A gRPC method that loads a model into memory. Args: request: A LoadModelRequest object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: A Result object that contains the result of the LoadModel operation. """ model_name = request.Model try: kwargs = {} if request.Type != "": kwargs['model_type'] = request.Type if request.PipelineType != "": # Reuse the PipelineType field for language kwargs['lang'] = request.PipelineType self.model_name = model_name self.model = Reranker(model_name, **kwargs) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) def Rerank(self, request, context): documents = [] for idx, doc in enumerate(request.documents): documents.append(doc) ranked_results=self.model.rank(query=request.query, docs=documents, doc_ids=list(range(len(request.documents)))) # Prepare results to return cropped_results = ranked_results.top_k(request.top_n) if request.top_n > 0 else ranked_results results = [ backend_pb2.DocumentResult( index=res.doc_id, text=res.text, relevance_score=res.score ) for res in (cropped_results) ] # Calculate the usage and total tokens # TODO: Implement the usage calculation with reranker total_tokens = sum(len(doc.split()) for doc in request.documents) + len(request.query.split()) prompt_tokens = len(request.query.split()) usage = backend_pb2.Usage(total_tokens=total_tokens, prompt_tokens=prompt_tokens) return backend_pb2.RerankResult(usage=usage, results=results) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/rerankers/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi installRequirements ================================================ FILE: backend/python/rerankers/requirements-cpu.txt ================================================ transformers accelerate torch==2.4.1 rerankers[transformers] ================================================ FILE: backend/python/rerankers/requirements-cublas12.txt ================================================ transformers accelerate torch==2.4.1 rerankers[transformers] ================================================ FILE: backend/python/rerankers/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 transformers accelerate torch==2.9.1 rerankers[transformers] ================================================ FILE: backend/python/rerankers/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 transformers accelerate torch==2.8.0+rocm6.4 rerankers[transformers] ================================================ FILE: backend/python/rerankers/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu transformers accelerate torch rerankers[transformers] optimum[openvino] setuptools ================================================ FILE: backend/python/rerankers/requirements-mps.txt ================================================ torch==2.7.1 transformers accelerate rerankers[transformers] ================================================ FILE: backend/python/rerankers/requirements.txt ================================================ grpcio==1.78.1 protobuf certifi ================================================ FILE: backend/python/rerankers/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/rerankers/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(10) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.kill() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_rerank(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) request = backend_pb2.RerankRequest( query="I love you", documents=["I hate you", "I really like you"], top_n=2 ) response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) self.assertTrue(response.success) rerank_response = stub.Rerank(request) print(rerank_response.results[0]) self.assertIsNotNone(rerank_response.results) self.assertEqual(len(rerank_response.results), 2) self.assertEqual(rerank_response.results[0].text, "I really like you") self.assertEqual(rerank_response.results[1].text, "I hate you") except Exception as err: print(err) self.fail("Reranker service failed") finally: self.tearDown() def test_rerank_omit_top_n(self): """ This method tests if the embeddings are generated successfully even top_n is omitted """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) request = backend_pb2.RerankRequest( query="I love you", documents=["I hate you", "I really like you"], top_n=0 # ) response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) self.assertTrue(response.success) rerank_response = stub.Rerank(request) print(rerank_response.results[0]) self.assertIsNotNone(rerank_response.results) self.assertEqual(len(rerank_response.results), 2) self.assertEqual(rerank_response.results[0].text, "I really like you") self.assertEqual(rerank_response.results[1].text, "I hate you") except Exception as err: print(err) self.fail("Reranker service failed") finally: self.tearDown() def test_rerank_crop(self): """ This method tests top_n cropping """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) request = backend_pb2.RerankRequest( query="I love you", documents=["I hate you", "I really like you", "I hate ignoring top_n"], top_n=2 ) response = stub.LoadModel(backend_pb2.ModelOptions(Model="cross-encoder")) self.assertTrue(response.success) rerank_response = stub.Rerank(request) print(rerank_response.results[0]) self.assertIsNotNone(rerank_response.results) self.assertEqual(len(rerank_response.results), 2) self.assertEqual(rerank_response.results[0].text, "I really like you") self.assertEqual(rerank_response.results[1].text, "I hate you") except Exception as err: print(err) self.fail("Reranker service failed") finally: self.tearDown() ================================================ FILE: backend/python/rerankers/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/rfdetr/Makefile ================================================ .DEFAULT_GOAL := install .PHONY: install install: bash install.sh .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/rfdetr/backend.py ================================================ #!/usr/bin/env python3 """ gRPC server for RFDETR object detection models. """ from concurrent import futures import argparse import signal import sys import os import time import base64 import backend_pb2 import backend_pb2_grpc import grpc import requests import supervision as sv from inference import get_model from PIL import Image from io import BytesIO _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ A gRPC servicer for the RFDETR backend service. This class implements the gRPC methods for object detection using RFDETR models. """ def __init__(self): self.model = None self.model_name = None def Health(self, request, context): """ A gRPC method that returns the health status of the backend service. Args: request: A HealthMessage object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: A Reply object that contains the health status of the backend service. """ return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): """ A gRPC method that loads a RFDETR model into memory. Args: request: A ModelOptions object that contains the model parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: A Result object that contains the result of the LoadModel operation. """ model_name = request.Model try: # Load the RFDETR model self.model = get_model(model_name) self.model_name = model_name print(f'Loaded RFDETR model: {model_name}') except Exception as err: return backend_pb2.Result(success=False, message=f"Failed to load model: {err}") return backend_pb2.Result(message="Model loaded successfully", success=True) def Detect(self, request, context): """ A gRPC method that performs object detection on an image. Args: request: A DetectOptions object that contains the image source. context: A grpc.ServicerContext object that provides information about the RPC. Returns: A DetectResponse object that contains the detection results. """ if self.model is None: print(f"Model is None") return backend_pb2.DetectResponse() print(f"Model is not None") try: print(f"Decoding image") # Decode the base64 image print(f"Image data: {request.src}") image_data = base64.b64decode(request.src) image = Image.open(BytesIO(image_data)) # Perform inference predictions = self.model.infer(image, confidence=0.5)[0] # Convert to proto format proto_detections = [] for i in range(len(predictions.predictions)): pred = predictions.predictions[i] print(f"Prediction: {pred}") proto_detection = backend_pb2.Detection( x=float(pred.x), y=float(pred.y), width=float(pred.width), height=float(pred.height), confidence=float(pred.confidence), class_name=pred.class_name ) proto_detections.append(proto_detection) return backend_pb2.DetectResponse(Detections=proto_detections) except Exception as err: print(f"Detection error: {err}") return backend_pb2.DetectResponse() def Status(self, request, context): """ A gRPC method that returns the status of the backend service. Args: request: A HealthMessage object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: A StatusResponse object that contains the status information. """ state = backend_pb2.StatusResponse.READY if self.model is not None else backend_pb2.StatusResponse.UNINITIALIZED return backend_pb2.StatusResponse(state=state) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("[RFDETR] Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("[RFDETR] Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the RFDETR gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() print(f"[RFDETR] startup: {args}", file=sys.stderr) serve(args.addr) ================================================ FILE: backend/python/rfdetr/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi installRequirements ================================================ FILE: backend/python/rfdetr/requirements-cpu.txt ================================================ rfdetr opencv-python accelerate peft inference torch==2.7.1 optimum-quanto ================================================ FILE: backend/python/rfdetr/requirements-cublas12.txt ================================================ torch==2.7.1 rfdetr opencv-python accelerate inference peft optimum-quanto ================================================ FILE: backend/python/rfdetr/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch==2.9.1 rfdetr opencv-python accelerate inference peft optimum-quanto ================================================ FILE: backend/python/rfdetr/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0+rocm6.4 torchvision==0.23.0+rocm6.4 rfdetr opencv-python accelerate inference peft optimum-quanto ================================================ FILE: backend/python/rfdetr/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch torchvision optimum[openvino] setuptools rfdetr inference opencv-python accelerate peft optimum-quanto ================================================ FILE: backend/python/rfdetr/requirements-mps.txt ================================================ torch==2.7.1 rfdetr opencv-python accelerate peft inference optimum-quanto ================================================ FILE: backend/python/rfdetr/requirements.txt ================================================ grpcio==1.71.0 protobuf grpcio-tools ================================================ FILE: backend/python/rfdetr/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/rfdetr/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/transformers/Makefile ================================================ .PHONY: transformers transformers: bash install.sh .PHONY: run run: transformers @echo "Running transformers..." bash run.sh @echo "transformers run." # It is not working well by using command line. It only6 works with IDE like VSCode. .PHONY: test test: transformers @echo "Testing transformers..." bash test.sh @echo "transformers tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/transformers/README.md ================================================ # Creating a separate environment for the transformers project ``` make transformers ``` ================================================ FILE: backend/python/transformers/backend.py ================================================ #!/usr/bin/env python3 """ Extra gRPC server for HuggingFace AutoModel models. """ from concurrent import futures import argparse import signal import sys import os from threading import Thread import asyncio import time import backend_pb2 import backend_pb2_grpc import grpc import torch import torch.cuda XPU=os.environ.get("XPU", "0") == "1" from transformers import AutoTokenizer, AutoModel, set_seed, TextIteratorStreamer, StoppingCriteriaList, StopStringCriteria, MambaConfig, MambaForCausalLM from transformers import AutoProcessor, MusicgenForConditionalGeneration, DiaForConditionalGeneration from scipy.io import wavfile from sentence_transformers import SentenceTransformer _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) def mean_pooling(model_output, attention_mask): """ Mean pooling to get sentence embeddings. See: https://huggingface.co/sentence-transformers/paraphrase-distilroberta-base-v1 """ token_embeddings = model_output[0] input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1) # Sum columns sum_mask = torch.clamp(input_mask_expanded.sum(1), min=1e-9) return sum_embeddings / sum_mask # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ A gRPC servicer for the backend service. This class implements the gRPC methods for the backend service, including Health, LoadModel, and Embedding. """ def Health(self, request, context): """ A gRPC method that returns the health status of the backend service. Args: request: A HealthRequest object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: A Reply object that contains the health status of the backend service. """ return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): """ A gRPC method that loads a model into memory. Args: request: A LoadModelRequest object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: A Result object that contains the result of the LoadModel operation. """ model_name = request.Model # Check to see if the Model exists in the filesystem already. if os.path.exists(request.ModelFile): model_name = request.ModelFile compute = torch.float16 if request.F16Memory == True: compute=torch.bfloat16 self.CUDA = torch.cuda.is_available() self.OV=False self.DiaTTS=False self.SentenceTransformer = False device_map="cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device_map = "mps" quantization = None autoTokenizer = True # Parse options from request.Options self.options = {} options = request.Options # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when generating # Example options: ["max_new_tokens:3072", "guidance_scale:3.0", "temperature:1.8", "top_p:0.90", "top_k:45"] for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) # if value is a number, convert it to the appropriate type try: if "." in value: value = float(value) else: value = int(value) except ValueError: # Keep as string if conversion fails pass self.options[key] = value print(f"Parsed options: {self.options}", file=sys.stderr) if self.CUDA: from transformers import BitsAndBytesConfig, AutoModelForCausalLM if request.MainGPU: device_map=request.MainGPU else: device_map="cuda:0" if request.Quantization == "bnb_4bit": quantization = BitsAndBytesConfig( load_in_4bit = True, bnb_4bit_compute_dtype = compute, bnb_4bit_quant_type = "nf4", bnb_4bit_use_double_quant = True, load_in_8bit = False, ) elif request.Quantization == "bnb_8bit": quantization = BitsAndBytesConfig( load_in_4bit=False, bnb_4bit_compute_dtype = None, load_in_8bit=True, ) try: if request.Type == "AutoModelForCausalLM": if XPU: import intel_extension_for_pytorch as ipex from intel_extension_for_transformers.transformers.modeling import AutoModelForCausalLM device_map="xpu" compute=torch.float16 if request.Quantization == "xpu_4bit": xpu_4bit = True xpu_8bit = False elif request.Quantization == "xpu_8bit": xpu_4bit = False xpu_8bit = True else: xpu_4bit = False xpu_8bit = False self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, device_map=device_map, load_in_4bit=xpu_4bit, load_in_8bit=xpu_8bit, torch_dtype=compute) else: self.model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute) elif request.Type == "OVModelForCausalLM": from optimum.intel.openvino import OVModelForCausalLM from openvino.runtime import Core if request.MainGPU: device_map=request.MainGPU else: device_map="AUTO" devices = Core().available_devices if "GPU" in " ".join(devices): device_map="AUTO:GPU" # While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected. # https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html if "CPU" or "NPU" in device_map: if "-CPU" or "-NPU" not in device_map: ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"} else: ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"} self.model = OVModelForCausalLM.from_pretrained(model_name, compile=True, trust_remote_code=request.TrustRemoteCode, ov_config=ovconfig, device=device_map) self.OV = True elif request.Type == "OVModelForFeatureExtraction": from optimum.intel.openvino import OVModelForFeatureExtraction from openvino.runtime import Core if request.MainGPU: device_map=request.MainGPU else: device_map="AUTO" devices = Core().available_devices if "GPU" in " ".join(devices): device_map="AUTO:GPU" # While working on a fine tuned model, inference may give an inaccuracy and performance drop on GPU if winograd convolutions are selected. # https://docs.openvino.ai/2024/openvino-workflow/running-inference/inference-devices-and-modes/gpu-device.html if "CPU" or "NPU" in device_map: if "-CPU" or "-NPU" not in device_map: ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT"} else: ovconfig={"PERFORMANCE_HINT": "CUMULATIVE_THROUGHPUT","GPU_DISABLE_WINOGRAD_CONVOLUTION": "YES"} self.model = OVModelForFeatureExtraction.from_pretrained(model_name, compile=True, trust_remote_code=request.TrustRemoteCode, ov_config=ovconfig, export=True, device=device_map) self.OV = True elif request.Type == "MusicgenForConditionalGeneration": autoTokenizer = False self.processor = AutoProcessor.from_pretrained(model_name) self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) elif request.Type == "DiaForConditionalGeneration": autoTokenizer = False print("DiaForConditionalGeneration", file=sys.stderr) self.processor = AutoProcessor.from_pretrained(model_name) self.model = DiaForConditionalGeneration.from_pretrained(model_name) if self.CUDA: self.model = self.model.to("cuda") self.processor = self.processor.to("cuda") print("DiaForConditionalGeneration loaded", file=sys.stderr) self.DiaTTS = True elif request.Type == "SentenceTransformer": autoTokenizer = False self.model = SentenceTransformer(model_name, trust_remote_code=request.TrustRemoteCode) self.SentenceTransformer = True elif request.Type == "Mamba": autoTokenizer = False self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = MambaForCausalLM.from_pretrained(model_name) else: print("Automodel", file=sys.stderr) self.model = AutoModel.from_pretrained(model_name, trust_remote_code=request.TrustRemoteCode, use_safetensors=True, quantization_config=quantization, device_map=device_map, torch_dtype=compute) if request.ContextSize > 0: self.max_tokens = request.ContextSize elif hasattr(self.model, 'config') and hasattr(self.model.config, 'max_position_embeddings'): self.max_tokens = self.model.config.max_position_embeddings else: self.max_tokens = self.options.get("max_new_tokens", 512) if autoTokenizer: self.tokenizer = AutoTokenizer.from_pretrained(model_name, use_safetensors=True) self.XPU = False if XPU and self.OV == False: self.XPU = True try: print("Optimizing model", model_name, "to XPU.", file=sys.stderr) self.model = ipex.optimize_transformers(self.model, inplace=True, dtype=torch.float16, device="xpu") except Exception as err: print("Not using XPU:", err, file=sys.stderr) except Exception as err: print("Error:", err, file=sys.stderr) return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") # Implement your logic here for the LoadModel service # Replace this with your desired response return backend_pb2.Result(message="Model loaded successfully", success=True) def Embedding(self, request, context): """ A gRPC method that calculates embeddings for a given sentence. Args: request: An EmbeddingRequest object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: An EmbeddingResult object that contains the calculated embeddings. """ set_seed(request.Seed) # Tokenize input max_length = 512 if request.Tokens != 0: max_length = request.Tokens embeds = None if self.SentenceTransformer: print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) embeds = self.model.encode(request.Embeddings) else: encoded_input = self.tokenizer(request.Embeddings, padding=True, truncation=True, max_length=max_length, return_tensors="pt") # Create word embeddings if self.CUDA: encoded_input = encoded_input.to("cuda") with torch.no_grad(): model_output = self.model(**encoded_input) # Pool to get sentence embeddings; i.e. generate one 1024 vector for the entire sentence sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) embeds = sentence_embeddings[0] return backend_pb2.EmbeddingResult(embeddings=embeds) async def _predict(self, request, context, streaming=False): set_seed(request.Seed) if request.TopP < 0 or request.TopP > 1: request.TopP = 1 if request.TopK <= 0: request.TopK = 50 if request.Temperature > 0 : sample=True else: sample=False request.TopP == None request.TopK == None request.Temperature == None prompt = request.Prompt if not request.Prompt and request.UseTokenizerTemplate and request.Messages: prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) inputs = self.tokenizer(prompt, return_tensors="pt") if request.Tokens > 0: max_tokens = request.Tokens else: max_tokens = self.max_tokens - inputs["input_ids"].size()[inputs["input_ids"].dim()-1] if self.CUDA: inputs = inputs.to("cuda") if XPU and self.OV == False: inputs = inputs.to("xpu") streaming = False criteria=[] if request.StopPrompts: criteria = StoppingCriteriaList( [ StopStringCriteria(tokenizer=self.tokenizer, stop_strings=request.StopPrompts), ] ) if streaming: streamer=TextIteratorStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) config=dict(inputs, max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP, top_k=request.TopK, do_sample=sample, attention_mask=inputs["attention_mask"], eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.eos_token_id, streamer=streamer, stopping_criteria=criteria, use_cache=True, ) thread=Thread(target=self.model.generate, kwargs=config) thread.start() generated_text = "" try: for new_text in streamer: generated_text += new_text yield backend_pb2.Reply(message=bytes(new_text, encoding='utf-8')) finally: thread.join() else: if XPU and self.OV == False: outputs = self.model.generate(inputs["input_ids"], max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP, top_k=request.TopK, do_sample=sample, pad_token=self.tokenizer.eos_token_id) else: outputs = self.model.generate(**inputs, max_new_tokens=max_tokens, temperature=request.Temperature, top_p=request.TopP, top_k=request.TopK, do_sample=sample, eos_token_id=self.tokenizer.eos_token_id, pad_token_id=self.tokenizer.eos_token_id, stopping_criteria=criteria, use_cache=True, ) generated_text = self.tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True)[0] if streaming: return yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) async def Predict(self, request, context): """ Generates text based on the given prompt and sampling parameters. Args: request: The predict request. context: The gRPC context. Returns: backend_pb2.Reply: The predict result. """ gen = self._predict(request, context, streaming=False) res = await gen.__anext__() return res async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results. Args: request: The predict stream request. context: The gRPC context. Returns: backend_pb2.Result: The predict stream result. """ iterations = self._predict(request, context, streaming=True) try: async for iteration in iterations: yield iteration finally: await iterations.aclose() def SoundGeneration(self, request, context): model_name = request.model try: if self.processor is None: if model_name == "": return backend_pb2.Result(success=False, message="request.model is required") self.processor = AutoProcessor.from_pretrained(model_name) if self.model is None: if model_name == "": return backend_pb2.Result(success=False, message="request.model is required") self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) inputs = None if request.text == "": inputs = self.model.get_unconditional_inputs(num_samples=1) elif request.HasField('src'): # TODO SECURITY CODE GOES HERE LOL # WHO KNOWS IF THIS WORKS??? sample_rate, wsamples = wavfile.read('path_to_your_file.wav') if request.HasField('src_divisor'): wsamples = wsamples[: len(wsamples) // request.src_divisor] inputs = self.processor( audio=wsamples, sampling_rate=sample_rate, text=[request.text], padding=True, return_tensors="pt", ) else: inputs = self.processor( text=[request.text], padding=True, return_tensors="pt", ) if request.HasField('duration'): tokens = int(request.duration * 51.2) # 256 tokens = 5 seconds, therefore 51.2 tokens is one second guidance = self.options.get("guidance_scale", 3.0) if request.HasField('temperature'): guidance = request.temperature dosample = self.options.get("do_sample", True) if request.HasField('sample'): dosample = request.sample audio_values = self.model.generate(**inputs, do_sample=dosample, guidance_scale=guidance, max_new_tokens=self.max_tokens) print("[transformers-musicgen] SoundGeneration generated!", file=sys.stderr) sampling_rate = self.model.config.audio_encoder.sampling_rate wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy()) print("[transformers-musicgen] SoundGeneration saved to", request.dst, file=sys.stderr) print("[transformers-musicgen] SoundGeneration for", file=sys.stderr) print("[transformers-musicgen] SoundGeneration requested tokens", tokens, file=sys.stderr) print(request, file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def CallDiaTTS(self, request, context): """ Generates dialogue audio using the Dia model. Args: request: A TTSRequest containing text dialogue and generation parameters context: The gRPC context Returns: A Result object indicating success or failure """ try: print("[DiaTTS] generating dialogue audio", file=sys.stderr) # Prepare text input - expect dialogue format like [S1] ... [S2] ... text = [request.text] # Process the input inputs = self.processor(text=text, padding=True, return_tensors="pt") # Generate audio with parameters from options or defaults generation_params = { **inputs, "max_new_tokens": self.max_tokens, "guidance_scale": self.options.get("guidance_scale", 3.0), "temperature": self.options.get("temperature", 1.8), "top_p": self.options.get("top_p", 0.90), "top_k": self.options.get("top_k", 45) } outputs = self.model.generate(**generation_params) # Decode and save audio outputs = self.processor.batch_decode(outputs) self.processor.save_audio(outputs, request.dst) print("[DiaTTS] Generated dialogue audio", file=sys.stderr) print("[DiaTTS] Audio saved to", request.dst, file=sys.stderr) print("[DiaTTS] Dialogue generation done", file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) # The TTS endpoint is older, and provides fewer features, but exists for compatibility reasons def TTS(self, request, context): if self.DiaTTS: print("DiaTTS", file=sys.stderr) return self.CallDiaTTS(request, context) model_name = request.model try: if self.processor is None: if model_name == "": return backend_pb2.Result(success=False, message="request.model is required") self.processor = AutoProcessor.from_pretrained(model_name) if self.model is None: if model_name == "": return backend_pb2.Result(success=False, message="request.model is required") self.model = MusicgenForConditionalGeneration.from_pretrained(model_name) inputs = self.processor( text=[request.text], padding=True, return_tensors="pt", ) tokens = self.max_tokens # No good place to set the "length" in TTS, so use 10s as a sane default audio_values = self.model.generate(**inputs, max_new_tokens=tokens) print("[transformers-musicgen] TTS generated!", file=sys.stderr) sampling_rate = self.model.config.audio_encoder.sampling_rate wavfile.write(request.dst, rate=sampling_rate, data=audio_values[0, 0].numpy()) print("[transformers-musicgen] TTS saved to", request.dst, file=sys.stderr) print("[transformers-musicgen] TTS for", file=sys.stderr) print(request, file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) async def serve(address): # Start asyncio gRPC server server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address server.add_insecure_port(address) # Gracefully shutdown the server on SIGTERM or SIGINT loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler( sig, lambda: asyncio.ensure_future(server.stop(5)) ) # Start the server await server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Wait for the server to be terminated await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() asyncio.run(serve(args.addr)) ================================================ FILE: backend/python/transformers/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi installRequirements ================================================ FILE: backend/python/transformers/requirements-cpu.txt ================================================ torch==2.7.1 llvmlite==0.43.0 numba==0.60.0 accelerate transformers bitsandbytes sentence-transformers==5.2.3 protobuf==6.33.5 ================================================ FILE: backend/python/transformers/requirements-cublas12.txt ================================================ torch==2.7.1 accelerate llvmlite==0.43.0 numba==0.60.0 transformers bitsandbytes sentence-transformers==5.2.3 protobuf==6.33.5 ================================================ FILE: backend/python/transformers/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch==2.9.0 llvmlite==0.43.0 numba==0.60.0 transformers bitsandbytes sentence-transformers==5.2.3 protobuf==6.33.5 ================================================ FILE: backend/python/transformers/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0+rocm6.4 accelerate transformers llvmlite==0.43.0 numba==0.60.0 bitsandbytes sentence-transformers==5.2.3 protobuf==6.33.5 ================================================ FILE: backend/python/transformers/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch optimum[openvino] llvmlite==0.43.0 numba==0.60.0 transformers bitsandbytes sentence-transformers==5.2.3 protobuf==6.33.5 ================================================ FILE: backend/python/transformers/requirements-mps.txt ================================================ torch==2.7.1 llvmlite==0.43.0 numba==0.60.0 accelerate transformers bitsandbytes sentence-transformers==5.2.3 protobuf==6.33.5 ================================================ FILE: backend/python/transformers/requirements.txt ================================================ grpcio==1.78.1 protobuf==6.33.5 certifi setuptools scipy==1.15.1 numpy>=2.0.0 ================================================ FILE: backend/python/transformers/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi if [ -d "/opt/intel" ]; then # Assumes we are using the Intel oneAPI container image # https://github.com/intel/intel-extension-for-pytorch/issues/538 export XPU=1 fi startBackend $@ ================================================ FILE: backend/python/transformers/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(10) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.kill() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_embedding(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-cased")) print(response.message) self.assertTrue(response.success) embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.") embedding_response = stub.Embedding(embedding_request) self.assertIsNotNone(embedding_response.embeddings) except Exception as err: print(err) self.fail("Embedding service failed") finally: self.tearDown() def test_audio_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small",Type="MusicgenForConditionalGeneration")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_tts(self): """ This method tests if TTS is generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small",Type="MusicgenForConditionalGeneration")) self.assertTrue(response.success) tts_request = backend_pb2.TTSRequest(text="80s TV news production music hit for tonight's biggest story") tts_response = stub.TTS(tts_request) self.assertIsNotNone(tts_response) except Exception as err: print(err) self.fail("TTS service failed") finally: self.tearDown() def test_sound_generation(self): """ This method tests if SoundGeneration is generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/musicgen-small",Type="MusicgenForConditionalGeneration")) self.assertTrue(response.success) sg_request = backend_pb2.SoundGenerationRequest(text="80s TV news production music hit for tonight's biggest story") sg_response = stub.SoundGeneration(sg_request) self.assertIsNotNone(sg_response) except Exception as err: print(err) self.fail("SoundGeneration service failed") finally: self.tearDown() def test_embed_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens",Type="SentenceTransformer")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_sentencetransformers_embedding(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="bert-base-nli-mean-tokens",Type="SentenceTransformer")) self.assertTrue(response.success) embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.") embedding_response = stub.Embedding(embedding_request) self.assertIsNotNone(embedding_response.embeddings) except Exception as err: print(err) self.fail("Embedding service failed") finally: self.tearDown() ================================================ FILE: backend/python/transformers/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/vibevoice/Makefile ================================================ .PHONY: vibevoice vibevoice: bash install.sh .PHONY: download-voices download-voices: @echo "Downloading voice preset files..." @mkdir -p voices/streaming_model @if command -v wget >/dev/null 2>&1; then \ wget -q -O voices/streaming_model/en-Frank_man.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Frank_man.pt && \ wget -q -O voices/streaming_model/en-Grace_woman.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Grace_woman.pt && \ wget -q -O voices/streaming_model/en-Mike_man.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Mike_man.pt && \ wget -q -O voices/streaming_model/en-Emma_woman.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Emma_woman.pt && \ wget -q -O voices/streaming_model/en-Carter_man.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Carter_man.pt && \ wget -q -O voices/streaming_model/en-Davis_man.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Davis_man.pt && \ echo "Voice files downloaded successfully"; \ elif command -v curl >/dev/null 2>&1; then \ curl -sL -o voices/streaming_model/en-Frank_man.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Frank_man.pt && \ curl -sL -o voices/streaming_model/en-Grace_woman.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Grace_woman.pt && \ curl -sL -o voices/streaming_model/en-Mike_man.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Mike_man.pt && \ curl -sL -o voices/streaming_model/en-Emma_woman.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Emma_woman.pt && \ curl -sL -o voices/streaming_model/en-Carter_man.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Carter_man.pt && \ curl -sL -o voices/streaming_model/en-Davis_man.pt \ https://raw.githubusercontent.com/microsoft/VibeVoice/main/demo/voices/streaming_model/en-Davis_man.pt && \ echo "Voice files downloaded successfully"; \ else \ echo "Error: Neither wget nor curl found. Cannot download voice files."; \ exit 1; \ fi .PHONY: run run: vibevoice @echo "Running vibevoice..." bash run.sh @echo "vibevoice run." .PHONY: test test: vibevoice download-voices @echo "Testing vibevoice..." bash test.sh @echo "vibevoice tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/vibevoice/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for VibeVoice """ from concurrent import futures import time import argparse import signal import sys import os import copy import traceback from pathlib import Path import backend_pb2 import backend_pb2_grpc import torch from vibevoice.modular.modeling_vibevoice_streaming_inference import VibeVoiceStreamingForConditionalGenerationInference from vibevoice.processor.vibevoice_streaming_processor import VibeVoiceStreamingProcessor from vibevoice.modular.modeling_vibevoice_asr import VibeVoiceASRForConditionalGeneration from vibevoice.processor.vibevoice_asr_processor import VibeVoiceASRProcessor import grpc def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): # Get device if torch.cuda.is_available(): print("CUDA is available", file=sys.stderr) device = "cuda" else: print("CUDA is not available", file=sys.stderr) device = "cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") # Normalize potential 'mpx' typo to 'mps' if device == "mpx": print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr) device = "mps" # Validate mps availability if requested if device == "mps" and not torch.backends.mps.is_available(): print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr) device = "cpu" self.device = device self._torch_device = torch.device(device) options = request.Options # empty dict self.options = {} # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when # generating the audio for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) # Split only on first colon # if value is a number, convert it to the appropriate type if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value # Check if ASR mode is enabled self.asr_mode = self.options.get("asr_mode", False) if not isinstance(self.asr_mode, bool): # Handle string "true"/"false" case self.asr_mode = str(self.asr_mode).lower() == "true" # Get model path from request model_path = request.Model if not model_path: if self.asr_mode: model_path = "microsoft/VibeVoice-ASR" # Default ASR model else: model_path = "microsoft/VibeVoice-Realtime-0.5B" # Default TTS model default_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 load_dtype = default_dtype if "torch_dtype" in self.options: torch_dtype_str = str(self.options["torch_dtype"]).lower() if torch_dtype_str == "fp16": load_dtype = torch.float16 elif torch_dtype_str == "bf16": load_dtype = torch.bfloat16 elif torch_dtype_str == "fp32": load_dtype = torch.float32 # remove it from options after reading del self.options["torch_dtype"] # Get inference steps from options, default to 5 (TTS only) self.inference_steps = self.options.get("inference_steps", 5) if not isinstance(self.inference_steps, int) or self.inference_steps <= 0: self.inference_steps = 5 # Get cfg_scale from options, default to 1.5 (TTS only) self.cfg_scale = self.options.get("cfg_scale", 1.5) if not isinstance(self.cfg_scale, (int, float)) or self.cfg_scale <= 0: self.cfg_scale = 1.5 # Get ASR generation parameters from options self.max_new_tokens = self.options.get("max_new_tokens", 512) if not isinstance(self.max_new_tokens, int) or self.max_new_tokens <= 0: self.max_new_tokens = 512 self.temperature = self.options.get("temperature", 0.0) if not isinstance(self.temperature, (int, float)) or self.temperature < 0: self.temperature = 0.0 self.top_p = self.options.get("top_p", 1.0) if not isinstance(self.top_p, (int, float)) or self.top_p <= 0: self.top_p = 1.0 self.do_sample = self.options.get("do_sample", None) if self.do_sample is None: # Default: use sampling if temperature > 0 self.do_sample = self.temperature > 0 elif not isinstance(self.do_sample, bool): self.do_sample = str(self.do_sample).lower() == "true" self.num_beams = self.options.get("num_beams", 1) if not isinstance(self.num_beams, int) or self.num_beams < 1: self.num_beams = 1 self.repetition_penalty = self.options.get("repetition_penalty", 1.0) if not isinstance(self.repetition_penalty, (int, float)) or self.repetition_penalty <= 0: self.repetition_penalty = 1.0 # Determine voices directory # Priority order: # 1. voices_dir option (explicitly set by user - highest priority) # 2. Relative to ModelFile if provided # 3. Relative to ModelPath (models directory) if provided # 4. Backend directory # 5. Absolute path from AudioPath if provided voices_dir = None # First check if voices_dir is explicitly set in options if "voices_dir" in self.options: voices_dir_option = self.options["voices_dir"] if isinstance(voices_dir_option, str) and voices_dir_option.strip(): voices_dir = voices_dir_option.strip() # If relative path, try to resolve it relative to ModelPath or ModelFile if not os.path.isabs(voices_dir): if hasattr(request, 'ModelPath') and request.ModelPath: voices_dir = os.path.join(request.ModelPath, voices_dir) elif request.ModelFile: model_file_base = os.path.dirname(request.ModelFile) voices_dir = os.path.join(model_file_base, voices_dir) # If still relative, make it absolute from current working directory if not os.path.isabs(voices_dir): voices_dir = os.path.abspath(voices_dir) # Check if the directory exists if not os.path.exists(voices_dir): print(f"Warning: voices_dir option specified but directory does not exist: {voices_dir}", file=sys.stderr) voices_dir = None # If not set via option, try relative to ModelFile if provided if not voices_dir and request.ModelFile: model_file_base = os.path.dirname(request.ModelFile) voices_dir = os.path.join(model_file_base, "voices", "streaming_model") if not os.path.exists(voices_dir): voices_dir = None # If not found, try relative to ModelPath (models directory) if not voices_dir and hasattr(request, 'ModelPath') and request.ModelPath: voices_dir = os.path.join(request.ModelPath, "voices", "streaming_model") if not os.path.exists(voices_dir): voices_dir = None # If not found, try relative to backend directory if not voices_dir: backend_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) voices_dir = os.path.join(backend_dir, "vibevoice", "voices", "streaming_model") if not os.path.exists(voices_dir): # Try absolute path from AudioPath if provided if request.AudioPath and os.path.isabs(request.AudioPath): voices_dir = os.path.dirname(request.AudioPath) else: voices_dir = None # Initialize voice-related attributes (TTS only) self.voices_dir = voices_dir self.voice_presets = {} self._voice_cache = {} self.default_voice_key = None # Store AudioPath, ModelFile, and ModelPath from LoadModel request for use in TTS self.audio_path = request.AudioPath if hasattr(request, 'AudioPath') and request.AudioPath else None self.model_file = request.ModelFile if hasattr(request, 'ModelFile') and request.ModelFile else None self.model_path = request.ModelPath if hasattr(request, 'ModelPath') and request.ModelPath else None # Decide attention implementation and device_map (matching upstream example) if self.device == "mps": device_map = None attn_impl_primary = "sdpa" # flash_attention_2 not supported on MPS elif self.device == "cuda": device_map = "cuda" attn_impl_primary = "flash_attention_2" else: # cpu device_map = "cpu" # Match upstream example: use "cpu" for CPU device_map attn_impl_primary = "sdpa" try: if self.asr_mode: # Load ASR model and processor print(f"Loading ASR processor & model from {model_path}", file=sys.stderr) # Load ASR processor self.processor = VibeVoiceASRProcessor.from_pretrained( model_path, language_model_pretrained_name="Qwen/Qwen2.5-7B" ) print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}", file=sys.stderr) # Load ASR model - use device_map=None and move manually to avoid JSON serialization issues # Load with dtype to ensure all components are in correct dtype from the start try: print(f"Using attention implementation: {attn_impl_primary}", file=sys.stderr) # Load model with dtype to ensure all components are in correct dtype self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( model_path, dtype=load_dtype, device_map=None, # Always use None, move manually to avoid JSON serialization issues attn_implementation=attn_impl_primary, trust_remote_code=True ) # Move to device manually self.model = self.model.to(self.device) except Exception as e: if attn_impl_primary == 'flash_attention_2': print(f"[ERROR] : {type(e).__name__}: {e}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) print("Error loading the ASR model. Trying to use SDPA.", file=sys.stderr) self.model = VibeVoiceASRForConditionalGeneration.from_pretrained( model_path, dtype=load_dtype, device_map=None, attn_implementation='sdpa', trust_remote_code=True ) # Move to device manually self.model = self.model.to(self.device) else: raise e self.model.eval() print(f"ASR model loaded successfully", file=sys.stderr) else: # Load TTS model and processor (existing logic) # Load voice presets if directory exists if self.voices_dir and os.path.exists(self.voices_dir): self._load_voice_presets() else: print(f"Warning: Voices directory not found. Voice presets will not be available.", file=sys.stderr) print(f"Loading TTS processor & model from {model_path}", file=sys.stderr) self.processor = VibeVoiceStreamingProcessor.from_pretrained(model_path) print(f"Using device: {self.device}, torch_dtype: {load_dtype}, attn_implementation: {attn_impl_primary}", file=sys.stderr) # Load model with device-specific logic (matching upstream example exactly) try: if self.device == "mps": self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( model_path, torch_dtype=load_dtype, attn_implementation=attn_impl_primary, device_map=None, # load then move ) self.model.to("mps") elif self.device == "cuda": self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( model_path, torch_dtype=load_dtype, device_map=device_map, attn_implementation=attn_impl_primary, ) else: # cpu # Match upstream example: use device_map="cpu" for CPU self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( model_path, torch_dtype=load_dtype, device_map="cpu", attn_implementation=attn_impl_primary, ) except Exception as e: if attn_impl_primary == 'flash_attention_2': print(f"[ERROR] : {type(e).__name__}: {e}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) print("Error loading the model. Trying to use SDPA. However, note that only flash_attention_2 has been fully tested, and using SDPA may result in lower audio quality.", file=sys.stderr) # Match upstream example fallback pattern self.model = VibeVoiceStreamingForConditionalGenerationInference.from_pretrained( model_path, torch_dtype=load_dtype, device_map=(self.device if self.device in ("cuda", "cpu") else None), attn_implementation='sdpa' ) if self.device == "mps": self.model.to("mps") else: raise e self.model.eval() self.model.set_ddpm_inference_steps(num_steps=self.inference_steps) # Set default voice key if self.voice_presets: # Try to get default from environment or use first available preset_name = os.environ.get("VOICE_PRESET") self.default_voice_key = self._determine_voice_key(preset_name) print(f"Default voice preset: {self.default_voice_key}", file=sys.stderr) else: print("Warning: No voice presets available. Voice selection will not work.", file=sys.stderr) except Exception as err: # Format error message safely, avoiding JSON serialization issues error_msg = str(err) error_type = type(err).__name__ # Include traceback for debugging tb_str = traceback.format_exc() print(f"[ERROR] LoadModel failed: {error_type}: {error_msg}", file=sys.stderr) print(tb_str, file=sys.stderr) return backend_pb2.Result(success=False, message=f"{error_type}: {error_msg}") return backend_pb2.Result(message="Model loaded successfully", success=True) def _load_voice_presets(self): """Load voice presets from the voices directory.""" if not self.voices_dir or not os.path.exists(self.voices_dir): self.voice_presets = {} return self.voice_presets = {} # Get all .pt files in the voices directory pt_files = [f for f in os.listdir(self.voices_dir) if f.lower().endswith('.pt') and os.path.isfile(os.path.join(self.voices_dir, f))] # Create dictionary with filename (without extension) as key for pt_file in pt_files: # Remove .pt extension to get the name name = os.path.splitext(pt_file)[0] # Create full path full_path = os.path.join(self.voices_dir, pt_file) self.voice_presets[name] = full_path # Sort the voice presets alphabetically by name self.voice_presets = dict(sorted(self.voice_presets.items())) print(f"Found {len(self.voice_presets)} voice files in {self.voices_dir}", file=sys.stderr) if self.voice_presets: print(f"Available voices: {', '.join(self.voice_presets.keys())}", file=sys.stderr) def _determine_voice_key(self, name): """Determine voice key from name or use default.""" if name and name in self.voice_presets: return name # Try default key default_key = "en-WHTest_man" if default_key in self.voice_presets: return default_key # Use first available if self.voice_presets: first_key = next(iter(self.voice_presets)) print(f"Using fallback voice preset: {first_key}", file=sys.stderr) return first_key return None def _get_voice_path(self, speaker_name): """Get voice file path for a given speaker name.""" if not self.voice_presets: return None # First try exact match if speaker_name and speaker_name in self.voice_presets: return self.voice_presets[speaker_name] # Try partial matching (case insensitive) if speaker_name: speaker_lower = speaker_name.lower() for preset_name, path in self.voice_presets.items(): if preset_name.lower() in speaker_lower or speaker_lower in preset_name.lower(): return path # Default to first voice if no match found if self.default_voice_key and self.default_voice_key in self.voice_presets: return self.voice_presets[self.default_voice_key] elif self.voice_presets: default_voice = list(self.voice_presets.values())[0] print(f"Warning: No voice preset found for '{speaker_name}', using default voice: {default_voice}", file=sys.stderr) return default_voice return None def _ensure_voice_cached(self, voice_path): """Load and cache voice preset.""" if not voice_path or not os.path.exists(voice_path): return None # Ensure cache exists (should be initialized in LoadModel) if not hasattr(self, '_voice_cache'): self._voice_cache = {} # Use path as cache key if voice_path not in self._voice_cache: print(f"Loading prefilled prompt from {voice_path}", file=sys.stderr) # Match self-test.py: use string device name for map_location # Ensure self.device exists (should be set in LoadModel) try: if not hasattr(self, 'device'): # Fallback to CPU if device not set device_str = "cpu" else: device_str = str(self.device) except AttributeError as e: print(f"Error accessing self.device: {e}, falling back to CPU", file=sys.stderr) device_str = "cpu" if device_str != "cpu": map_loc = device_str else: map_loc = "cpu" # Call torch.load with explicit arguments prefilled_outputs = torch.load(voice_path, map_location=map_loc, weights_only=False) self._voice_cache[voice_path] = prefilled_outputs return self._voice_cache[voice_path] def TTS(self, request, context): try: # Get voice selection # Priority: request.voice > AudioPath > default voice_path = None voice_key = None if request.voice: # Try to get voice by name voice_path = self._get_voice_path(request.voice) if voice_path: voice_key = request.voice elif self.audio_path: # Use AudioPath from LoadModel as voice file if os.path.isabs(self.audio_path): voice_path = self.audio_path elif self.model_file: model_file_base = os.path.dirname(self.model_file) voice_path = os.path.join(model_file_base, self.audio_path) elif self.model_path: voice_path = os.path.join(self.model_path, self.audio_path) else: voice_path = self.audio_path elif self.default_voice_key: voice_path = self._get_voice_path(self.default_voice_key) voice_key = self.default_voice_key if not voice_path or not os.path.exists(voice_path): return backend_pb2.Result( success=False, message=f"Voice file not found: {voice_path}. Please provide a valid voice preset or AudioPath." ) # Load voice preset prefilled_outputs = self._ensure_voice_cached(voice_path) if prefilled_outputs is None: return backend_pb2.Result( success=False, message=f"Failed to load voice preset from {voice_path}" ) # Get generation parameters from options cfg_scale = self.options.get("cfg_scale", self.cfg_scale) inference_steps = self.options.get("inference_steps", self.inference_steps) do_sample = self.options.get("do_sample", False) temperature = self.options.get("temperature", 0.9) top_p = self.options.get("top_p", 0.9) # Update inference steps if needed if inference_steps != self.inference_steps: self.model.set_ddpm_inference_steps(num_steps=inference_steps) self.inference_steps = inference_steps # Prepare text text = request.text.strip().replace("'", "'").replace('"', '"').replace('"', '"') # Prepare inputs inputs = self.processor.process_input_with_cached_prompt( text=text, cached_prompt=prefilled_outputs, padding=True, return_tensors="pt", return_attention_mask=True, ) # Move tensors to target device (matching self-test.py exactly) # Explicitly ensure it's a string to avoid any variable name collisions target_device = str(self.device) if str(self.device) != "cpu" else "cpu" for k, v in inputs.items(): if torch.is_tensor(v): inputs[k] = v.to(target_device) print(f"Generating audio with cfg_scale: {cfg_scale}, inference_steps: {inference_steps}", file=sys.stderr) # Generate audio outputs = self.model.generate( **inputs, max_new_tokens=None, cfg_scale=cfg_scale, tokenizer=self.processor.tokenizer, generation_config={ 'do_sample': do_sample, 'temperature': temperature if do_sample else 1.0, 'top_p': top_p if do_sample else 1.0, }, verbose=False, all_prefilled_outputs=copy.deepcopy(prefilled_outputs) if prefilled_outputs is not None else None, ) # Save output if outputs.speech_outputs and outputs.speech_outputs[0] is not None: self.processor.save_audio( outputs.speech_outputs[0], # First (and only) batch item output_path=request.dst, ) print(f"Saved output to {request.dst}", file=sys.stderr) else: return backend_pb2.Result( success=False, message="No audio output generated" ) except Exception as err: print(f"Error in TTS: {err}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def AudioTranscription(self, request, context): """Transcribe audio file to text using ASR model.""" try: # Validate ASR mode is active if not self.asr_mode: return backend_pb2.TranscriptResult( segments=[], text="", ) # Note: We return empty result instead of error to match faster-whisper behavior # Get audio file path audio_path = request.dst if not audio_path or not os.path.exists(audio_path): print(f"Error: Audio file not found: {audio_path}", file=sys.stderr) return backend_pb2.TranscriptResult( segments=[], text="", ) print(f"Transcribing audio file: {audio_path}", file=sys.stderr) # Get context_info from options if available context_info = self.options.get("context_info", None) if context_info and isinstance(context_info, str) and context_info.strip(): context_info = context_info.strip() else: context_info = None # Process audio with ASR processor (matching gradio example) inputs = self.processor( audio=audio_path, sampling_rate=None, return_tensors="pt", add_generation_prompt=True, context_info=context_info ) # Move to device (matching gradio example) inputs = {k: v.to(self.device) if isinstance(v, torch.Tensor) else v for k, v in inputs.items()} # Prepare generation config (matching gradio example) generation_config = { "max_new_tokens": self.max_new_tokens, "temperature": self.temperature if self.temperature > 0 else None, "top_p": self.top_p if self.do_sample else None, "do_sample": self.do_sample, "num_beams": self.num_beams, "repetition_penalty": self.repetition_penalty, "pad_token_id": self.processor.pad_id, "eos_token_id": self.processor.tokenizer.eos_token_id, } # Remove None values (matching gradio example) generation_config = {k: v for k, v in generation_config.items() if v is not None} print(f"Generating transcription with max_new_tokens: {self.max_new_tokens}, temperature: {self.temperature}, do_sample: {self.do_sample}, num_beams: {self.num_beams}, repetition_penalty: {self.repetition_penalty}", file=sys.stderr) # Generate transcription (matching gradio example) with torch.no_grad(): output_ids = self.model.generate( **inputs, **generation_config ) # Decode output (matching gradio example) generated_ids = output_ids[0, inputs['input_ids'].shape[1]:] generated_text = self.processor.decode(generated_ids, skip_special_tokens=True) # Parse structured output to get segments result_segments = [] try: transcription_segments = self.processor.post_process_transcription(generated_text) if transcription_segments: # Map segments to TranscriptSegment format for idx, seg in enumerate(transcription_segments): # Extract timing information (if available) # Handle both dict and object with attributes if isinstance(seg, dict): start_time = seg.get('start_time', 0) end_time = seg.get('end_time', 0) text = seg.get('text', '') speaker_id = seg.get('speaker_id', None) else: # Handle object with attributes start_time = getattr(seg, 'start_time', 0) end_time = getattr(seg, 'end_time', 0) text = getattr(seg, 'text', '') speaker_id = getattr(seg, 'speaker_id', None) # Convert time to milliseconds (assuming seconds) start_ms = int(start_time * 1000) if isinstance(start_time, (int, float)) else 0 end_ms = int(end_time * 1000) if isinstance(end_time, (int, float)) else 0 # Add speaker info to text if available if speaker_id is not None: text = f"[Speaker {speaker_id}] {text}" result_segments.append(backend_pb2.TranscriptSegment( id=idx, start=start_ms, end=end_ms, text=text, tokens=[] # Token IDs not extracted for now )) except Exception as e: print(f"Warning: Failed to parse structured output: {e}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) # Fallback: create a single segment with the full text if generated_text: result_segments.append(backend_pb2.TranscriptSegment( id=0, start=0, end=0, text=generated_text, tokens=[] )) # Combine all segment texts into full transcription if result_segments: full_text = " ".join([seg.text for seg in result_segments]) else: full_text = generated_text if generated_text else "" print(f"Transcription completed: {len(result_segments)} segments", file=sys.stderr) return backend_pb2.TranscriptResult( segments=result_segments, text=full_text ) except Exception as err: print(f"Error in AudioTranscription: {err}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.TranscriptResult( segments=[], text="", ) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/vibevoice/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi # Use python 3.12 for l4t if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then PYTHON_VERSION="3.12" PYTHON_PATCH="12" PY_STANDALONE_TAG="20251120" fi if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then USE_PIP=true fi installRequirements if [ ! -d VibeVoice ]; then git clone https://github.com/microsoft/VibeVoice.git cd VibeVoice/ if [ "x${USE_PIP}" == "xtrue" ]; then pip install ${EXTRA_PIP_INSTALL_FLAGS:-} . else uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} . fi fi ================================================ FILE: backend/python/vibevoice/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu git+https://github.com/huggingface/diffusers opencv-python transformers>=4.51.3,<5.0.0 torchvision==0.22.1 accelerate compel peft sentencepiece torch==2.7.1 optimum-quanto ftfy llvmlite>=0.40.0 numba>=0.57.0 tqdm numpy scipy librosa ml-collections absl-py gradio av ================================================ FILE: backend/python/vibevoice/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu121 git+https://github.com/huggingface/diffusers opencv-python transformers>=4.51.3,<5.0.0 torchvision accelerate compel peft sentencepiece torch ftfy optimum-quanto llvmlite>=0.40.0 numba>=0.57.0 tqdm numpy scipy librosa ml-collections absl-py gradio av ================================================ FILE: backend/python/vibevoice/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 git+https://github.com/huggingface/diffusers opencv-python transformers>=4.51.3,<5.0.0 torchvision accelerate compel peft sentencepiece torch ftfy optimum-quanto llvmlite>=0.40.0 numba>=0.57.0 tqdm numpy scipy librosa ml-collections absl-py gradio av ================================================ FILE: backend/python/vibevoice/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.3 torch==2.7.1+rocm6.3 torchvision==0.22.1+rocm6.3 git+https://github.com/huggingface/diffusers opencv-python transformers>=4.51.3,<5.0.0 accelerate compel peft sentencepiece optimum-quanto ftfy llvmlite>=0.40.0 numba>=0.57.0 tqdm numpy scipy librosa ml-collections absl-py gradio av ================================================ FILE: backend/python/vibevoice/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch torchvision optimum[openvino] setuptools git+https://github.com/huggingface/diffusers opencv-python transformers>=4.51.3,<5.0.0 accelerate compel peft sentencepiece optimum-quanto ftfy llvmlite>=0.40.0 numba>=0.57.0 tqdm numpy scipy librosa ml-collections absl-py gradio av ================================================ FILE: backend/python/vibevoice/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ torch git+https://github.com/huggingface/diffusers transformers>=4.51.3,<5.0.0 accelerate compel peft optimum-quanto numpy<2 sentencepiece torchvision ftfy llvmlite>=0.40.0 numba>=0.57.0 tqdm numpy scipy librosa ml-collections absl-py gradio av ================================================ FILE: backend/python/vibevoice/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch git+https://github.com/huggingface/diffusers transformers>=4.51.3,<5.0.0 accelerate compel peft optimum-quanto numpy<2 sentencepiece torchvision ftfy llvmlite>=0.40.0 numba>=0.57.0 tqdm numpy scipy librosa ml-collections absl-py gradio av ================================================ FILE: backend/python/vibevoice/requirements-mps.txt ================================================ torch==2.7.1 torchvision==0.22.1 git+https://github.com/huggingface/diffusers opencv-python transformers>=4.51.3,<5.0.0 accelerate compel peft sentencepiece optimum-quanto ftfy llvmlite>=0.40.0 numba>=0.57.0 tqdm numpy scipy librosa ml-collections absl-py gradio av ================================================ FILE: backend/python/vibevoice/requirements.txt ================================================ grpcio==1.71.0 protobuf certifi packaging==24.1 ================================================ FILE: backend/python/vibevoice/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/vibevoice/test.py ================================================ """ A test script to test the gRPC service for VibeVoice TTS and ASR """ import unittest import subprocess import time import os import tempfile import shutil import backend_pb2 import backend_pb2_grpc import grpc # Check if we should skip ASR tests (they require large models ~14B parameters total) # Skip in CI or if explicitly disabled SKIP_ASR_TESTS = os.environ.get("SKIP_ASR_TESTS", "false").lower() == "true" class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(30) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_tts_model(self): """ This method tests if the TTS model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="microsoft/VibeVoice-Realtime-0.5B")) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() @unittest.skipIf(SKIP_ASR_TESTS, "ASR tests require large models (~14B parameters) and are skipped in CI") def test_load_asr_model(self): """ This method tests if the ASR model is loaded successfully with asr_mode option """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions( Model="microsoft/VibeVoice-ASR", Options=["asr_mode:true"] )) print(f"LoadModel response: {response}") if not response.success: print(f"LoadModel failed with message: {response.message}") self.assertTrue(response.success, f"LoadModel failed: {response.message}") self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(f"Exception during LoadModel: {err}") import traceback traceback.print_exc() self.fail("LoadModel service failed for ASR mode") finally: self.tearDown() def test_tts(self): """ This method tests if TTS generation works successfully """ # Create a temporary directory for the output audio file temp_dir = tempfile.mkdtemp() output_file = os.path.join(temp_dir, 'output.wav') try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) # Load TTS model response = stub.LoadModel(backend_pb2.ModelOptions(Model="microsoft/VibeVoice-Realtime-0.5B")) self.assertTrue(response.success) # Generate TTS tts_request = backend_pb2.TTSRequest( text="Hello, this is a test of the VibeVoice text to speech system.", dst=output_file ) tts_response = stub.TTS(tts_request) # Verify response self.assertIsNotNone(tts_response) self.assertTrue(tts_response.success) # Verify output file was created self.assertTrue(os.path.exists(output_file), f"Output file was not created: {output_file}") self.assertGreater(os.path.getsize(output_file), 0, "Output file is empty") except Exception as err: print(err) self.fail("TTS service failed") finally: self.tearDown() # Clean up the temporary directory if os.path.exists(temp_dir): shutil.rmtree(temp_dir) @unittest.skipIf(SKIP_ASR_TESTS, "ASR tests require large models (~14B parameters) and are skipped in CI") def test_audio_transcription(self): """ This method tests if audio transcription works successfully """ # Create a temporary directory for the audio file temp_dir = tempfile.mkdtemp() audio_file = os.path.join(temp_dir, 'audio.wav') try: # Download the audio file to the temporary directory print(f"Downloading audio file to {audio_file}...") url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav" result = subprocess.run( ["wget", "-q", url, "-O", audio_file], capture_output=True, text=True ) if result.returncode != 0: self.fail(f"Failed to download audio file: {result.stderr}") # Verify the file was downloaded if not os.path.exists(audio_file): self.fail(f"Audio file was not downloaded to {audio_file}") self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) # Load the ASR model first load_response = stub.LoadModel(backend_pb2.ModelOptions( Model="microsoft/VibeVoice-ASR", Options=["asr_mode:true"] )) print(f"LoadModel response: {load_response}") if not load_response.success: print(f"LoadModel failed with message: {load_response.message}") self.assertTrue(load_response.success, f"LoadModel failed: {load_response.message}") # Perform transcription transcript_request = backend_pb2.TranscriptRequest(dst=audio_file) transcript_response = stub.AudioTranscription(transcript_request) # Print the transcribed text for debugging print(f"Transcribed text: {transcript_response.text}") print(f"Number of segments: {len(transcript_response.segments)}") # Verify response structure self.assertIsNotNone(transcript_response) self.assertIsNotNone(transcript_response.text) # Protobuf repeated fields return a sequence, not a list self.assertIsNotNone(transcript_response.segments) # Check if segments is iterable (has length) self.assertGreaterEqual(len(transcript_response.segments), 0) # Verify the transcription contains some text self.assertGreater(len(transcript_response.text), 0, "Transcription should not be empty") # If we got segments, verify they have the expected structure if len(transcript_response.segments) > 0: segment = transcript_response.segments[0] self.assertIsNotNone(segment.text) self.assertIsInstance(segment.id, int) else: # Even if no segments, we should have text self.assertIsNotNone(transcript_response.text) self.assertGreater(len(transcript_response.text), 0) except Exception as err: print(err) self.fail("AudioTranscription service failed") finally: self.tearDown() # Clean up the temporary directory if os.path.exists(temp_dir): shutil.rmtree(temp_dir) ================================================ FILE: backend/python/vibevoice/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/vllm/Makefile ================================================ .PHONY: vllm vllm: bash install.sh .PHONY: run run: vllm @echo "Running vllm..." bash run.sh @echo "vllm run." .PHONY: test test: vllm @echo "Testing vllm..." bash test.sh @echo "vllm tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/vllm/README.md ================================================ # Creating a separate environment for the vllm project ``` make vllm ``` ================================================ FILE: backend/python/vllm/backend.py ================================================ #!/usr/bin/env python3 import asyncio from concurrent import futures import argparse import signal import sys import os from typing import List from PIL import Image import backend_pb2 import backend_pb2_grpc import grpc from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.multimodal.utils import fetch_image from vllm.assets.video import VideoAsset import base64 import io _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ A gRPC servicer that implements the Backend service defined in backend.proto. """ def generate(self,prompt, max_new_tokens): """ Generates text based on the given prompt and maximum number of new tokens. Args: prompt (str): The prompt to generate text from. max_new_tokens (int): The maximum number of new tokens to generate. Returns: str: The generated text. """ self.generator.end_beam_search() # Tokenizing the input ids = self.generator.tokenizer.encode(prompt) self.generator.gen_begin_reuse(ids) initial_len = self.generator.sequence[0].shape[0] has_leading_space = False decoded_text = '' for i in range(max_new_tokens): token = self.generator.gen_single_token() if i == 0 and self.generator.tokenizer.tokenizer.IdToPiece(int(token)).startswith('▁'): has_leading_space = True decoded_text = self.generator.tokenizer.decode(self.generator.sequence[0][initial_len:]) if has_leading_space: decoded_text = ' ' + decoded_text if token.item() == self.generator.tokenizer.eos_token_id: break return decoded_text def Health(self, request, context): """ Returns a health check message. Args: request: The health check request. context: The gRPC context. Returns: backend_pb2.Reply: The health check reply. """ return backend_pb2.Reply(message=bytes("OK", 'utf-8')) async def LoadModel(self, request, context): """ Loads a language model. Args: request: The load model request. context: The gRPC context. Returns: backend_pb2.Result: The load model result. """ engine_args = AsyncEngineArgs( model=request.Model, ) if request.Quantization != "": engine_args.quantization = request.Quantization if request.LoadFormat != "": engine_args.load_format = request.LoadFormat if request.GPUMemoryUtilization != 0: engine_args.gpu_memory_utilization = request.GPUMemoryUtilization if request.TrustRemoteCode: engine_args.trust_remote_code = request.TrustRemoteCode if request.EnforceEager: engine_args.enforce_eager = request.EnforceEager if request.TensorParallelSize: engine_args.tensor_parallel_size = request.TensorParallelSize if request.SwapSpace != 0: engine_args.swap_space = request.SwapSpace if request.MaxModelLen != 0: engine_args.max_model_len = request.MaxModelLen if request.DisableLogStatus: engine_args.disable_log_status = request.DisableLogStatus if request.DType != "": engine_args.dtype = request.DType if request.LimitImagePerPrompt != 0 or request.LimitVideoPerPrompt != 0 or request.LimitAudioPerPrompt != 0: # limit-mm-per-prompt defaults to 1 per modality, based on vLLM docs engine_args.limit_mm_per_prompt = { "image": max(request.LimitImagePerPrompt, 1), "video": max(request.LimitVideoPerPrompt, 1), "audio": max(request.LimitAudioPerPrompt, 1) } try: self.llm = AsyncLLMEngine.from_engine_args(engine_args) except Exception as err: print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") try: engine_model_config = await self.llm.get_model_config() self.tokenizer = get_tokenizer( engine_model_config.tokenizer, tokenizer_mode=engine_model_config.tokenizer_mode, trust_remote_code=engine_model_config.trust_remote_code, truncation_side="left", ) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") print("Model loaded successfully", file=sys.stderr) return backend_pb2.Result(message="Model loaded successfully", success=True) async def Predict(self, request, context): """ Generates text based on the given prompt and sampling parameters. Args: request: The predict request. context: The gRPC context. Returns: backend_pb2.Reply: The predict result. """ gen = self._predict(request, context, streaming=False) res = await gen.__anext__() return res def Embedding(self, request, context): """ A gRPC method that calculates embeddings for a given sentence. Args: request: An EmbeddingRequest object that contains the request parameters. context: A grpc.ServicerContext object that provides information about the RPC. Returns: An EmbeddingResult object that contains the calculated embeddings. """ print("Calculated embeddings for: " + request.Embeddings, file=sys.stderr) outputs = self.model.encode(request.Embeddings) # Check if we have one result at least if len(outputs) == 0: context.set_code(grpc.StatusCode.INVALID_ARGUMENT) context.set_details("No embeddings were calculated.") return backend_pb2.EmbeddingResult() return backend_pb2.EmbeddingResult(embeddings=outputs[0].outputs.embedding) async def PredictStream(self, request, context): """ Generates text based on the given prompt and sampling parameters, and streams the results. Args: request: The predict stream request. context: The gRPC context. Returns: backend_pb2.Result: The predict stream result. """ iterations = self._predict(request, context, streaming=True) try: async for iteration in iterations: yield iteration finally: await iterations.aclose() async def _predict(self, request, context, streaming=False): # Build the sampling parameters # NOTE: this must stay in sync with the vllm backend request_to_sampling_params = { "N": "n", "PresencePenalty": "presence_penalty", "FrequencyPenalty": "frequency_penalty", "RepetitionPenalty": "repetition_penalty", "Temperature": "temperature", "TopP": "top_p", "TopK": "top_k", "MinP": "min_p", "Seed": "seed", "StopPrompts": "stop", "StopTokenIds": "stop_token_ids", "BadWords": "bad_words", "IncludeStopStrInOutput": "include_stop_str_in_output", "IgnoreEOS": "ignore_eos", "Tokens": "max_tokens", "MinTokens": "min_tokens", "Logprobs": "logprobs", "PromptLogprobs": "prompt_logprobs", "SkipSpecialTokens": "skip_special_tokens", "SpacesBetweenSpecialTokens": "spaces_between_special_tokens", "TruncatePromptTokens": "truncate_prompt_tokens", "GuidedDecoding": "guided_decoding", } sampling_params = SamplingParams(top_p=0.9, max_tokens=200) for request_field, param_field in request_to_sampling_params.items(): if hasattr(request, request_field): value = getattr(request, request_field) if value not in (None, 0, [], False): setattr(sampling_params, param_field, value) # Extract image paths and process images prompt = request.Prompt image_paths = request.Images image_data = [self.load_image(img_path) for img_path in image_paths] videos_path = request.Videos video_data = [self.load_video(video_path) for video_path in videos_path] # If tokenizer template is enabled and messages are provided instead of prompt, apply the tokenizer template if not request.Prompt and request.UseTokenizerTemplate and request.Messages: prompt = self.tokenizer.apply_chat_template(request.Messages, tokenize=False, add_generation_prompt=True) # Generate text using the LLM engine request_id = random_uuid() print(f"Generating text with request_id: {request_id}", file=sys.stderr) multi_modal_data = {} if image_data: multi_modal_data["image"] = image_data if video_data: multi_modal_data["video"] = video_data outputs = self.llm.generate( { "prompt": prompt, "multi_modal_data": multi_modal_data if multi_modal_data else None, }, sampling_params=sampling_params, request_id=request_id, ) # Stream the results generated_text = "" try: async for request_output in outputs: iteration_text = request_output.outputs[0].text if streaming: # Remove text already sent as vllm concatenates the text from previous yields delta_iteration_text = iteration_text.removeprefix(generated_text) # Send the partial result yield backend_pb2.Reply(message=bytes(delta_iteration_text, encoding='utf-8')) # Keep track of text generated generated_text = iteration_text finally: await outputs.aclose() # If streaming, we already sent everything if streaming: return # Remove the image files from /tmp folder for img_path in image_paths: try: os.remove(img_path) except Exception as e: print(f"Error removing image file: {img_path}, {e}", file=sys.stderr) # Sending the final generated text yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) def load_image(self, image_path: str): """ Load an image from the given file path or base64 encoded data. Args: image_path (str): The path to the image file or base64 encoded data. Returns: Image: The loaded image. """ try: image_data = base64.b64decode(image_path) image = Image.open(io.BytesIO(image_data)) return image except Exception as e: print(f"Error loading image {image_path}: {e}", file=sys.stderr) return None def load_video(self, video_path: str): """ Load a video from the given file path. Args: video_path (str): The path to the image file. Returns: Video: The loaded video. """ try: timestamp = str(int(time.time() * 1000)) # Generate timestamp p = f"/tmp/vl-{timestamp}.data" # Use timestamp in filename with open(p, "wb") as f: f.write(base64.b64decode(video_path)) video = VideoAsset(name=p).np_ndarrays os.remove(p) return video except Exception as e: print(f"Error loading video {video_path}: {e}", file=sys.stderr) return None async def serve(address): # Start asyncio gRPC server server = grpc.aio.server(migration_thread_pool=futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) # Add the servicer to the server backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) # Bind the server to the address server.add_insecure_port(address) # Gracefully shutdown the server on SIGTERM or SIGINT loop = asyncio.get_event_loop() for sig in (signal.SIGINT, signal.SIGTERM): loop.add_signal_handler( sig, lambda: asyncio.ensure_future(server.stop(5)) ) # Start the server await server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Wait for the server to be terminated await server.wait_for_termination() if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() asyncio.run(serve(args.addr)) ================================================ FILE: backend/python/vllm/install.sh ================================================ #!/bin/bash set -e EXTRA_PIP_INSTALL_FLAGS="--no-build-isolation" # Avoid to overcommit the CPU during build # https://github.com/vllm-project/vllm/issues/20079 # https://docs.vllm.ai/en/v0.8.3/serving/env_vars.html # https://docs.redhat.com/it/documentation/red_hat_ai_inference_server/3.0/html/vllm_server_arguments/environment_variables-server-arguments export NVCC_THREADS=2 export MAX_JOBS=1 backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # This is here because the Intel pip index is broken and returns 200 status codes for every package name, it just doesn't return any package links. # This makes uv think that the package exists in the Intel pip index, and by default it stops looking at other pip indexes once it finds a match. # We need uv to continue falling through to the pypi default index to find optimum[openvino] in the pypi index # the --upgrade actually allows us to *downgrade* torch to the version provided in the Intel pip index if [ "x${BUILD_PROFILE}" == "xintel" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --upgrade --index-strategy=unsafe-first-match" fi # We don't embed this into the images as it is a large dependency and not always needed. # Besides, the speed inference are not actually usable in the current state for production use-cases. if [ "x${BUILD_TYPE}" == "x" ] && [ "x${FROM_SOURCE:-}" == "xtrue" ]; then ensureVenv # https://docs.vllm.ai/en/v0.6.1/getting_started/cpu-installation.html if [ ! -d vllm ]; then git clone https://github.com/vllm-project/vllm fi pushd vllm uv pip install wheel packaging ninja "setuptools>=49.4.0" numpy typing-extensions pillow setuptools-scm grpcio==1.68.1 protobuf bitsandbytes uv pip install -v -r requirements-cpu.txt --extra-index-url https://download.pytorch.org/whl/cpu VLLM_TARGET_DEVICE=cpu python setup.py install popd rm -rf vllm else installRequirements fi ================================================ FILE: backend/python/vllm/requirements-after.txt ================================================ vllm ================================================ FILE: backend/python/vllm/requirements-cpu.txt ================================================ accelerate torch==2.7.0 transformers ================================================ FILE: backend/python/vllm/requirements-cublas12-after.txt ================================================ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ================================================ FILE: backend/python/vllm/requirements-cublas12.txt ================================================ accelerate torch==2.7.0 transformers bitsandbytes ================================================ FILE: backend/python/vllm/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.4 accelerate torch transformers bitsandbytes ================================================ FILE: backend/python/vllm/requirements-install.txt ================================================ # mabma does not specify it's build dependencies per PEP517, so we need to disable build isolation # this also means that we need to install the basic build dependencies into the venv ourselves # https://github.com/Dao-AILab/causal-conv1d/issues/24 packaging setuptools wheel ================================================ FILE: backend/python/vllm/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu accelerate torch transformers optimum[openvino] setuptools bitsandbytes ================================================ FILE: backend/python/vllm/requirements.txt ================================================ grpcio==1.78.1 protobuf certifi setuptools ================================================ FILE: backend/python/vllm/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/vllm/test.py ================================================ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc import unittest import subprocess import time import grpc import backend_pb2_grpc import backend_pb2 class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service. This class contains methods to test the startup and shutdown of the gRPC service. """ def setUp(self): self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"]) time.sleep(10) def tearDown(self) -> None: self.service.terminate() self.service.wait() def test_server_startup(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_text(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) self.assertTrue(response.success) req = backend_pb2.PredictOptions(Prompt="The capital of France is") resp = stub.Predict(req) self.assertIsNotNone(resp.message) except Exception as err: print(err) self.fail("text service failed") finally: self.tearDown() def test_sampling_params(self): """ This method tests if all sampling parameters are correctly processed NOTE: this does NOT test for correctness, just that we received a compatible response """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="facebook/opt-125m")) self.assertTrue(response.success) req = backend_pb2.PredictOptions( Prompt="The capital of France is", TopP=0.8, Tokens=50, Temperature=0.7, TopK=40, PresencePenalty=0.1, FrequencyPenalty=0.2, RepetitionPenalty=1.1, MinP=0.05, Seed=42, StopPrompts=["\n"], StopTokenIds=[50256], BadWords=["badword"], IncludeStopStrInOutput=True, IgnoreEOS=True, MinTokens=5, Logprobs=5, PromptLogprobs=5, SkipSpecialTokens=True, SpacesBetweenSpecialTokens=True, TruncatePromptTokens=10, GuidedDecoding=True, N=2, ) resp = stub.Predict(req) self.assertIsNotNone(resp.message) self.assertIsNotNone(resp.logprobs) except Exception as err: print(err) self.fail("sampling params service failed") finally: self.tearDown() def test_embedding(self): """ This method tests if the embeddings are generated successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct")) self.assertTrue(response.success) embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.") embedding_response = stub.Embedding(embedding_request) self.assertIsNotNone(embedding_response.embeddings) # assert that is a list of floats self.assertIsInstance(embedding_response.embeddings, list) # assert that the list is not empty self.assertTrue(len(embedding_response.embeddings) > 0) except Exception as err: print(err) self.fail("Embedding service failed") finally: self.tearDown() ================================================ FILE: backend/python/vllm/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/vllm-omni/Makefile ================================================ .PHONY: vllm-omni vllm-omni: bash install.sh .PHONY: run run: vllm-omni @echo "Running vllm-omni..." bash run.sh @echo "vllm-omni run." .PHONY: test test: vllm-omni @echo "Testing vllm-omni..." bash test.sh @echo "vllm-omni tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/vllm-omni/backend.py ================================================ #!/usr/bin/env python3 """ LocalAI vLLM-Omni Backend This backend provides gRPC access to vllm-omni for multimodal generation: - Image generation (text-to-image, image editing) - Video generation (text-to-video, image-to-video) - Text generation with multimodal inputs (LLM) - Text-to-speech generation """ from concurrent import futures import traceback import argparse import signal import sys import time import os import base64 import io from PIL import Image import torch import numpy as np import soundfile as sf import backend_pb2 import backend_pb2_grpc import grpc from vllm_omni.entrypoints.omni import Omni from vllm_omni.outputs import OmniRequestOutput from vllm_omni.diffusion.data import DiffusionParallelConfig from vllm_omni.utils.platform_utils import detect_device_type, is_npu from vllm import SamplingParams from diffusers.utils import export_to_video _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): def _detect_model_type(self, model_name): """Detect model type from model name.""" model_lower = model_name.lower() if "tts" in model_lower or "qwen3-tts" in model_lower: return "tts" elif "omni" in model_lower and "qwen3" in model_lower: return "llm" elif "wan" in model_lower or "t2v" in model_lower or "i2v" in model_lower: return "video" elif "image" in model_lower or "z-image" in model_lower or "qwen-image" in model_lower: return "image" else: # Default to image for diffusion models, llm for others return "image" def _detect_tts_task_type(self): """Detect TTS task type from model name.""" model_lower = self.model_name.lower() if "customvoice" in model_lower: return "CustomVoice" elif "voicedesign" in model_lower: return "VoiceDesign" elif "base" in model_lower: return "Base" else: # Default to CustomVoice return "CustomVoice" def _load_image(self, image_path): """Load an image from file path or base64 encoded data.""" # Try file path first if os.path.exists(image_path): return Image.open(image_path) # Try base64 decode try: image_data = base64.b64decode(image_path) return Image.open(io.BytesIO(image_data)) except: return None def _load_video(self, video_path): """Load a video from file path or base64 encoded data.""" from vllm.assets.video import VideoAsset, video_to_ndarrays if os.path.exists(video_path): return video_to_ndarrays(video_path, num_frames=16) # Try base64 decode try: timestamp = str(int(time.time() * 1000)) p = f"/tmp/vl-{timestamp}.data" with open(p, "wb") as f: f.write(base64.b64decode(video_path)) video = VideoAsset(name=p).np_ndarrays os.remove(p) return video except: return None def _load_audio(self, audio_path): """Load audio from file path or base64 encoded data.""" import librosa if os.path.exists(audio_path): audio_signal, sr = librosa.load(audio_path, sr=16000) return (audio_signal.astype(np.float32), sr) # Try base64 decode try: audio_data = base64.b64decode(audio_path) # Save to temp file and load timestamp = str(int(time.time() * 1000)) p = f"/tmp/audio-{timestamp}.wav" with open(p, "wb") as f: f.write(audio_data) audio_signal, sr = librosa.load(p, sr=16000) os.remove(p) return (audio_signal.astype(np.float32), sr) except: return None def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): try: print(f"Loading model {request.Model}...", file=sys.stderr) print(f"Request {request}", file=sys.stderr) # Parse options from request.Options (key:value pairs) self.options = {} for opt in request.Options: if ":" not in opt: continue key, value = opt.split(":", 1) # Convert value to appropriate type if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value print(f"Options: {self.options}", file=sys.stderr) # Detect model type self.model_name = request.Model self.model_type = request.Type if request.Type else self._detect_model_type(request.Model) print(f"Detected model type: {self.model_type}", file=sys.stderr) # Build DiffusionParallelConfig if diffusion model (image or video) parallel_config = None if self.model_type in ["image", "video"]: parallel_config = DiffusionParallelConfig( ulysses_degree=self.options.get("ulysses_degree", 1), ring_degree=self.options.get("ring_degree", 1), cfg_parallel_size=self.options.get("cfg_parallel_size", 1), tensor_parallel_size=self.options.get("tensor_parallel_size", 1), ) # Build cache_config dict if cache_backend specified cache_backend = self.options.get("cache_backend") # "cache_dit" or "tea_cache" cache_config = None if cache_backend == "cache_dit": cache_config = { "Fn_compute_blocks": self.options.get("cache_dit_fn_compute_blocks", 1), "Bn_compute_blocks": self.options.get("cache_dit_bn_compute_blocks", 0), "max_warmup_steps": self.options.get("cache_dit_max_warmup_steps", 4), "residual_diff_threshold": self.options.get("cache_dit_residual_diff_threshold", 0.24), "max_continuous_cached_steps": self.options.get("cache_dit_max_continuous_cached_steps", 3), "enable_taylorseer": self.options.get("cache_dit_enable_taylorseer", False), "taylorseer_order": self.options.get("cache_dit_taylorseer_order", 1), "scm_steps_mask_policy": self.options.get("cache_dit_scm_steps_mask_policy"), "scm_steps_policy": self.options.get("cache_dit_scm_steps_policy", "dynamic"), } elif cache_backend == "tea_cache": cache_config = { "rel_l1_thresh": self.options.get("tea_cache_rel_l1_thresh", 0.2), } # Base Omni initialization parameters omni_kwargs = { "model": request.Model, } # Add diffusion-specific parameters (image/video models) if self.model_type in ["image", "video"]: omni_kwargs.update({ "vae_use_slicing": is_npu(), "vae_use_tiling": is_npu(), "cache_backend": cache_backend, "cache_config": cache_config, "parallel_config": parallel_config, "enforce_eager": self.options.get("enforce_eager", request.EnforceEager), "enable_cpu_offload": self.options.get("enable_cpu_offload", False), }) # Video-specific parameters if self.model_type == "video": omni_kwargs.update({ "boundary_ratio": self.options.get("boundary_ratio", 0.875), "flow_shift": self.options.get("flow_shift", 5.0), }) # Add LLM/TTS-specific parameters if self.model_type in ["llm", "tts"]: omni_kwargs.update({ "stage_configs_path": self.options.get("stage_configs_path"), "log_stats": self.options.get("enable_stats", False), "stage_init_timeout": self.options.get("stage_init_timeout", 300), }) # vllm engine options (passed through Omni for LLM/TTS) if request.GPUMemoryUtilization > 0: omni_kwargs["gpu_memory_utilization"] = request.GPUMemoryUtilization if request.TensorParallelSize > 0: omni_kwargs["tensor_parallel_size"] = request.TensorParallelSize if request.TrustRemoteCode: omni_kwargs["trust_remote_code"] = request.TrustRemoteCode if request.MaxModelLen > 0: omni_kwargs["max_model_len"] = request.MaxModelLen self.omni = Omni(**omni_kwargs) print("Model loaded successfully", file=sys.stderr) return backend_pb2.Result(message="Model loaded successfully", success=True) except Exception as err: print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) traceback.print_exc() return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") def GenerateImage(self, request, context): try: # Validate model is loaded and is image/diffusion type if not hasattr(self, 'omni'): return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.") if self.model_type not in ["image"]: return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support image generation") # Extract parameters prompt = request.positive_prompt negative_prompt = request.negative_prompt if request.negative_prompt else None width = request.width if request.width > 0 else 1024 height = request.height if request.height > 0 else 1024 seed = request.seed if request.seed > 0 else None num_inference_steps = request.step if request.step > 0 else 50 cfg_scale = self.options.get("cfg_scale", 4.0) guidance_scale = self.options.get("guidance_scale", 1.0) # Create generator if seed provided generator = None if seed: device = detect_device_type() generator = torch.Generator(device=device).manual_seed(seed) # Handle image input for image editing pil_image = None if request.src or (request.ref_images and len(request.ref_images) > 0): image_path = request.ref_images[0] if request.ref_images else request.src pil_image = self._load_image(image_path) if pil_image is None: return backend_pb2.Result(success=False, message=f"Invalid image source: {image_path}") pil_image = pil_image.convert("RGB") # Build generate kwargs generate_kwargs = { "prompt": prompt, "negative_prompt": negative_prompt, "height": height, "width": width, "generator": generator, "true_cfg_scale": cfg_scale, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, } if pil_image: generate_kwargs["pil_image"] = pil_image # Call omni.generate() outputs = self.omni.generate(**generate_kwargs) # Extract images (following example pattern) if not outputs or len(outputs) == 0: return backend_pb2.Result(success=False, message="No output generated") first_output = outputs[0] if not hasattr(first_output, "request_output") or not first_output.request_output: return backend_pb2.Result(success=False, message="Invalid output structure") req_out = first_output.request_output[0] if not isinstance(req_out, OmniRequestOutput) or not hasattr(req_out, "images"): return backend_pb2.Result(success=False, message="No images in output") images = req_out.images if not images or len(images) == 0: return backend_pb2.Result(success=False, message="Empty images list") # Save image output_image = images[0] output_image.save(request.dst) return backend_pb2.Result(message="Image generated successfully", success=True) except Exception as err: print(f"Error generating image: {err}", file=sys.stderr) traceback.print_exc() return backend_pb2.Result(success=False, message=f"Error generating image: {err}") def GenerateVideo(self, request, context): try: # Validate model is loaded and is video/diffusion type if not hasattr(self, 'omni'): return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.") if self.model_type not in ["video"]: return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support video generation") # Extract parameters prompt = request.prompt negative_prompt = request.negative_prompt if request.negative_prompt else "" width = request.width if request.width > 0 else 1280 height = request.height if request.height > 0 else 720 num_frames = request.num_frames if request.num_frames > 0 else 81 fps = request.fps if request.fps > 0 else 24 seed = request.seed if request.seed > 0 else None guidance_scale = request.cfg_scale if request.cfg_scale > 0 else 4.0 guidance_scale_high = self.options.get("guidance_scale_high") num_inference_steps = request.step if request.step > 0 else 40 # Create generator generator = None if seed: device = detect_device_type() generator = torch.Generator(device=device).manual_seed(seed) # Handle image input for image-to-video pil_image = None if request.start_image: pil_image = self._load_image(request.start_image) if pil_image is None: return backend_pb2.Result(success=False, message=f"Invalid start_image: {request.start_image}") pil_image = pil_image.convert("RGB") # Resize to target dimensions pil_image = pil_image.resize((width, height), Image.Resampling.LANCZOS) # Build generate kwargs generate_kwargs = { "prompt": prompt, "negative_prompt": negative_prompt, "height": height, "width": width, "generator": generator, "guidance_scale": guidance_scale, "num_inference_steps": num_inference_steps, "num_frames": num_frames, } if pil_image: generate_kwargs["pil_image"] = pil_image if guidance_scale_high: generate_kwargs["guidance_scale_2"] = guidance_scale_high # Call omni.generate() frames = self.omni.generate(**generate_kwargs) # Extract video frames (following example pattern) if isinstance(frames, list) and len(frames) > 0: first_item = frames[0] if hasattr(first_item, "final_output_type"): if first_item.final_output_type != "image": return backend_pb2.Result(success=False, message=f"Unexpected output type: {first_item.final_output_type}") # Pipeline mode: extract from nested request_output if hasattr(first_item, "is_pipeline_output") and first_item.is_pipeline_output: if isinstance(first_item.request_output, list) and len(first_item.request_output) > 0: inner_output = first_item.request_output[0] if isinstance(inner_output, OmniRequestOutput) and hasattr(inner_output, "images"): frames = inner_output.images[0] if inner_output.images else None # Diffusion mode: use direct images field elif hasattr(first_item, "images") and first_item.images: frames = first_item.images else: return backend_pb2.Result(success=False, message="No video frames found") if frames is None: return backend_pb2.Result(success=False, message="No video frames found in output") # Convert frames to numpy array (following example) if isinstance(frames, torch.Tensor): video_tensor = frames.detach().cpu() # Handle different tensor shapes [B, C, F, H, W] or [B, F, H, W, C] if video_tensor.dim() == 5: if video_tensor.shape[1] in (3, 4): video_tensor = video_tensor[0].permute(1, 2, 3, 0) else: video_tensor = video_tensor[0] elif video_tensor.dim() == 4 and video_tensor.shape[0] in (3, 4): video_tensor = video_tensor.permute(1, 2, 3, 0) # Normalize from [-1,1] to [0,1] if float if video_tensor.is_floating_point(): video_tensor = video_tensor.clamp(-1, 1) * 0.5 + 0.5 video_array = video_tensor.float().numpy() else: video_array = frames if hasattr(video_array, "shape") and video_array.ndim == 5: video_array = video_array[0] # Convert 4D array (frames, H, W, C) to list of frames if isinstance(video_array, np.ndarray) and video_array.ndim == 4: video_array = list(video_array) # Save video export_to_video(video_array, request.dst, fps=fps) return backend_pb2.Result(message="Video generated successfully", success=True) except Exception as err: print(f"Error generating video: {err}", file=sys.stderr) traceback.print_exc() return backend_pb2.Result(success=False, message=f"Error generating video: {err}") def Predict(self, request, context): """Non-streaming text generation with multimodal inputs.""" gen = self._predict(request, context, streaming=False) try: res = next(gen) return res except StopIteration: return backend_pb2.Reply(message=bytes("", 'utf-8')) def PredictStream(self, request, context): """Streaming text generation with multimodal inputs.""" return self._predict(request, context, streaming=True) def _predict(self, request, context, streaming=False): """Internal method for text generation (streaming and non-streaming).""" try: # Validate model is loaded and is LLM type if not hasattr(self, 'omni'): yield backend_pb2.Reply(message=bytes("Model not loaded. Call LoadModel first.", 'utf-8')) return if self.model_type not in ["llm"]: yield backend_pb2.Reply(message=bytes(f"Model type {self.model_type} does not support text generation", 'utf-8')) return # Extract prompt if request.Prompt: prompt = request.Prompt elif request.Messages and request.UseTokenizerTemplate: # Build prompt from messages (simplified - would need tokenizer for full template) prompt = "" for msg in request.Messages: role = msg.role content = msg.content prompt += f"<|im_start|>{role}\n{content}<|im_end|>\n" prompt += "<|im_start|>assistant\n" else: yield backend_pb2.Reply(message=bytes("", 'utf-8')) return # Build multi_modal_data dict multi_modal_data = {} # Process images if request.Images: image_data = [] for img_path in request.Images: img = self._load_image(img_path) if img: # Convert to format expected by vllm from vllm.multimodal.image import convert_image_mode img_data = convert_image_mode(img, "RGB") image_data.append(img_data) if image_data: multi_modal_data["image"] = image_data # Process videos if request.Videos: video_data = [] for video_path in request.Videos: video = self._load_video(video_path) if video is not None: video_data.append(video) if video_data: multi_modal_data["video"] = video_data # Process audio if request.Audios: audio_data = [] for audio_path in request.Audios: audio = self._load_audio(audio_path) if audio is not None: audio_data.append(audio) if audio_data: multi_modal_data["audio"] = audio_data # Build inputs dict inputs = { "prompt": prompt, "multi_modal_data": multi_modal_data if multi_modal_data else None, } # Build sampling params sampling_params = SamplingParams( temperature=request.Temperature if request.Temperature > 0 else 0.7, top_p=request.TopP if request.TopP > 0 else 0.9, top_k=request.TopK if request.TopK > 0 else -1, max_tokens=request.Tokens if request.Tokens > 0 else 200, presence_penalty=request.PresencePenalty if request.PresencePenalty != 0 else 0.0, frequency_penalty=request.FrequencyPenalty if request.FrequencyPenalty != 0 else 0.0, repetition_penalty=request.RepetitionPenalty if request.RepetitionPenalty != 0 else 1.0, seed=request.Seed if request.Seed > 0 else None, stop=request.StopPrompts if request.StopPrompts else None, stop_token_ids=request.StopTokenIds if request.StopTokenIds else None, ignore_eos=request.IgnoreEOS, ) sampling_params_list = [sampling_params] # Call omni.generate() (returns generator for LLM mode) omni_generator = self.omni.generate([inputs], sampling_params_list) # Extract text from outputs generated_text = "" for stage_outputs in omni_generator: if stage_outputs.final_output_type == "text": for output in stage_outputs.request_output: text_output = output.outputs[0].text if streaming: # Remove already sent text (vllm concatenates) delta_text = text_output.removeprefix(generated_text) yield backend_pb2.Reply(message=bytes(delta_text, encoding='utf-8')) generated_text = text_output if not streaming: yield backend_pb2.Reply(message=bytes(generated_text, encoding='utf-8')) except Exception as err: print(f"Error in Predict: {err}", file=sys.stderr) traceback.print_exc() yield backend_pb2.Reply(message=bytes(f"Error: {err}", encoding='utf-8')) def TTS(self, request, context): try: # Validate model is loaded and is TTS type if not hasattr(self, 'omni'): return backend_pb2.Result(success=False, message="Model not loaded. Call LoadModel first.") if self.model_type not in ["tts"]: return backend_pb2.Result(success=False, message=f"Model type {self.model_type} does not support TTS") # Extract parameters text = request.text language = request.language if request.language else "Auto" voice = request.voice if request.voice else None task_type = self._detect_tts_task_type() # Build prompt with chat template # TODO: for now vllm-omni supports only qwen3-tts, so we hardcode it, however, we want to support other models in the future. # and we might need to use the chat template here prompt = f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n" # Build inputs dict inputs = { "prompt": prompt, "additional_information": { "task_type": [task_type], "text": [text], "language": [language], "max_new_tokens": [2048], } } # Add task-specific fields if task_type == "CustomVoice": if voice: inputs["additional_information"]["speaker"] = [voice] # Add instruct if provided in options if "instruct" in self.options: inputs["additional_information"]["instruct"] = [self.options["instruct"]] elif task_type == "VoiceDesign": if "instruct" in self.options: inputs["additional_information"]["instruct"] = [self.options["instruct"]] inputs["additional_information"]["non_streaming_mode"] = [True] elif task_type == "Base": # Voice cloning requires ref_audio and ref_text if "ref_audio" in self.options: inputs["additional_information"]["ref_audio"] = [self.options["ref_audio"]] if "ref_text" in self.options: inputs["additional_information"]["ref_text"] = [self.options["ref_text"]] if "x_vector_only_mode" in self.options: inputs["additional_information"]["x_vector_only_mode"] = [self.options["x_vector_only_mode"]] # Build sampling params sampling_params = SamplingParams( temperature=0.9, top_p=1.0, top_k=50, max_tokens=2048, seed=42, detokenize=False, repetition_penalty=1.05, ) sampling_params_list = [sampling_params] # Call omni.generate() omni_generator = self.omni.generate(inputs, sampling_params_list) # Extract audio (following TTS example) for stage_outputs in omni_generator: for output in stage_outputs.request_output: if "audio" in output.multimodal_output: audio_tensor = output.multimodal_output["audio"] audio_samplerate = output.multimodal_output["sr"].item() # Convert to numpy audio_numpy = audio_tensor.float().detach().cpu().numpy() if audio_numpy.ndim > 1: audio_numpy = audio_numpy.flatten() # Save audio file sf.write(request.dst, audio_numpy, samplerate=audio_samplerate, format="WAV") return backend_pb2.Result(message="TTS audio generated successfully", success=True) return backend_pb2.Result(success=False, message="No audio output generated") except Exception as err: print(f"Error generating TTS: {err}", file=sys.stderr) traceback.print_exc() return backend_pb2.Result(success=False, message=f"Error generating TTS: {err}") def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), ('grpc.max_receive_message_length', 50 * 1024 * 1024), ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Signal handlers for graceful shutdown def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/vllm-omni/install.sh ================================================ #!/bin/bash set -e PYTHON_VERSION="3.12" PYTHON_PATCH="12" PY_STANDALONE_TAG="20251120" backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi # Handle l4t build profiles (Python 3.12, pip fallback) if needed if [ "x${BUILD_PROFILE}" == "xl4t13" ]; then PYTHON_VERSION="3.12" PYTHON_PATCH="12" PY_STANDALONE_TAG="20251120" fi if [ "x${BUILD_PROFILE}" == "xl4t12" ]; then USE_PIP=true fi # Install base requirements first installRequirements # Install vllm based on build type if [ "x${BUILD_TYPE}" == "xhipblas" ]; then # ROCm if [ "x${USE_PIP}" == "xtrue" ]; then pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700 else uv pip install vllm==0.14.0 --extra-index-url https://wheels.vllm.ai/rocm/0.14.0/rocm700 fi elif [ "x${BUILD_TYPE}" == "xcublas" ] || [ "x${BUILD_TYPE}" == "x" ]; then # CUDA (default) or CPU if [ "x${USE_PIP}" == "xtrue" ]; then pip install vllm==0.14.0 --torch-backend=auto else uv pip install vllm==0.14.0 --torch-backend=auto fi else echo "Unsupported build type: ${BUILD_TYPE}" >&2 exit 1 fi # Clone and install vllm-omni from source if [ ! -d vllm-omni ]; then git clone https://github.com/vllm-project/vllm-omni.git fi cd vllm-omni/ if [ "x${USE_PIP}" == "xtrue" ]; then pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e . else uv pip install ${EXTRA_PIP_INSTALL_FLAGS:-} -e . fi cd .. ================================================ FILE: backend/python/vllm-omni/requirements-after.txt ================================================ diffusers librosa ================================================ FILE: backend/python/vllm-omni/requirements-cublas12-after.txt ================================================ https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.3/flash_attn-2.8.3+cu12torch2.7cxx11abiTRUE-cp310-cp310-linux_x86_64.whl ================================================ FILE: backend/python/vllm-omni/requirements-cublas12.txt ================================================ accelerate torch==2.7.0 transformers bitsandbytes ================================================ FILE: backend/python/vllm-omni/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.4 accelerate torch transformers bitsandbytes ================================================ FILE: backend/python/vllm-omni/requirements.txt ================================================ grpcio==1.76.0 protobuf certifi setuptools pillow numpy soundfile ================================================ FILE: backend/python/vllm-omni/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/vllm-omni/test.py ================================================ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service. This class contains methods to test the startup and shutdown of the gRPC service. """ def setUp(self): self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"]) time.sleep(10) def tearDown(self) -> None: self.service.terminate() self.service.wait() def test_server_startup(self): try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) # Use a small image generation model for testing response = stub.LoadModel(backend_pb2.ModelOptions(Model="Tongyi-MAI/Z-Image-Turbo")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_generate_image(self): """ This method tests if image generation works """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="Tongyi-MAI/Z-Image-Turbo")) self.assertTrue(response.success) req = backend_pb2.GenerateImageRequest( positive_prompt="a cup of coffee on the table", dst="/tmp/test_output.png", width=512, height=512, step=20, seed=42additional_information ) resp = stub.GenerateImage(req) self.assertTrue(resp.success) except Exception as err: print(err) self.fail("GenerateImage service failed") finally: self.tearDown() additional_information if __name__ == "__main__": unittest.main() ================================================ FILE: backend/python/vllm-omni/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/voxcpm/Makefile ================================================ .PHONY: voxcpm voxcpm: bash install.sh .PHONY: run run: voxcpm @echo "Running voxcpm..." bash run.sh @echo "voxcpm run." .PHONY: test test: voxcpm @echo "Testing voxcpm..." bash test.sh @echo "voxcpm tested." .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ ================================================ FILE: backend/python/voxcpm/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for VoxCPM """ from concurrent import futures import time import argparse import signal import sys import os import traceback import numpy as np import soundfile as sf from voxcpm import VoxCPM import backend_pb2 import backend_pb2_grpc import torch import grpc def is_float(s): """Check if a string can be converted to float.""" try: float(s) return True except ValueError: return False def is_int(s): """Check if a string can be converted to int.""" try: int(s) return True except ValueError: return False _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): # Get device if torch.cuda.is_available(): print("CUDA is available", file=sys.stderr) device = "cuda" else: print("CUDA is not available", file=sys.stderr) device = "cpu" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" if not torch.cuda.is_available() and request.CUDA: return backend_pb2.Result(success=False, message="CUDA is not available") # Normalize potential 'mpx' typo to 'mps' if device == "mpx": print("Note: device 'mpx' detected, treating it as 'mps'.", file=sys.stderr) device = "mps" # Validate mps availability if requested if device == "mps" and not torch.backends.mps.is_available(): print("Warning: MPS not available. Falling back to CPU.", file=sys.stderr) device = "cpu" self.device = device options = request.Options # empty dict self.options = {} # The options are a list of strings in this form optname:optvalue # We are storing all the options in a dict so we can use it later when # generating the audio for opt in options: if ":" not in opt: continue key, value = opt.split(":", 1) # Split only on first colon # if value is a number, convert it to the appropriate type if is_float(value): value = float(value) elif is_int(value): value = int(value) elif value.lower() in ["true", "false"]: value = value.lower() == "true" self.options[key] = value # Get model path from request model_path = request.Model if not model_path: model_path = "openbmb/VoxCPM1.5" try: print(f"Loading model from {model_path}", file=sys.stderr) self.model = VoxCPM.from_pretrained(model_path) print(f"Model loaded successfully on device: {self.device}", file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Model loaded successfully", success=True) def TTS(self, request, context): try: # Get generation parameters from options with defaults cfg_value = self.options.get("cfg_value", 2.0) inference_timesteps = self.options.get("inference_timesteps", 10) normalize = self.options.get("normalize", False) denoise = self.options.get("denoise", False) retry_badcase = self.options.get("retry_badcase", True) retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3) retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0) use_streaming = self.options.get("streaming", False) # Handle voice cloning via prompt_wav_path and prompt_text prompt_wav_path = None prompt_text = None # Priority: request.voice > AudioPath > options if hasattr(request, 'voice') and request.voice: # If voice is provided, try to use it as a path if os.path.exists(request.voice): prompt_wav_path = request.voice elif hasattr(request, 'ModelFile') and request.ModelFile: model_file_base = os.path.dirname(request.ModelFile) potential_path = os.path.join(model_file_base, request.voice) if os.path.exists(potential_path): prompt_wav_path = potential_path elif hasattr(request, 'ModelPath') and request.ModelPath: potential_path = os.path.join(request.ModelPath, request.voice) if os.path.exists(potential_path): prompt_wav_path = potential_path if hasattr(request, 'AudioPath') and request.AudioPath: if os.path.isabs(request.AudioPath): prompt_wav_path = request.AudioPath elif hasattr(request, 'ModelFile') and request.ModelFile: model_file_base = os.path.dirname(request.ModelFile) prompt_wav_path = os.path.join(model_file_base, request.AudioPath) elif hasattr(request, 'ModelPath') and request.ModelPath: prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath) else: prompt_wav_path = request.AudioPath # Get prompt_text from options if available if "prompt_text" in self.options: prompt_text = self.options["prompt_text"] # Prepare text text = request.text.strip() print(f"Generating audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, streaming: {use_streaming}", file=sys.stderr) # Generate audio if use_streaming: # Streaming generation chunks = [] for chunk in self.model.generate_streaming( text=text, prompt_wav_path=prompt_wav_path, prompt_text=prompt_text, cfg_value=cfg_value, inference_timesteps=inference_timesteps, normalize=normalize, denoise=denoise, retry_badcase=retry_badcase, retry_badcase_max_times=retry_badcase_max_times, retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, ): chunks.append(chunk) wav = np.concatenate(chunks) else: # Non-streaming generation wav = self.model.generate( text=text, prompt_wav_path=prompt_wav_path, prompt_text=prompt_text, cfg_value=cfg_value, inference_timesteps=inference_timesteps, normalize=normalize, denoise=denoise, retry_badcase=retry_badcase, retry_badcase_max_times=retry_badcase_max_times, retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, ) # Get sample rate from model sample_rate = self.model.tts_model.sample_rate # Save output sf.write(request.dst, wav, sample_rate) print(f"Saved output to {request.dst}", file=sys.stderr) except Exception as err: print(f"Error in TTS: {err}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(success=True) def TTSStream(self, request, context): try: # Get generation parameters from options with defaults cfg_value = self.options.get("cfg_value", 2.0) inference_timesteps = self.options.get("inference_timesteps", 10) normalize = self.options.get("normalize", False) denoise = self.options.get("denoise", False) retry_badcase = self.options.get("retry_badcase", True) retry_badcase_max_times = self.options.get("retry_badcase_max_times", 3) retry_badcase_ratio_threshold = self.options.get("retry_badcase_ratio_threshold", 6.0) # Handle voice cloning via prompt_wav_path and prompt_text prompt_wav_path = None prompt_text = None # Priority: request.voice > AudioPath > options if hasattr(request, 'voice') and request.voice: # If voice is provided, try to use it as a path if os.path.exists(request.voice): prompt_wav_path = request.voice elif hasattr(request, 'ModelFile') and request.ModelFile: model_file_base = os.path.dirname(request.ModelFile) potential_path = os.path.join(model_file_base, request.voice) if os.path.exists(potential_path): prompt_wav_path = potential_path elif hasattr(request, 'ModelPath') and request.ModelPath: potential_path = os.path.join(request.ModelPath, request.voice) if os.path.exists(potential_path): prompt_wav_path = potential_path if hasattr(request, 'AudioPath') and request.AudioPath: if os.path.isabs(request.AudioPath): prompt_wav_path = request.AudioPath elif hasattr(request, 'ModelFile') and request.ModelFile: model_file_base = os.path.dirname(request.ModelFile) prompt_wav_path = os.path.join(model_file_base, request.AudioPath) elif hasattr(request, 'ModelPath') and request.ModelPath: prompt_wav_path = os.path.join(request.ModelPath, request.AudioPath) else: prompt_wav_path = request.AudioPath # Get prompt_text from options if available if "prompt_text" in self.options: prompt_text = self.options["prompt_text"] # Prepare text text = request.text.strip() # Get sample rate from model (needed for WAV header) sample_rate = self.model.tts_model.sample_rate print(f"Streaming audio with cfg_value: {cfg_value}, inference_timesteps: {inference_timesteps}, sample_rate: {sample_rate}", file=sys.stderr) # Send sample rate as first message (in message field as JSON or string) # Format: "sample_rate:16000" so we can parse it import json sample_rate_info = json.dumps({"sample_rate": int(sample_rate)}) yield backend_pb2.Reply(message=bytes(sample_rate_info, 'utf-8')) # Stream audio chunks for chunk in self.model.generate_streaming( text=text, prompt_wav_path=prompt_wav_path, prompt_text=prompt_text, cfg_value=cfg_value, inference_timesteps=inference_timesteps, normalize=normalize, denoise=denoise, retry_badcase=retry_badcase, retry_badcase_max_times=retry_badcase_max_times, retry_badcase_ratio_threshold=retry_badcase_ratio_threshold, ): # Convert numpy array to int16 PCM and then to bytes # Ensure values are in int16 range chunk_int16 = np.clip(chunk * 32767, -32768, 32767).astype(np.int16) chunk_bytes = chunk_int16.tobytes() yield backend_pb2.Reply(audio=chunk_bytes) except Exception as err: print(f"Error in TTSStream: {err}", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) # Yield an error reply yield backend_pb2.Reply(message=bytes(f"Error: {err}", 'utf-8')) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/voxcpm/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi installRequirements if [ "x${USE_PIP}" == "xtrue" ]; then pip install "setuptools<70.0.0" else uv pip install "setuptools<70.0.0" fi # Apply patch to fix PyTorch compatibility issue in voxcpm # This fixes the "Dimension out of range" error in scaled_dot_product_attention # by changing .contiguous() to .unsqueeze(0) in the attention module # The patch is needed because voxcpm's initialization test generation fails with # certain PyTorch versions due to a bug in scaled_dot_product_attention # https://github.com/OpenBMB/VoxCPM/issues/71#issuecomment-3441789452 VOXCPM_PATH=$(python -c "import voxcpm; import os; print(os.path.dirname(voxcpm.__file__))" 2>/dev/null || echo "") if [ -n "$VOXCPM_PATH" ] && [ -f "$VOXCPM_PATH/modules/minicpm4/model.py" ]; then echo "Applying patch to voxcpm at $VOXCPM_PATH/modules/minicpm4/model.py" # Replace .contiguous() with .unsqueeze(0) for the three lines in the attention forward_step method # This fixes the dimension error in scaled_dot_product_attention # Use temp file for in-place edit so it works on both BSD sed (macOS) and GNU sed (Linux) PATCH_FILE="$VOXCPM_PATH/modules/minicpm4/model.py" sed 's/query_states = query_states\.contiguous()/query_states = query_states.unsqueeze(0)/g; s/key_cache = key_cache\.contiguous()/key_cache = key_cache.unsqueeze(0)/g; s/value_cache = value_cache\.contiguous()/value_cache = value_cache.unsqueeze(0)/g' "$PATCH_FILE" > "${PATCH_FILE}.tmp" && mv "${PATCH_FILE}.tmp" "$PATCH_FILE" echo "Patch applied successfully" else echo "Warning: Could not find voxcpm installation to apply patch (path: ${VOXCPM_PATH:-not found})" fi ================================================ FILE: backend/python/voxcpm/protogen.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runProtogen ================================================ FILE: backend/python/voxcpm/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu torch soundfile numpy voxcpm torchcodec ================================================ FILE: backend/python/voxcpm/requirements-cublas12.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu121 torch soundfile numpy voxcpm ================================================ FILE: backend/python/voxcpm/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch soundfile numpy voxcpm ================================================ FILE: backend/python/voxcpm/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.3 torch==2.7.1+rocm6.3 soundfile numpy voxcpm ================================================ FILE: backend/python/voxcpm/requirements-intel.txt ================================================ --extra-index-url https://download.pytorch.org/whl/xpu torch setuptools soundfile numpy voxcpm ================================================ FILE: backend/python/voxcpm/requirements-l4t12.txt ================================================ --extra-index-url https://pypi.jetson-ai-lab.io/jp6/cu129/ torch soundfile numpy voxcpm ================================================ FILE: backend/python/voxcpm/requirements-l4t13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch soundfile numpy voxcpm ================================================ FILE: backend/python/voxcpm/requirements-mps.txt ================================================ torch soundfile numpy voxcpm ================================================ FILE: backend/python/voxcpm/requirements.txt ================================================ setuptools grpcio==1.76.0 protobuf certifi packaging==24.1 soundfile numpy voxcpm ================================================ FILE: backend/python/voxcpm/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/voxcpm/test.py ================================================ """ A test script to test the gRPC service """ import unittest import subprocess import time import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(30) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() print("Starting test_load_model") with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="openbmb/VoxCPM1.5")) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") tts_request = backend_pb2.TTSRequest(text="VoxCPM is an innovative end-to-end TTS model from ModelBest.", dst="test.wav") tts_response = stub.TTS(tts_request) self.assertIsNotNone(tts_response) except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_tts_stream(self): """ This method tests if TTS streaming works correctly """ try: self.setUp() print("Starting test_tts_stream") with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="openbmb/VoxCPM1.5")) print(response) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") # Test TTSStream tts_request = backend_pb2.TTSRequest(text="VoxCPM is an innovative end-to-end TTS model from ModelBest. This is a streaming test.", dst="test_stream.wav") chunks_received = 0 total_audio_bytes = 0 for reply in stub.TTSStream(tts_request): # Verify that we receive audio chunks if reply.audio: chunks_received += 1 total_audio_bytes += len(reply.audio) self.assertGreater(len(reply.audio), 0, "Audio chunk should not be empty") # Verify that we received multiple chunks self.assertGreater(chunks_received, 0, "Should receive at least one audio chunk") self.assertGreater(total_audio_bytes, 0, "Total audio bytes should be greater than 0") print(f"Received {chunks_received} chunks with {total_audio_bytes} total bytes") except Exception as err: print(err) self.fail("TTSStream service failed") finally: self.tearDown() ================================================ FILE: backend/python/voxcpm/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: backend/python/whisperx/Makefile ================================================ .DEFAULT_GOAL := install .PHONY: install install: bash install.sh .PHONY: protogen-clean protogen-clean: $(RM) backend_pb2_grpc.py backend_pb2.py .PHONY: clean clean: protogen-clean rm -rf venv __pycache__ test: install bash test.sh ================================================ FILE: backend/python/whisperx/backend.py ================================================ #!/usr/bin/env python3 """ This is an extra gRPC server of LocalAI for WhisperX transcription with speaker diarization, word-level timestamps, and forced alignment. """ from concurrent import futures import time import argparse import signal import sys import os import backend_pb2 import backend_pb2_grpc import grpc _ONE_DAY_IN_SECONDS = 60 * 60 * 24 # If MAX_WORKERS are specified in the environment use it, otherwise default to 1 MAX_WORKERS = int(os.environ.get('PYTHON_GRPC_MAX_WORKERS', '1')) # Implement the BackendServicer class with the service methods class BackendServicer(backend_pb2_grpc.BackendServicer): """ BackendServicer is the class that implements the gRPC service """ def Health(self, request, context): return backend_pb2.Reply(message=bytes("OK", 'utf-8')) def LoadModel(self, request, context): import whisperx import torch device = "cpu" if request.CUDA: device = "cuda" mps_available = hasattr(torch.backends, "mps") and torch.backends.mps.is_available() if mps_available: device = "mps" try: print("Preparing WhisperX model, please wait", file=sys.stderr) compute_type = "float16" if device != "cpu" else "int8" self.model = whisperx.load_model( request.Model, device, compute_type=compute_type, ) self.device = device self.model_name = request.Model # Store HF token for diarization if available self.hf_token = os.environ.get("HF_TOKEN", None) self.diarize_pipeline = None # Cache for alignment models keyed by language code self.align_cache = {} print(f"WhisperX model loaded: {request.Model} on {device}", file=sys.stderr) except Exception as err: return backend_pb2.Result(success=False, message=f"Unexpected {err=}, {type(err)=}") return backend_pb2.Result(message="Model loaded successfully", success=True) def _get_align_model(self, language_code): """Load or return cached alignment model for a given language.""" import whisperx if language_code not in self.align_cache: model_a, metadata = whisperx.load_align_model( language_code=language_code, device=self.device, ) self.align_cache[language_code] = (model_a, metadata) return self.align_cache[language_code] def AudioTranscription(self, request, context): import whisperx resultSegments = [] text = "" try: audio = whisperx.load_audio(request.dst) # Transcribe transcript = self.model.transcribe( audio, batch_size=16, language=request.language if request.language else None, ) # Align for word-level timestamps model_a, metadata = self._get_align_model(transcript["language"]) transcript = whisperx.align( transcript["segments"], model_a, metadata, audio, self.device, return_char_alignments=False, ) # Diarize if requested and HF token is available if request.diarize and self.hf_token: if self.diarize_pipeline is None: self.diarize_pipeline = whisperx.DiarizationPipeline( use_auth_token=self.hf_token, device=self.device, ) diarize_segments = self.diarize_pipeline(audio) transcript = whisperx.assign_word_speakers(diarize_segments, transcript) # Build result segments for idx, seg in enumerate(transcript["segments"]): seg_text = seg.get("text", "") start = int(seg.get("start", 0)) end = int(seg.get("end", 0)) speaker = seg.get("speaker", "") resultSegments.append(backend_pb2.TranscriptSegment( id=idx, start=start, end=end, text=seg_text, speaker=speaker, )) text += seg_text except Exception as err: print(f"Unexpected {err=}, {type(err)=}", file=sys.stderr) return backend_pb2.TranscriptResult(segments=[], text="") return backend_pb2.TranscriptResult(segments=resultSegments, text=text) def serve(address): server = grpc.server(futures.ThreadPoolExecutor(max_workers=MAX_WORKERS), options=[ ('grpc.max_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_send_message_length', 50 * 1024 * 1024), # 50MB ('grpc.max_receive_message_length', 50 * 1024 * 1024), # 50MB ]) backend_pb2_grpc.add_BackendServicer_to_server(BackendServicer(), server) server.add_insecure_port(address) server.start() print("Server started. Listening on: " + address, file=sys.stderr) # Define the signal handler function def signal_handler(sig, frame): print("Received termination signal. Shutting down...") server.stop(0) sys.exit(0) # Set the signal handlers for SIGINT and SIGTERM signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) try: while True: time.sleep(_ONE_DAY_IN_SECONDS) except KeyboardInterrupt: server.stop(0) if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run the gRPC server.") parser.add_argument( "--addr", default="localhost:50051", help="The address to bind the server to." ) args = parser.parse_args() serve(args.addr) ================================================ FILE: backend/python/whisperx/install.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi if [ "x${BUILD_PROFILE}" != "xmetal" ] && [ "x${BUILD_PROFILE}" != "xmps" ]; then EXTRA_PIP_INSTALL_FLAGS+=" --index-strategy unsafe-best-match" fi installRequirements ================================================ FILE: backend/python/whisperx/protogen.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi python3 -m grpc_tools.protoc -I../.. -I./ --python_out=. --grpc_python_out=. backend.proto ================================================ FILE: backend/python/whisperx/requirements-cpu.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cpu torch==2.8.0 whisperx @ git+https://github.com/m-bain/whisperX.git ================================================ FILE: backend/python/whisperx/requirements-cublas12.txt ================================================ torch whisperx @ git+https://github.com/m-bain/whisperX.git ================================================ FILE: backend/python/whisperx/requirements-cublas13.txt ================================================ --extra-index-url https://download.pytorch.org/whl/cu130 torch whisperx @ git+https://github.com/m-bain/whisperX.git ================================================ FILE: backend/python/whisperx/requirements-hipblas.txt ================================================ --extra-index-url https://download.pytorch.org/whl/rocm6.4 torch==2.8.0 whisperx @ git+https://github.com/m-bain/whisperX.git ================================================ FILE: backend/python/whisperx/requirements-mps.txt ================================================ torch whisperx @ git+https://github.com/m-bain/whisperX.git ================================================ FILE: backend/python/whisperx/requirements.txt ================================================ grpcio==1.71.0 protobuf grpcio-tools ================================================ FILE: backend/python/whisperx/run.sh ================================================ #!/bin/bash backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi startBackend $@ ================================================ FILE: backend/python/whisperx/test.py ================================================ """ A test script to test the gRPC service for WhisperX transcription """ import unittest import subprocess import time import os import tempfile import shutil import backend_pb2 import backend_pb2_grpc import grpc class TestBackendServicer(unittest.TestCase): """ TestBackendServicer is the class that tests the gRPC service """ def setUp(self): """ This method sets up the gRPC service by starting the server """ self.service = subprocess.Popen(["python3", "backend.py", "--addr", "localhost:50051"]) time.sleep(10) def tearDown(self) -> None: """ This method tears down the gRPC service by terminating the server """ self.service.terminate() self.service.wait() def test_server_startup(self): """ This method tests if the server starts up successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.Health(backend_pb2.HealthMessage()) self.assertEqual(response.message, b'OK') except Exception as err: print(err) self.fail("Server failed to start") finally: self.tearDown() def test_load_model(self): """ This method tests if the model is loaded successfully """ try: self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) response = stub.LoadModel(backend_pb2.ModelOptions(Model="tiny")) self.assertTrue(response.success) self.assertEqual(response.message, "Model loaded successfully") except Exception as err: print(err) self.fail("LoadModel service failed") finally: self.tearDown() def test_audio_transcription(self): """ This method tests if audio transcription works successfully """ # Create a temporary directory for the audio file temp_dir = tempfile.mkdtemp() audio_file = os.path.join(temp_dir, 'audio.wav') try: # Download the audio file to the temporary directory print(f"Downloading audio file to {audio_file}...") url = "https://cdn.openai.com/whisper/draft-20220913a/micro-machines.wav" result = subprocess.run( ["wget", "-q", url, "-O", audio_file], capture_output=True, text=True ) if result.returncode != 0: self.fail(f"Failed to download audio file: {result.stderr}") # Verify the file was downloaded if not os.path.exists(audio_file): self.fail(f"Audio file was not downloaded to {audio_file}") self.setUp() with grpc.insecure_channel("localhost:50051") as channel: stub = backend_pb2_grpc.BackendStub(channel) # Load the model first load_response = stub.LoadModel(backend_pb2.ModelOptions(Model="tiny")) self.assertTrue(load_response.success) # Perform transcription without diarization transcript_request = backend_pb2.TranscriptRequest(dst=audio_file) transcript_response = stub.AudioTranscription(transcript_request) # Print the transcribed text for debugging print(f"Transcribed text: {transcript_response.text}") print(f"Number of segments: {len(transcript_response.segments)}") # Verify response structure self.assertIsNotNone(transcript_response) self.assertIsNotNone(transcript_response.text) self.assertGreater(len(transcript_response.text), 0) self.assertGreater(len(transcript_response.segments), 0) # Verify segments have timing info segment = transcript_response.segments[0] self.assertIsNotNone(segment.text) self.assertIsInstance(segment.id, int) except Exception as err: print(err) self.fail("AudioTranscription service failed") finally: self.tearDown() # Clean up the temporary directory if os.path.exists(temp_dir): shutil.rmtree(temp_dir) ================================================ FILE: backend/python/whisperx/test.sh ================================================ #!/bin/bash set -e backend_dir=$(dirname $0) if [ -d $backend_dir/common ]; then source $backend_dir/common/libbackend.sh else source $backend_dir/../common/libbackend.sh fi runUnittests ================================================ FILE: cmd/launcher/icon.go ================================================ package main import ( _ "embed" "fyne.io/fyne/v2" ) //go:embed logo.png var logoData []byte // resourceIconPng is the LocalAI logo icon var resourceIconPng = &fyne.StaticResource{ StaticName: "logo.png", StaticContent: logoData, } ================================================ FILE: cmd/launcher/internal/launcher.go ================================================ package launcher import ( "bufio" "context" "encoding/json" "fmt" "io" "log" "net/url" "os" "os/exec" "path/filepath" "strings" "sync" "syscall" "time" "fyne.io/fyne/v2" "fyne.io/fyne/v2/container" "fyne.io/fyne/v2/dialog" "fyne.io/fyne/v2/widget" ) // Config represents the launcher configuration type Config struct { ModelsPath string `json:"models_path"` BackendsPath string `json:"backends_path"` Address string `json:"address"` AutoStart bool `json:"auto_start"` StartOnBoot bool `json:"start_on_boot"` LogLevel string `json:"log_level"` EnvironmentVars map[string]string `json:"environment_vars"` ShowWelcome *bool `json:"show_welcome"` } // Launcher represents the main launcher application type Launcher struct { // Core components releaseManager *ReleaseManager config *Config ui *LauncherUI systray *SystrayManager ctx context.Context window fyne.Window app fyne.App // Process management localaiCmd *exec.Cmd isRunning bool logBuffer *strings.Builder logMutex sync.RWMutex statusChannel chan string // Logging logFile *os.File logPath string // UI state lastUpdateCheck time.Time } // NewLauncher creates a new launcher instance func NewLauncher(ui *LauncherUI, window fyne.Window, app fyne.App) *Launcher { return &Launcher{ releaseManager: NewReleaseManager(), config: &Config{}, logBuffer: &strings.Builder{}, statusChannel: make(chan string, 100), ctx: context.Background(), ui: ui, window: window, app: app, } } // setupLogging sets up log file for LocalAI process output func (l *Launcher) setupLogging() error { // Create logs directory in data folder dataPath := l.GetDataPath() logsDir := filepath.Join(dataPath, "logs") if err := os.MkdirAll(logsDir, 0755); err != nil { return fmt.Errorf("failed to create logs directory: %w", err) } // Create log file with timestamp timestamp := time.Now().Format("2006-01-02_15-04-05") l.logPath = filepath.Join(logsDir, fmt.Sprintf("localai_%s.log", timestamp)) logFile, err := os.Create(l.logPath) if err != nil { return fmt.Errorf("failed to create log file: %w", err) } l.logFile = logFile return nil } // Initialize sets up the launcher func (l *Launcher) Initialize() error { if l.app == nil { return fmt.Errorf("app is nil") } log.Printf("Initializing launcher...") // Setup logging if err := l.setupLogging(); err != nil { return fmt.Errorf("failed to setup logging: %w", err) } // Load configuration log.Printf("Loading configuration...") if err := l.loadConfig(); err != nil { return fmt.Errorf("failed to load config: %w", err) } log.Printf("Configuration loaded, current state: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s", l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel) // Clean up any partial downloads log.Printf("Cleaning up partial downloads...") if err := l.releaseManager.CleanupPartialDownloads(); err != nil { log.Printf("Warning: failed to cleanup partial downloads: %v", err) } if l.config.StartOnBoot { l.StartLocalAI() } // Set default paths if not configured (only if not already loaded from config) if l.config.ModelsPath == "" { homeDir, _ := os.UserHomeDir() l.config.ModelsPath = filepath.Join(homeDir, ".localai", "models") log.Printf("Setting default ModelsPath: %s", l.config.ModelsPath) } if l.config.BackendsPath == "" { homeDir, _ := os.UserHomeDir() l.config.BackendsPath = filepath.Join(homeDir, ".localai", "backends") log.Printf("Setting default BackendsPath: %s", l.config.BackendsPath) } if l.config.Address == "" { l.config.Address = "127.0.0.1:8080" log.Printf("Setting default Address: %s", l.config.Address) } if l.config.LogLevel == "" { l.config.LogLevel = "info" log.Printf("Setting default LogLevel: %s", l.config.LogLevel) } if l.config.EnvironmentVars == nil { l.config.EnvironmentVars = make(map[string]string) log.Printf("Initializing empty EnvironmentVars map") } // Set default welcome window preference if l.config.ShowWelcome == nil { true := true l.config.ShowWelcome = &true log.Printf("Setting default ShowWelcome: true") } // Create directories os.MkdirAll(l.config.ModelsPath, 0755) os.MkdirAll(l.config.BackendsPath, 0755) // Save the configuration with default values if err := l.saveConfig(); err != nil { log.Printf("Warning: failed to save default configuration: %v", err) } // System tray is now handled in main.go using Fyne's built-in approach // Check if LocalAI is installed if !l.releaseManager.IsLocalAIInstalled() { log.Printf("No LocalAI installation found") fyne.Do(func() { l.updateStatus("No LocalAI installation found") if l.ui != nil { // Show dialog offering to download LocalAI l.showDownloadLocalAIDialog() } }) } // Check for updates periodically go l.periodicUpdateCheck() return nil } // StartLocalAI starts the LocalAI server func (l *Launcher) StartLocalAI() error { if l.isRunning { return fmt.Errorf("LocalAI is already running") } // Verify binary integrity before starting if err := l.releaseManager.VerifyInstalledBinary(); err != nil { // Binary is corrupted, remove it and offer to reinstall binaryPath := l.releaseManager.GetBinaryPath() if removeErr := os.Remove(binaryPath); removeErr != nil { log.Printf("Failed to remove corrupted binary: %v", removeErr) } return fmt.Errorf("LocalAI binary is corrupted: %v. Please reinstall LocalAI", err) } binaryPath := l.releaseManager.GetBinaryPath() if _, err := os.Stat(binaryPath); os.IsNotExist(err) { return fmt.Errorf("LocalAI binary not found. Please download a release first") } // Build command arguments args := []string{ "run", "--models-path", l.config.ModelsPath, "--backends-path", l.config.BackendsPath, "--address", l.config.Address, "--log-level", l.config.LogLevel, } l.localaiCmd = exec.CommandContext(l.ctx, binaryPath, args...) // Apply environment variables if len(l.config.EnvironmentVars) > 0 { env := os.Environ() for key, value := range l.config.EnvironmentVars { env = append(env, fmt.Sprintf("%s=%s", key, value)) } l.localaiCmd.Env = env } // Setup logging stdout, err := l.localaiCmd.StdoutPipe() if err != nil { return fmt.Errorf("failed to create stdout pipe: %w", err) } stderr, err := l.localaiCmd.StderrPipe() if err != nil { return fmt.Errorf("failed to create stderr pipe: %w", err) } // Start the process if err := l.localaiCmd.Start(); err != nil { return fmt.Errorf("failed to start LocalAI: %w", err) } l.isRunning = true fyne.Do(func() { l.updateStatus("LocalAI is starting...") l.updateRunningState(true) }) // Start log monitoring go l.monitorLogs(stdout, "STDOUT") go l.monitorLogs(stderr, "STDERR") // Monitor process with startup timeout go func() { // Wait for process to start or fail err := l.localaiCmd.Wait() l.isRunning = false fyne.Do(func() { l.updateRunningState(false) if err != nil { l.updateStatus(fmt.Sprintf("LocalAI stopped with error: %v", err)) } else { l.updateStatus("LocalAI stopped") } }) }() // Add startup timeout detection go func() { time.Sleep(10 * time.Second) // Wait 10 seconds for startup if l.isRunning { // Check if process is still alive if l.localaiCmd.Process != nil { if err := l.localaiCmd.Process.Signal(syscall.Signal(0)); err != nil { // Process is dead, mark as not running l.isRunning = false fyne.Do(func() { l.updateRunningState(false) l.updateStatus("LocalAI failed to start properly") }) } } } }() return nil } // StopLocalAI stops the LocalAI server func (l *Launcher) StopLocalAI() error { if !l.isRunning || l.localaiCmd == nil { return fmt.Errorf("LocalAI is not running") } // Gracefully terminate the process if err := l.localaiCmd.Process.Signal(os.Interrupt); err != nil { // If graceful termination fails, force kill if killErr := l.localaiCmd.Process.Kill(); killErr != nil { return fmt.Errorf("failed to kill LocalAI process: %w", killErr) } } l.isRunning = false fyne.Do(func() { l.updateRunningState(false) l.updateStatus("LocalAI stopped") }) return nil } // IsRunning returns whether LocalAI is currently running func (l *Launcher) IsRunning() bool { return l.isRunning } // Shutdown performs cleanup when the application is closing func (l *Launcher) Shutdown() error { log.Printf("Launcher shutting down, stopping LocalAI...") // Stop LocalAI if it's running if l.isRunning { if err := l.StopLocalAI(); err != nil { log.Printf("Error stopping LocalAI during shutdown: %v", err) } } // Close log file if open if l.logFile != nil { if err := l.logFile.Close(); err != nil { log.Printf("Error closing log file: %v", err) } l.logFile = nil } log.Printf("Launcher shutdown complete") return nil } // GetLogs returns the current log buffer func (l *Launcher) GetLogs() string { l.logMutex.RLock() defer l.logMutex.RUnlock() return l.logBuffer.String() } // GetRecentLogs returns the most recent logs (last 50 lines) for better error display func (l *Launcher) GetRecentLogs() string { l.logMutex.RLock() defer l.logMutex.RUnlock() content := l.logBuffer.String() lines := strings.Split(content, "\n") // Get last 50 lines if len(lines) > 50 { lines = lines[len(lines)-50:] } return strings.Join(lines, "\n") } // GetConfig returns the current configuration func (l *Launcher) GetConfig() *Config { return l.config } // SetConfig updates the configuration func (l *Launcher) SetConfig(config *Config) error { l.config = config return l.saveConfig() } func (l *Launcher) GetUI() *LauncherUI { return l.ui } func (l *Launcher) SetSystray(systray *SystrayManager) { l.systray = systray } // GetReleaseManager returns the release manager func (l *Launcher) GetReleaseManager() *ReleaseManager { return l.releaseManager } // GetWebUIURL returns the URL for the WebUI func (l *Launcher) GetWebUIURL() string { address := l.config.Address if strings.HasPrefix(address, ":") { address = "localhost" + address } if !strings.HasPrefix(address, "http") { address = "http://" + address } return address } // GetDataPath returns the path where LocalAI data and logs are stored func (l *Launcher) GetDataPath() string { // LocalAI typically stores data in the current working directory or a models directory // First check if models path is configured if l.config != nil && l.config.ModelsPath != "" { // Return the parent directory of models path return filepath.Dir(l.config.ModelsPath) } // Fallback to home directory LocalAI folder homeDir, err := os.UserHomeDir() if err != nil { return "." } return filepath.Join(homeDir, ".localai") } // CheckForUpdates checks if there are any available updates func (l *Launcher) CheckForUpdates() (bool, string, error) { log.Printf("CheckForUpdates: checking for available updates...") available, version, err := l.releaseManager.IsUpdateAvailable() if err != nil { log.Printf("CheckForUpdates: error occurred: %v", err) return false, "", err } log.Printf("CheckForUpdates: result - available=%v, version=%s", available, version) l.lastUpdateCheck = time.Now() return available, version, nil } // DownloadUpdate downloads the latest version func (l *Launcher) DownloadUpdate(version string, progressCallback func(float64)) error { return l.releaseManager.DownloadRelease(version, progressCallback) } // GetCurrentVersion returns the current installed version func (l *Launcher) GetCurrentVersion() string { return l.releaseManager.GetInstalledVersion() } // GetCurrentStatus returns the current status func (l *Launcher) GetCurrentStatus() string { select { case status := <-l.statusChannel: return status default: if l.isRunning { return "LocalAI is running" } return "Ready" } } // GetLastStatus returns the last known status without consuming from channel func (l *Launcher) GetLastStatus() string { if l.isRunning { return "LocalAI is running" } // Check if LocalAI is installed if !l.releaseManager.IsLocalAIInstalled() { return "LocalAI not installed" } return "Ready" } func (l *Launcher) githubReleaseNotesURL(version string) (*url.URL, error) { // Construct GitHub release URL releaseURL := fmt.Sprintf("https://github.com/%s/%s/releases/tag/%s", l.releaseManager.GitHubOwner, l.releaseManager.GitHubRepo, version) // Convert string to *url.URL return url.Parse(releaseURL) } // showDownloadLocalAIDialog shows a dialog offering to download LocalAI func (l *Launcher) showDownloadLocalAIDialog() { if l.app == nil { log.Printf("Cannot show download dialog: app is nil") return } fyne.DoAndWait(func() { // Create a standalone window for the download dialog dialogWindow := l.app.NewWindow("LocalAI Installation Required") dialogWindow.Resize(fyne.NewSize(500, 350)) dialogWindow.CenterOnScreen() dialogWindow.SetCloseIntercept(func() { dialogWindow.Close() }) // Create the dialog content titleLabel := widget.NewLabel("LocalAI Not Found") titleLabel.TextStyle = fyne.TextStyle{Bold: true} titleLabel.Alignment = fyne.TextAlignCenter messageLabel := widget.NewLabel("LocalAI is not installed on your system.\n\nWould you like to download and install the latest version?") messageLabel.Wrapping = fyne.TextWrapWord messageLabel.Alignment = fyne.TextAlignCenter // Buttons downloadButton := widget.NewButton("Download & Install", func() { dialogWindow.Close() l.downloadAndInstallLocalAI() if l.systray != nil { l.systray.recreateMenu() } }) downloadButton.Importance = widget.HighImportance // Release notes button releaseNotesButton := widget.NewButton("View Release Notes", func() { // Get latest release info and open release notes go func() { release, err := l.releaseManager.GetLatestRelease() if err != nil { log.Printf("Failed to get latest release info: %v", err) return } releaseNotesURL, err := l.githubReleaseNotesURL(release.Version) if err != nil { log.Printf("Failed to parse URL: %v", err) return } l.app.OpenURL(releaseNotesURL) }() }) skipButton := widget.NewButton("Skip for Now", func() { dialogWindow.Close() }) // Layout - put release notes button above the main action buttons actionButtons := container.NewHBox(skipButton, downloadButton) content := container.NewVBox( titleLabel, widget.NewSeparator(), messageLabel, widget.NewSeparator(), releaseNotesButton, widget.NewSeparator(), actionButtons, ) dialogWindow.SetContent(content) dialogWindow.Show() }) } // downloadAndInstallLocalAI downloads and installs the latest LocalAI version func (l *Launcher) downloadAndInstallLocalAI() { if l.app == nil { log.Printf("Cannot download LocalAI: app is nil") return } // First check what the latest version is go func() { log.Printf("Checking for latest LocalAI version...") available, version, err := l.CheckForUpdates() if err != nil { log.Printf("Failed to check for updates: %v", err) l.showDownloadError("Failed to check for latest version", err.Error()) return } if !available { log.Printf("No updates available, but LocalAI is not installed") l.showDownloadError("No Version Available", "Could not determine the latest LocalAI version. Please check your internet connection and try again.") return } log.Printf("Latest version available: %s", version) // Show progress window with the specific version l.showDownloadProgress(version, fmt.Sprintf("Downloading LocalAI %s...", version)) }() } // showDownloadError shows an error dialog for download failures func (l *Launcher) showDownloadError(title, message string) { fyne.DoAndWait(func() { // Create error window errorWindow := l.app.NewWindow("Download Error") errorWindow.Resize(fyne.NewSize(400, 200)) errorWindow.CenterOnScreen() errorWindow.SetCloseIntercept(func() { errorWindow.Close() }) // Error content titleLabel := widget.NewLabel(title) titleLabel.TextStyle = fyne.TextStyle{Bold: true} titleLabel.Alignment = fyne.TextAlignCenter messageLabel := widget.NewLabel(message) messageLabel.Wrapping = fyne.TextWrapWord messageLabel.Alignment = fyne.TextAlignCenter // Close button closeButton := widget.NewButton("Close", func() { errorWindow.Close() }) // Layout content := container.NewVBox( titleLabel, widget.NewSeparator(), messageLabel, widget.NewSeparator(), closeButton, ) errorWindow.SetContent(content) errorWindow.Show() }) } // showDownloadProgress shows a standalone progress window for downloading LocalAI func (l *Launcher) showDownloadProgress(version, title string) { fyne.DoAndWait(func() { // Create progress window progressWindow := l.app.NewWindow("Downloading LocalAI") progressWindow.Resize(fyne.NewSize(400, 250)) progressWindow.CenterOnScreen() progressWindow.SetCloseIntercept(func() { progressWindow.Close() }) // Progress bar progressBar := widget.NewProgressBar() progressBar.SetValue(0) // Status label statusLabel := widget.NewLabel("Preparing download...") // Release notes button releaseNotesButton := widget.NewButton("View Release Notes", func() { releaseNotesURL, err := l.githubReleaseNotesURL(version) if err != nil { log.Printf("Failed to parse URL: %v", err) return } l.app.OpenURL(releaseNotesURL) }) // Progress container progressContainer := container.NewVBox( widget.NewLabel(title), progressBar, statusLabel, widget.NewSeparator(), releaseNotesButton, ) progressWindow.SetContent(progressContainer) progressWindow.Show() // Start download in background go func() { err := l.DownloadUpdate(version, func(progress float64) { // Update progress bar fyne.Do(func() { progressBar.SetValue(progress) percentage := int(progress * 100) statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage)) }) }) // Handle completion fyne.Do(func() { if err != nil { statusLabel.SetText(fmt.Sprintf("Download failed: %v", err)) // Show error dialog dialog.ShowError(err, progressWindow) } else { statusLabel.SetText("Download completed successfully!") progressBar.SetValue(1.0) // Show success dialog dialog.ShowConfirm("Installation Complete", "LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.", func(close bool) { progressWindow.Close() // Update status and refresh systray menu l.updateStatus("LocalAI installed successfully") if l.systray != nil { l.systray.recreateMenu() } }, progressWindow) } }) }() }) } // monitorLogs monitors the output of LocalAI and adds it to the log buffer func (l *Launcher) monitorLogs(reader io.Reader, prefix string) { scanner := bufio.NewScanner(reader) for scanner.Scan() { line := scanner.Text() timestamp := time.Now().Format("15:04:05") logLine := fmt.Sprintf("[%s] %s: %s\n", timestamp, prefix, line) l.logMutex.Lock() l.logBuffer.WriteString(logLine) // Keep log buffer size reasonable if l.logBuffer.Len() > 100000 { // 100KB content := l.logBuffer.String() // Keep last 50KB if len(content) > 50000 { l.logBuffer.Reset() l.logBuffer.WriteString(content[len(content)-50000:]) } } l.logMutex.Unlock() // Write to log file if available if l.logFile != nil { if _, err := l.logFile.WriteString(logLine); err != nil { log.Printf("Failed to write to log file: %v", err) } } fyne.Do(func() { // Notify UI of new log content if l.ui != nil { l.ui.OnLogUpdate(logLine) } // Check for startup completion if strings.Contains(line, "API server listening") { l.updateStatus("LocalAI is running") } }) } } // updateStatus updates the status and notifies UI func (l *Launcher) updateStatus(status string) { select { case l.statusChannel <- status: default: // Channel full, skip } if l.ui != nil { l.ui.UpdateStatus(status) } if l.systray != nil { l.systray.UpdateStatus(status) } } // updateRunningState updates the running state in UI and systray func (l *Launcher) updateRunningState(isRunning bool) { if l.ui != nil { l.ui.UpdateRunningState(isRunning) } if l.systray != nil { l.systray.UpdateRunningState(isRunning) } } // periodicUpdateCheck checks for updates periodically func (l *Launcher) periodicUpdateCheck() { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() for { select { case <-ticker.C: available, version, err := l.CheckForUpdates() if err == nil && available { fyne.Do(func() { l.updateStatus(fmt.Sprintf("Update available: %s", version)) if l.systray != nil { l.systray.NotifyUpdateAvailable(version) } if l.ui != nil { l.ui.NotifyUpdateAvailable(version) } }) } case <-l.ctx.Done(): return } } } // loadConfig loads configuration from file func (l *Launcher) loadConfig() error { homeDir, err := os.UserHomeDir() if err != nil { return fmt.Errorf("failed to get home directory: %w", err) } configPath := filepath.Join(homeDir, ".localai", "launcher.json") log.Printf("Loading config from: %s", configPath) if _, err := os.Stat(configPath); os.IsNotExist(err) { log.Printf("Config file not found, creating default config") // Create default config return l.saveConfig() } // Load existing config configData, err := os.ReadFile(configPath) if err != nil { return fmt.Errorf("failed to read config file: %w", err) } log.Printf("Config file content: %s", string(configData)) log.Printf("loadConfig: about to unmarshal JSON data") if err := json.Unmarshal(configData, l.config); err != nil { return fmt.Errorf("failed to parse config file: %w", err) } log.Printf("loadConfig: JSON unmarshaled successfully") log.Printf("Loaded config: ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s", l.config.ModelsPath, l.config.BackendsPath, l.config.Address, l.config.LogLevel) log.Printf("Environment vars: %v", l.config.EnvironmentVars) return nil } // saveConfig saves configuration to file func (l *Launcher) saveConfig() error { homeDir, err := os.UserHomeDir() if err != nil { return fmt.Errorf("failed to get home directory: %w", err) } configDir := filepath.Join(homeDir, ".localai") if err := os.MkdirAll(configDir, 0755); err != nil { return fmt.Errorf("failed to create config directory: %w", err) } // Marshal config to JSON log.Printf("saveConfig: marshaling config with EnvironmentVars: %v", l.config.EnvironmentVars) configData, err := json.MarshalIndent(l.config, "", " ") if err != nil { return fmt.Errorf("failed to marshal config: %w", err) } log.Printf("saveConfig: JSON marshaled successfully, length: %d", len(configData)) configPath := filepath.Join(configDir, "launcher.json") log.Printf("Saving config to: %s", configPath) log.Printf("Config content: %s", string(configData)) if err := os.WriteFile(configPath, configData, 0644); err != nil { return fmt.Errorf("failed to write config file: %w", err) } log.Printf("Config saved successfully") return nil } ================================================ FILE: cmd/launcher/internal/launcher_suite_test.go ================================================ package launcher_test import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestLauncher(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Launcher Suite") } ================================================ FILE: cmd/launcher/internal/launcher_test.go ================================================ package launcher_test import ( "os" "path/filepath" "strings" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "fyne.io/fyne/v2/app" launcher "github.com/mudler/LocalAI/cmd/launcher/internal" ) var _ = Describe("Launcher", func() { var ( launcherInstance *launcher.Launcher tempDir string ) BeforeEach(func() { var err error tempDir, err = os.MkdirTemp("", "launcher-test-*") Expect(err).ToNot(HaveOccurred()) ui := launcher.NewLauncherUI() app := app.NewWithID("com.localai.launcher") launcherInstance = launcher.NewLauncher(ui, nil, app) }) AfterEach(func() { os.RemoveAll(tempDir) }) Describe("NewLauncher", func() { It("should create a launcher with default configuration", func() { Expect(launcherInstance.GetConfig()).ToNot(BeNil()) }) }) Describe("Initialize", func() { It("should set default paths when not configured", func() { err := launcherInstance.Initialize() Expect(err).ToNot(HaveOccurred()) config := launcherInstance.GetConfig() Expect(config.ModelsPath).ToNot(BeEmpty()) Expect(config.BackendsPath).ToNot(BeEmpty()) }) It("should set default ShowWelcome to true", func() { err := launcherInstance.Initialize() Expect(err).ToNot(HaveOccurred()) config := launcherInstance.GetConfig() Expect(config.ShowWelcome).To(BeTrue()) Expect(config.Address).To(Equal("127.0.0.1:8080")) Expect(config.LogLevel).To(Equal("info")) }) It("should create models and backends directories", func() { // Set custom paths for testing config := launcherInstance.GetConfig() config.ModelsPath = filepath.Join(tempDir, "models") config.BackendsPath = filepath.Join(tempDir, "backends") launcherInstance.SetConfig(config) err := launcherInstance.Initialize() Expect(err).ToNot(HaveOccurred()) // Check if directories were created _, err = os.Stat(config.ModelsPath) Expect(err).ToNot(HaveOccurred()) _, err = os.Stat(config.BackendsPath) Expect(err).ToNot(HaveOccurred()) }) }) Describe("Configuration", func() { It("should get and set configuration", func() { config := launcherInstance.GetConfig() config.ModelsPath = "/test/models" config.BackendsPath = "/test/backends" config.Address = ":9090" config.LogLevel = "debug" err := launcherInstance.SetConfig(config) Expect(err).ToNot(HaveOccurred()) retrievedConfig := launcherInstance.GetConfig() Expect(retrievedConfig.ModelsPath).To(Equal("/test/models")) Expect(retrievedConfig.BackendsPath).To(Equal("/test/backends")) Expect(retrievedConfig.Address).To(Equal(":9090")) Expect(retrievedConfig.LogLevel).To(Equal("debug")) }) }) Describe("WebUI URL", func() { It("should return correct WebUI URL for localhost", func() { config := launcherInstance.GetConfig() config.Address = ":8080" launcherInstance.SetConfig(config) url := launcherInstance.GetWebUIURL() Expect(url).To(Equal("http://localhost:8080")) }) It("should return correct WebUI URL for full address", func() { config := launcherInstance.GetConfig() config.Address = "127.0.0.1:8080" launcherInstance.SetConfig(config) url := launcherInstance.GetWebUIURL() Expect(url).To(Equal("http://127.0.0.1:8080")) }) It("should handle http prefix correctly", func() { config := launcherInstance.GetConfig() config.Address = "http://localhost:8080" launcherInstance.SetConfig(config) url := launcherInstance.GetWebUIURL() Expect(url).To(Equal("http://localhost:8080")) }) }) Describe("Process Management", func() { It("should not be running initially", func() { Expect(launcherInstance.IsRunning()).To(BeFalse()) }) It("should handle start when binary doesn't exist", func() { err := launcherInstance.StartLocalAI() Expect(err).To(HaveOccurred()) // Could be either "not found" or "permission denied" depending on test environment errMsg := err.Error() hasExpectedError := strings.Contains(errMsg, "LocalAI binary") || strings.Contains(errMsg, "permission denied") Expect(hasExpectedError).To(BeTrue(), "Expected error about binary not found or permission denied, got: %s", errMsg) }) It("should handle stop when not running", func() { err := launcherInstance.StopLocalAI() Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("LocalAI is not running")) }) }) Describe("Logs", func() { It("should return empty logs initially", func() { logs := launcherInstance.GetLogs() Expect(logs).To(BeEmpty()) }) }) Describe("Version Management", func() { It("should return empty version when no binary installed", func() { version := launcherInstance.GetCurrentVersion() Expect(version).To(BeEmpty()) // No binary installed in test environment }) It("should handle update checks", func() { // This test would require mocking HTTP responses // For now, we'll just test that the method doesn't panic _, _, err := launcherInstance.CheckForUpdates() // We expect either success or a network error, not a panic if err != nil { // Network error is acceptable in tests Expect(err.Error()).To(ContainSubstring("failed to fetch")) } }) }) }) var _ = Describe("Config", func() { It("should have proper JSON tags", func() { config := &launcher.Config{ ModelsPath: "/test/models", BackendsPath: "/test/backends", Address: ":8080", AutoStart: true, LogLevel: "info", EnvironmentVars: map[string]string{"TEST": "value"}, } Expect(config.ModelsPath).To(Equal("/test/models")) Expect(config.BackendsPath).To(Equal("/test/backends")) Expect(config.Address).To(Equal(":8080")) Expect(config.AutoStart).To(BeTrue()) Expect(config.LogLevel).To(Equal("info")) Expect(config.EnvironmentVars).To(HaveKeyWithValue("TEST", "value")) }) It("should initialize environment variables map", func() { config := &launcher.Config{} Expect(config.EnvironmentVars).To(BeNil()) ui := launcher.NewLauncherUI() app := app.NewWithID("com.localai.launcher") launcher := launcher.NewLauncher(ui, nil, app) err := launcher.Initialize() Expect(err).ToNot(HaveOccurred()) retrievedConfig := launcher.GetConfig() Expect(retrievedConfig.EnvironmentVars).ToNot(BeNil()) Expect(retrievedConfig.EnvironmentVars).To(BeEmpty()) }) }) ================================================ FILE: cmd/launcher/internal/release_manager.go ================================================ package launcher import ( "bufio" "crypto/sha256" "encoding/hex" "encoding/json" "fmt" "io" "log" "net/http" "os" "os/exec" "path/filepath" "runtime" "strings" "time" "github.com/mudler/LocalAI/internal" ) // Release represents a LocalAI release type Release struct { Version string `json:"tag_name"` Name string `json:"name"` Body string `json:"body"` PublishedAt time.Time `json:"published_at"` Assets []Asset `json:"assets"` } // Asset represents a release asset type Asset struct { Name string `json:"name"` BrowserDownloadURL string `json:"browser_download_url"` Size int64 `json:"size"` } // ReleaseManager handles LocalAI release management type ReleaseManager struct { // GitHubOwner is the GitHub repository owner GitHubOwner string // GitHubRepo is the GitHub repository name GitHubRepo string // BinaryPath is where the LocalAI binary is stored locally BinaryPath string // CurrentVersion is the currently installed version CurrentVersion string // ChecksumsPath is where checksums are stored ChecksumsPath string // MetadataPath is where version metadata is stored MetadataPath string // HTTPClient is the HTTP client used for downloads HTTPClient *http.Client } // NewReleaseManager creates a new release manager func NewReleaseManager() *ReleaseManager { homeDir, _ := os.UserHomeDir() binaryPath := filepath.Join(homeDir, ".localai", "bin") checksumsPath := filepath.Join(homeDir, ".localai", "checksums") metadataPath := filepath.Join(homeDir, ".localai", "metadata") return &ReleaseManager{ GitHubOwner: "mudler", GitHubRepo: "LocalAI", BinaryPath: binaryPath, CurrentVersion: internal.PrintableVersion(), ChecksumsPath: checksumsPath, MetadataPath: metadataPath, HTTPClient: &http.Client{ Timeout: 30 * time.Second, }, } } // GetLatestRelease fetches the latest release information from GitHub func (rm *ReleaseManager) GetLatestRelease() (*Release, error) { url := fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", rm.GitHubOwner, rm.GitHubRepo) resp, err := rm.HTTPClient.Get(url) if err != nil { return nil, fmt.Errorf("failed to fetch latest release: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("failed to fetch latest release: status %d", resp.StatusCode) } // Parse the JSON response properly body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("failed to read response body: %w", err) } release := &Release{} if err := json.Unmarshal(body, release); err != nil { return nil, fmt.Errorf("failed to parse JSON response: %w", err) } // Validate the release data if release.Version == "" { return nil, fmt.Errorf("no version found in release data") } return release, nil } // DownloadRelease downloads a specific version of LocalAI func (rm *ReleaseManager) DownloadRelease(version string, progressCallback func(float64)) error { // Ensure the binary directory exists if err := os.MkdirAll(rm.BinaryPath, 0755); err != nil { return fmt.Errorf("failed to create binary directory: %w", err) } // Determine the binary name based on OS and architecture binaryName := rm.GetBinaryName(version) localPath := filepath.Join(rm.BinaryPath, "local-ai") // Download the binary downloadURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/%s", rm.GitHubOwner, rm.GitHubRepo, version, binaryName) if err := rm.downloadFile(downloadURL, localPath, progressCallback); err != nil { return fmt.Errorf("failed to download binary: %w", err) } // Download and verify checksums checksumURL := fmt.Sprintf("https://github.com/%s/%s/releases/download/%s/LocalAI-%s-checksums.txt", rm.GitHubOwner, rm.GitHubRepo, version, version) checksumPath := filepath.Join(rm.BinaryPath, "checksums.txt") manualChecksumPath := filepath.Join(rm.ChecksumsPath, fmt.Sprintf("checksums-%s.txt", version)) // First, check if there's already a checksum file (either manually placed or previously downloaded) // and honor that, skipping download entirely in such case var downloadErr error if _, err := os.Stat(manualChecksumPath); err == nil { log.Printf("Using existing checksums from: %s", manualChecksumPath) checksumPath = manualChecksumPath } else if _, err := os.Stat(checksumPath); err == nil { log.Printf("Using existing checksums from: %s", checksumPath) } else { // No existing checksum file found, try to download downloadErr = rm.downloadFile(checksumURL, checksumPath, nil) if downloadErr != nil { log.Printf("Warning: failed to download checksums: %v", downloadErr) log.Printf("Warning: Checksum verification will be skipped. For security, you can manually place checksums at: %s", manualChecksumPath) log.Printf("Download checksums from: %s", checksumURL) // Continue without verification - log warning but don't fail } } // Verify the checksum if we have a checksum file if _, err := os.Stat(checksumPath); err == nil { if err := rm.VerifyChecksum(localPath, checksumPath, binaryName); err != nil { return fmt.Errorf("checksum verification failed: %w", err) } log.Printf("Checksum verification successful") // Save checksums persistently for future verification if downloadErr == nil { if err := rm.saveChecksums(version, checksumPath, binaryName); err != nil { log.Printf("Warning: failed to save checksums: %v", err) } } } else { log.Printf("Warning: Proceeding without checksum verification") } // Make the binary executable if err := os.Chmod(localPath, 0755); err != nil { return fmt.Errorf("failed to make binary executable: %w", err) } return nil } // GetBinaryName returns the appropriate binary name for the current platform func (rm *ReleaseManager) GetBinaryName(version string) string { versionStr := strings.TrimPrefix(version, "v") os := runtime.GOOS arch := runtime.GOARCH // Map Go arch names to the release naming convention switch arch { case "amd64": arch = "amd64" case "arm64": arch = "arm64" default: arch = "amd64" // fallback } return fmt.Sprintf("local-ai-v%s-%s-%s", versionStr, os, arch) } // downloadFile downloads a file from a URL to a local path with optional progress callback func (rm *ReleaseManager) downloadFile(url, filepath string, progressCallback func(float64)) error { return rm.downloadFileWithRetry(url, filepath, progressCallback, 3) } // downloadFileWithRetry downloads a file from a URL with retry logic func (rm *ReleaseManager) downloadFileWithRetry(url, filepath string, progressCallback func(float64), maxRetries int) error { var lastErr error for attempt := 1; attempt <= maxRetries; attempt++ { if attempt > 1 { log.Printf("Retrying download (attempt %d/%d): %s", attempt, maxRetries, url) time.Sleep(time.Duration(attempt) * time.Second) } resp, err := rm.HTTPClient.Get(url) if err != nil { lastErr = err continue } if resp.StatusCode != http.StatusOK { resp.Body.Close() lastErr = fmt.Errorf("bad status: %s", resp.Status) continue } out, err := os.Create(filepath) if err != nil { resp.Body.Close() return err } // Create a progress reader if callback is provided var reader io.Reader = resp.Body if progressCallback != nil && resp.ContentLength > 0 { reader = &progressReader{ Reader: resp.Body, Total: resp.ContentLength, Callback: progressCallback, } } _, err = io.Copy(out, reader) resp.Body.Close() out.Close() if err != nil { lastErr = err os.Remove(filepath) continue } return nil } return fmt.Errorf("failed after %d attempts: %w", maxRetries, lastErr) } // saveChecksums saves checksums persistently for future verification func (rm *ReleaseManager) saveChecksums(version, checksumPath, binaryName string) error { // Ensure checksums directory exists if err := os.MkdirAll(rm.ChecksumsPath, 0755); err != nil { return fmt.Errorf("failed to create checksums directory: %w", err) } // Read the downloaded checksums file checksumData, err := os.ReadFile(checksumPath) if err != nil { return fmt.Errorf("failed to read checksums file: %w", err) } // Save to persistent location with version info persistentPath := filepath.Join(rm.ChecksumsPath, fmt.Sprintf("checksums-%s.txt", version)) if err := os.WriteFile(persistentPath, checksumData, 0644); err != nil { return fmt.Errorf("failed to write persistent checksums: %w", err) } // Also save a "latest" checksums file for the current version latestPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt") if err := os.WriteFile(latestPath, checksumData, 0644); err != nil { return fmt.Errorf("failed to write latest checksums: %w", err) } // Save version metadata if err := rm.saveVersionMetadata(version); err != nil { log.Printf("Warning: failed to save version metadata: %v", err) } log.Printf("Checksums saved for version %s", version) return nil } // saveVersionMetadata saves the installed version information func (rm *ReleaseManager) saveVersionMetadata(version string) error { // Ensure metadata directory exists if err := os.MkdirAll(rm.MetadataPath, 0755); err != nil { return fmt.Errorf("failed to create metadata directory: %w", err) } // Create metadata structure metadata := struct { Version string `json:"version"` InstalledAt time.Time `json:"installed_at"` BinaryPath string `json:"binary_path"` }{ Version: version, InstalledAt: time.Now(), BinaryPath: rm.GetBinaryPath(), } // Marshal to JSON metadataData, err := json.MarshalIndent(metadata, "", " ") if err != nil { return fmt.Errorf("failed to marshal metadata: %w", err) } // Save metadata file metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json") if err := os.WriteFile(metadataPath, metadataData, 0644); err != nil { return fmt.Errorf("failed to write metadata file: %w", err) } log.Printf("Version metadata saved: %s", version) return nil } // progressReader wraps an io.Reader to provide download progress type progressReader struct { io.Reader Total int64 Current int64 Callback func(float64) } func (pr *progressReader) Read(p []byte) (int, error) { n, err := pr.Reader.Read(p) pr.Current += int64(n) if pr.Callback != nil { progress := float64(pr.Current) / float64(pr.Total) pr.Callback(progress) } return n, err } // VerifyChecksum verifies the downloaded file against the provided checksums func (rm *ReleaseManager) VerifyChecksum(filePath, checksumPath, binaryName string) error { // Calculate the SHA256 of the downloaded file file, err := os.Open(filePath) if err != nil { return fmt.Errorf("failed to open file for checksum: %w", err) } defer file.Close() hasher := sha256.New() if _, err := io.Copy(hasher, file); err != nil { return fmt.Errorf("failed to calculate checksum: %w", err) } calculatedHash := hex.EncodeToString(hasher.Sum(nil)) // Read the checksums file checksumFile, err := os.Open(checksumPath) if err != nil { return fmt.Errorf("failed to open checksums file: %w", err) } defer checksumFile.Close() scanner := bufio.NewScanner(checksumFile) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if strings.Contains(line, binaryName) { parts := strings.Fields(line) if len(parts) >= 2 { expectedHash := parts[0] if calculatedHash == expectedHash { return nil // Checksum verified } return fmt.Errorf("checksum mismatch: expected %s, got %s", expectedHash, calculatedHash) } } } return fmt.Errorf("checksum not found for %s", binaryName) } // GetInstalledVersion returns the currently installed version func (rm *ReleaseManager) GetInstalledVersion() string { // Fallback: Check if the LocalAI binary exists and try to get its version binaryPath := rm.GetBinaryPath() if _, err := os.Stat(binaryPath); os.IsNotExist(err) { return "" // No version installed } // try to get version from metadata if version := rm.loadVersionMetadata(); version != "" { return version } // Try to run the binary to get the version (fallback method) version, err := exec.Command(binaryPath, "--version").Output() if err != nil { // If binary exists but --version fails, try to determine from filename or other means log.Printf("Binary exists but --version failed: %v", err) return "" } stringVersion := strings.TrimSpace(string(version)) stringVersion = strings.TrimRight(stringVersion, "\n") return stringVersion } // loadVersionMetadata loads the installed version from metadata file func (rm *ReleaseManager) loadVersionMetadata() string { metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json") // Check if metadata file exists if _, err := os.Stat(metadataPath); os.IsNotExist(err) { return "" } // Read metadata file metadataData, err := os.ReadFile(metadataPath) if err != nil { log.Printf("Failed to read metadata file: %v", err) return "" } // Parse metadata var metadata struct { Version string `json:"version"` InstalledAt time.Time `json:"installed_at"` BinaryPath string `json:"binary_path"` } if err := json.Unmarshal(metadataData, &metadata); err != nil { log.Printf("Failed to parse metadata file: %v", err) return "" } // Verify that the binary path in metadata matches current binary path if metadata.BinaryPath != rm.GetBinaryPath() { log.Printf("Binary path mismatch in metadata, ignoring") return "" } log.Printf("Loaded version from metadata: %s (installed at %s)", metadata.Version, metadata.InstalledAt.Format("2006-01-02 15:04:05")) return metadata.Version } // GetBinaryPath returns the path to the LocalAI binary func (rm *ReleaseManager) GetBinaryPath() string { return filepath.Join(rm.BinaryPath, "local-ai") } // IsUpdateAvailable checks if an update is available func (rm *ReleaseManager) IsUpdateAvailable() (bool, string, error) { log.Printf("IsUpdateAvailable: checking for updates...") latest, err := rm.GetLatestRelease() if err != nil { log.Printf("IsUpdateAvailable: failed to get latest release: %v", err) return false, "", err } log.Printf("IsUpdateAvailable: latest release version: %s", latest.Version) current := rm.GetInstalledVersion() log.Printf("IsUpdateAvailable: current installed version: %s", current) if current == "" { // No version installed, offer to download latest log.Printf("IsUpdateAvailable: no version installed, offering latest: %s", latest.Version) return true, latest.Version, nil } updateAvailable := latest.Version != current log.Printf("IsUpdateAvailable: update available: %v (latest: %s, current: %s)", updateAvailable, latest.Version, current) return updateAvailable, latest.Version, nil } // IsLocalAIInstalled checks if LocalAI binary exists and is valid func (rm *ReleaseManager) IsLocalAIInstalled() bool { binaryPath := rm.GetBinaryPath() if _, err := os.Stat(binaryPath); os.IsNotExist(err) { return false } // Verify the binary integrity if err := rm.VerifyInstalledBinary(); err != nil { log.Printf("Binary integrity check failed: %v", err) // Remove corrupted binary if removeErr := os.Remove(binaryPath); removeErr != nil { log.Printf("Failed to remove corrupted binary: %v", removeErr) } return false } return true } // VerifyInstalledBinary verifies the installed binary against saved checksums func (rm *ReleaseManager) VerifyInstalledBinary() error { binaryPath := rm.GetBinaryPath() // Check if we have saved checksums latestChecksumsPath := filepath.Join(rm.ChecksumsPath, "checksums-latest.txt") if _, err := os.Stat(latestChecksumsPath); os.IsNotExist(err) { return fmt.Errorf("no saved checksums found") } // Get the binary name for the current version from metadata currentVersion := rm.loadVersionMetadata() if currentVersion == "" { return fmt.Errorf("cannot determine current version from metadata") } binaryName := rm.GetBinaryName(currentVersion) // Verify against saved checksums return rm.VerifyChecksum(binaryPath, latestChecksumsPath, binaryName) } // CleanupPartialDownloads removes any partial or corrupted downloads func (rm *ReleaseManager) CleanupPartialDownloads() error { binaryPath := rm.GetBinaryPath() // Check if binary exists but is corrupted if _, err := os.Stat(binaryPath); err == nil { // Binary exists, verify it if verifyErr := rm.VerifyInstalledBinary(); verifyErr != nil { log.Printf("Found corrupted binary, removing: %v", verifyErr) if removeErr := os.Remove(binaryPath); removeErr != nil { log.Printf("Failed to remove corrupted binary: %v", removeErr) } // Clear metadata since binary is corrupted rm.clearVersionMetadata() } } // Clean up any temporary checksum files tempChecksumsPath := filepath.Join(rm.BinaryPath, "checksums.txt") if _, err := os.Stat(tempChecksumsPath); err == nil { if removeErr := os.Remove(tempChecksumsPath); removeErr != nil { log.Printf("Failed to remove temporary checksums: %v", removeErr) } } return nil } // clearVersionMetadata clears the version metadata (used when binary is corrupted or removed) func (rm *ReleaseManager) clearVersionMetadata() { metadataPath := filepath.Join(rm.MetadataPath, "installed-version.json") if err := os.Remove(metadataPath); err != nil && !os.IsNotExist(err) { log.Printf("Failed to clear version metadata: %v", err) } else { log.Printf("Version metadata cleared") } } ================================================ FILE: cmd/launcher/internal/release_manager_test.go ================================================ package launcher_test import ( "os" "path/filepath" "runtime" "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" launcher "github.com/mudler/LocalAI/cmd/launcher/internal" ) var _ = Describe("ReleaseManager", func() { var ( rm *launcher.ReleaseManager tempDir string ) BeforeEach(func() { var err error tempDir, err = os.MkdirTemp("", "launcher-test-*") Expect(err).ToNot(HaveOccurred()) rm = launcher.NewReleaseManager() // Override binary path for testing rm.BinaryPath = tempDir }) AfterEach(func() { os.RemoveAll(tempDir) }) Describe("NewReleaseManager", func() { It("should create a release manager with correct defaults", func() { newRM := launcher.NewReleaseManager() Expect(newRM.GitHubOwner).To(Equal("mudler")) Expect(newRM.GitHubRepo).To(Equal("LocalAI")) Expect(newRM.BinaryPath).To(ContainSubstring(".localai")) Expect(newRM.HTTPClient).ToNot(BeNil()) Expect(newRM.HTTPClient.Timeout).To(Equal(30 * time.Second)) }) }) Describe("GetBinaryName", func() { It("should return correct binary name for current platform", func() { binaryName := rm.GetBinaryName("v3.4.0") expectedOS := runtime.GOOS expectedArch := runtime.GOARCH expected := "local-ai-v3.4.0-" + expectedOS + "-" + expectedArch Expect(binaryName).To(Equal(expected)) }) It("should handle version with and without 'v' prefix", func() { withV := rm.GetBinaryName("v3.4.0") withoutV := rm.GetBinaryName("3.4.0") // Both should produce the same result Expect(withV).To(Equal(withoutV)) }) }) Describe("GetBinaryPath", func() { It("should return the correct binary path", func() { path := rm.GetBinaryPath() expected := filepath.Join(tempDir, "local-ai") Expect(path).To(Equal(expected)) }) }) Describe("GetInstalledVersion", func() { It("should return empty when no binary exists", func() { version := rm.GetInstalledVersion() Expect(version).To(BeEmpty()) // No binary installed in test }) It("should return empty version when binary exists but no metadata", func() { // Create a fake binary for testing err := os.MkdirAll(rm.BinaryPath, 0755) Expect(err).ToNot(HaveOccurred()) binaryPath := rm.GetBinaryPath() err = os.WriteFile(binaryPath, []byte("fake binary"), 0755) Expect(err).ToNot(HaveOccurred()) version := rm.GetInstalledVersion() Expect(version).To(BeEmpty()) }) }) Context("with mocked responses", func() { // Note: In a real implementation, we'd mock HTTP responses // For now, we'll test the structure and error handling Describe("GetLatestRelease", func() { It("should handle network errors gracefully", func() { // This test would require mocking HTTP client // For demonstration, we're just testing the method exists _, err := rm.GetLatestRelease() // We expect either success or a network error, not a panic // In a real test, we'd mock the HTTP response if err != nil { Expect(err.Error()).To(ContainSubstring("failed to fetch")) } }) }) Describe("DownloadRelease", func() { It("should create binary directory if it doesn't exist", func() { // Remove the temp directory to test creation os.RemoveAll(tempDir) // This will fail due to network, but should create the directory rm.DownloadRelease("v3.4.0", nil) // Check if directory was created _, err := os.Stat(tempDir) Expect(err).ToNot(HaveOccurred()) }) }) }) Describe("VerifyChecksum functionality", func() { var ( testFile string checksumFile string ) BeforeEach(func() { testFile = filepath.Join(tempDir, "test-binary") checksumFile = filepath.Join(tempDir, "checksums.txt") }) It("should verify checksums correctly", func() { // Create a test file with known content testContent := []byte("test content for checksum") err := os.WriteFile(testFile, testContent, 0644) Expect(err).ToNot(HaveOccurred()) // Calculate expected SHA256 // This is a simplified test - in practice we'd use the actual checksum checksumContent := "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855 test-binary\n" err = os.WriteFile(checksumFile, []byte(checksumContent), 0644) Expect(err).ToNot(HaveOccurred()) // Test checksum verification // Note: This will fail because our content doesn't match the empty string hash // In a real test, we'd calculate the actual hash err = rm.VerifyChecksum(testFile, checksumFile, "test-binary") // We expect this to fail since we're using a dummy checksum Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("checksum mismatch")) }) It("should handle missing checksum file", func() { // Create test file but no checksum file err := os.WriteFile(testFile, []byte("test"), 0644) Expect(err).ToNot(HaveOccurred()) err = rm.VerifyChecksum(testFile, checksumFile, "test-binary") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("failed to open checksums file")) }) It("should handle missing binary in checksums", func() { // Create files but checksum doesn't contain our binary err := os.WriteFile(testFile, []byte("test"), 0644) Expect(err).ToNot(HaveOccurred()) checksumContent := "hash other-binary\n" err = os.WriteFile(checksumFile, []byte(checksumContent), 0644) Expect(err).ToNot(HaveOccurred()) err = rm.VerifyChecksum(testFile, checksumFile, "test-binary") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("checksum not found")) }) }) }) ================================================ FILE: cmd/launcher/internal/systray_manager.go ================================================ package launcher import ( "fmt" "log" "net/url" "fyne.io/fyne/v2" "fyne.io/fyne/v2/container" "fyne.io/fyne/v2/dialog" "fyne.io/fyne/v2/driver/desktop" "fyne.io/fyne/v2/widget" ) // SystrayManager manages the system tray functionality type SystrayManager struct { launcher *Launcher window fyne.Window app fyne.App desk desktop.App // Menu items that need dynamic updates startStopItem *fyne.MenuItem hasUpdateAvailable bool latestVersion string icon *fyne.StaticResource } // NewSystrayManager creates a new systray manager func NewSystrayManager(launcher *Launcher, window fyne.Window, desktop desktop.App, app fyne.App, icon *fyne.StaticResource) *SystrayManager { sm := &SystrayManager{ launcher: launcher, window: window, app: app, desk: desktop, icon: icon, } sm.setupMenu(desktop) return sm } // setupMenu sets up the system tray menu func (sm *SystrayManager) setupMenu(desk desktop.App) { sm.desk = desk // Create the start/stop toggle item sm.startStopItem = fyne.NewMenuItem("Start LocalAI", func() { sm.toggleLocalAI() }) desk.SetSystemTrayIcon(sm.icon) // Initialize the menu state using recreateMenu sm.recreateMenu() } // toggleLocalAI starts or stops LocalAI based on current state func (sm *SystrayManager) toggleLocalAI() { if sm.launcher.IsRunning() { go func() { if err := sm.launcher.StopLocalAI(); err != nil { log.Printf("Failed to stop LocalAI: %v", err) sm.showErrorDialog("Failed to Stop LocalAI", err.Error()) } }() } else { go func() { if err := sm.launcher.StartLocalAI(); err != nil { log.Printf("Failed to start LocalAI: %v", err) sm.showStartupErrorDialog(err) } }() } } // openWebUI opens the LocalAI WebUI in the default browser func (sm *SystrayManager) openWebUI() { if !sm.launcher.IsRunning() { return // LocalAI is not running } webURL := sm.launcher.GetWebUIURL() if parsedURL, err := url.Parse(webURL); err == nil { sm.app.OpenURL(parsedURL) } } // openDocumentation opens the LocalAI documentation func (sm *SystrayManager) openDocumentation() { if parsedURL, err := url.Parse("https://localai.io"); err == nil { sm.app.OpenURL(parsedURL) } } // updateStartStopItem updates the start/stop menu item based on current state func (sm *SystrayManager) updateStartStopItem() { // Since Fyne menu items can't change text dynamically, we recreate the menu sm.recreateMenu() } // recreateMenu recreates the entire menu with updated state func (sm *SystrayManager) recreateMenu() { if sm.desk == nil { return } // Determine the action based on LocalAI installation and running state var actionItem *fyne.MenuItem if !sm.launcher.GetReleaseManager().IsLocalAIInstalled() { // LocalAI not installed - show install option actionItem = fyne.NewMenuItem("📥 Install Latest Version", func() { sm.launcher.showDownloadLocalAIDialog() }) } else if sm.launcher.IsRunning() { // LocalAI is running - show stop option actionItem = fyne.NewMenuItem("🛑 Stop LocalAI", func() { sm.toggleLocalAI() }) } else { // LocalAI is installed but not running - show start option actionItem = fyne.NewMenuItem("▶️ Start LocalAI", func() { sm.toggleLocalAI() }) } menuItems := []*fyne.MenuItem{} // Add status at the top (clickable for details) status := sm.launcher.GetLastStatus() statusText := sm.truncateText(status, 30) statusItem := fyne.NewMenuItem("📊 Status: "+statusText, func() { sm.showStatusDetails(status, "") }) menuItems = append(menuItems, statusItem) // Only show version if LocalAI is installed if sm.launcher.GetReleaseManager().IsLocalAIInstalled() { version := sm.launcher.GetCurrentVersion() versionText := sm.truncateText(version, 25) versionItem := fyne.NewMenuItem("🔧 Version: "+versionText, func() { sm.showStatusDetails(status, version) }) menuItems = append(menuItems, versionItem) } menuItems = append(menuItems, fyne.NewMenuItemSeparator()) // Add update notification if available if sm.hasUpdateAvailable { updateItem := fyne.NewMenuItem("🔔 New version available ("+sm.latestVersion+")", func() { sm.downloadUpdate() }) menuItems = append(menuItems, updateItem) menuItems = append(menuItems, fyne.NewMenuItemSeparator()) } // Core actions menuItems = append(menuItems, actionItem, ) // Only show WebUI option if LocalAI is installed if sm.launcher.GetReleaseManager().IsLocalAIInstalled() && sm.launcher.IsRunning() { menuItems = append(menuItems, fyne.NewMenuItem("Open WebUI", func() { sm.openWebUI() }), ) } menuItems = append(menuItems, fyne.NewMenuItemSeparator(), fyne.NewMenuItem("Check for Updates", func() { sm.checkForUpdates() }), fyne.NewMenuItemSeparator(), fyne.NewMenuItem("Settings", func() { sm.showSettings() }), fyne.NewMenuItem("Show Welcome Window", func() { sm.showWelcomeWindow() }), fyne.NewMenuItem("Open Data Folder", func() { sm.openDataFolder() }), fyne.NewMenuItemSeparator(), fyne.NewMenuItem("Documentation", func() { sm.openDocumentation() }), fyne.NewMenuItemSeparator(), fyne.NewMenuItem("Quit", func() { // Perform cleanup before quitting if err := sm.launcher.Shutdown(); err != nil { log.Printf("Error during shutdown: %v", err) } sm.app.Quit() }), ) menu := fyne.NewMenu("LocalAI", menuItems...) sm.desk.SetSystemTrayMenu(menu) } // UpdateRunningState updates the systray based on running state func (sm *SystrayManager) UpdateRunningState(isRunning bool) { sm.updateStartStopItem() } // UpdateStatus updates the systray menu to reflect status changes func (sm *SystrayManager) UpdateStatus(status string) { sm.recreateMenu() } // checkForUpdates checks for available updates func (sm *SystrayManager) checkForUpdates() { go func() { log.Printf("Checking for updates...") available, version, err := sm.launcher.CheckForUpdates() if err != nil { log.Printf("Failed to check for updates: %v", err) return } log.Printf("Update check result: available=%v, version=%s", available, version) if available { sm.hasUpdateAvailable = true sm.latestVersion = version sm.recreateMenu() } }() } // downloadUpdate downloads the latest update func (sm *SystrayManager) downloadUpdate() { if !sm.hasUpdateAvailable { return } // Show progress window sm.showDownloadProgress(sm.latestVersion) } // showSettings shows the settings window func (sm *SystrayManager) showSettings() { sm.window.Show() sm.window.RequestFocus() } // showWelcomeWindow shows the welcome window func (sm *SystrayManager) showWelcomeWindow() { if sm.launcher.GetUI() != nil { sm.launcher.GetUI().ShowWelcomeWindow() } } // openDataFolder opens the data folder in file manager func (sm *SystrayManager) openDataFolder() { dataPath := sm.launcher.GetDataPath() if parsedURL, err := url.Parse("file://" + dataPath); err == nil { sm.app.OpenURL(parsedURL) } } // NotifyUpdateAvailable sets update notification in systray func (sm *SystrayManager) NotifyUpdateAvailable(version string) { sm.hasUpdateAvailable = true sm.latestVersion = version sm.recreateMenu() } // truncateText truncates text to specified length and adds ellipsis if needed func (sm *SystrayManager) truncateText(text string, maxLength int) string { if len(text) <= maxLength { return text } return text[:maxLength-3] + "..." } // showStatusDetails shows a detailed status window with full information func (sm *SystrayManager) showStatusDetails(status, version string) { fyne.DoAndWait(func() { // Create status details window statusWindow := sm.app.NewWindow("LocalAI Status Details") statusWindow.Resize(fyne.NewSize(500, 400)) statusWindow.CenterOnScreen() // Status information statusLabel := widget.NewLabel("Current Status:") statusValue := widget.NewLabel(status) statusValue.Wrapping = fyne.TextWrapWord // Version information (only show if version exists) var versionContainer fyne.CanvasObject if version != "" { versionLabel := widget.NewLabel("Installed Version:") versionValue := widget.NewLabel(version) versionValue.Wrapping = fyne.TextWrapWord versionContainer = container.NewVBox(versionLabel, versionValue) } // Running state runningLabel := widget.NewLabel("Running State:") runningValue := widget.NewLabel("") if sm.launcher.IsRunning() { runningValue.SetText("🟢 Running") } else { runningValue.SetText("🔴 Stopped") } // WebUI URL webuiLabel := widget.NewLabel("WebUI URL:") webuiValue := widget.NewLabel(sm.launcher.GetWebUIURL()) webuiValue.Wrapping = fyne.TextWrapWord // Recent logs (last 20 lines) logsLabel := widget.NewLabel("Recent Logs:") logsText := widget.NewMultiLineEntry() logsText.SetText(sm.launcher.GetRecentLogs()) logsText.Wrapping = fyne.TextWrapWord logsText.Disable() // Make it read-only // Buttons closeButton := widget.NewButton("Close", func() { statusWindow.Close() }) refreshButton := widget.NewButton("Refresh", func() { // Refresh the status information statusValue.SetText(sm.launcher.GetLastStatus()) // Note: Version refresh is not implemented for simplicity // The version will be updated when the status details window is reopened if sm.launcher.IsRunning() { runningValue.SetText("🟢 Running") } else { runningValue.SetText("🔴 Stopped") } logsText.SetText(sm.launcher.GetRecentLogs()) }) openWebUIButton := widget.NewButton("Open WebUI", func() { sm.openWebUI() }) // Layout buttons := container.NewHBox(closeButton, refreshButton, openWebUIButton) // Build info container dynamically infoItems := []fyne.CanvasObject{ statusLabel, statusValue, widget.NewSeparator(), } // Add version section if it exists if versionContainer != nil { infoItems = append(infoItems, versionContainer, widget.NewSeparator()) } infoItems = append(infoItems, runningLabel, runningValue, widget.NewSeparator(), webuiLabel, webuiValue, ) infoContainer := container.NewVBox(infoItems...) content := container.NewVBox( infoContainer, widget.NewSeparator(), logsLabel, logsText, widget.NewSeparator(), buttons, ) statusWindow.SetContent(content) statusWindow.Show() }) } // showErrorDialog shows a simple error dialog func (sm *SystrayManager) showErrorDialog(title, message string) { fyne.DoAndWait(func() { dialog.ShowError(fmt.Errorf("%s", message), sm.window) }) } // showStartupErrorDialog shows a detailed error dialog with process logs func (sm *SystrayManager) showStartupErrorDialog(err error) { fyne.DoAndWait(func() { // Get the recent process logs (more useful for debugging) logs := sm.launcher.GetRecentLogs() // Create error window errorWindow := sm.app.NewWindow("LocalAI Startup Failed") errorWindow.Resize(fyne.NewSize(600, 500)) errorWindow.CenterOnScreen() // Error message errorLabel := widget.NewLabel(fmt.Sprintf("Failed to start LocalAI:\n%s", err.Error())) errorLabel.Wrapping = fyne.TextWrapWord // Logs display logsLabel := widget.NewLabel("Process Logs:") logsText := widget.NewMultiLineEntry() logsText.SetText(logs) logsText.Wrapping = fyne.TextWrapWord logsText.Disable() // Make it read-only // Buttons closeButton := widget.NewButton("Close", func() { errorWindow.Close() }) retryButton := widget.NewButton("Retry", func() { errorWindow.Close() // Try to start again go func() { if retryErr := sm.launcher.StartLocalAI(); retryErr != nil { sm.showStartupErrorDialog(retryErr) } }() }) openLogsButton := widget.NewButton("Open Logs Folder", func() { sm.openDataFolder() }) // Layout buttons := container.NewHBox(closeButton, retryButton, openLogsButton) content := container.NewVBox( errorLabel, widget.NewSeparator(), logsLabel, logsText, widget.NewSeparator(), buttons, ) errorWindow.SetContent(content) errorWindow.Show() }) } // showDownloadProgress shows a progress window for downloading updates func (sm *SystrayManager) showDownloadProgress(version string) { // Create a new window for download progress progressWindow := sm.app.NewWindow("Downloading LocalAI Update") progressWindow.Resize(fyne.NewSize(400, 250)) progressWindow.CenterOnScreen() // Progress bar progressBar := widget.NewProgressBar() progressBar.SetValue(0) // Status label statusLabel := widget.NewLabel("Preparing download...") // Release notes button releaseNotesButton := widget.NewButton("View Release Notes", func() { releaseNotesURL, err := sm.launcher.githubReleaseNotesURL(version) if err != nil { log.Printf("Failed to parse URL: %v", err) return } sm.app.OpenURL(releaseNotesURL) }) // Progress container progressContainer := container.NewVBox( widget.NewLabel(fmt.Sprintf("Downloading LocalAI version %s", version)), progressBar, statusLabel, widget.NewSeparator(), releaseNotesButton, ) progressWindow.SetContent(progressContainer) progressWindow.Show() // Start download in background go func() { err := sm.launcher.DownloadUpdate(version, func(progress float64) { // Update progress bar fyne.Do(func() { progressBar.SetValue(progress) percentage := int(progress * 100) statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage)) }) }) // Handle completion fyne.Do(func() { if err != nil { statusLabel.SetText(fmt.Sprintf("Download failed: %v", err)) // Show error dialog dialog.ShowError(err, progressWindow) } else { statusLabel.SetText("Download completed successfully!") progressBar.SetValue(1.0) // Show restart dialog dialog.ShowConfirm("Update Downloaded", "LocalAI has been updated successfully. Please restart the launcher to use the new version.", func(restart bool) { if restart { sm.app.Quit() } progressWindow.Close() }, progressWindow) } }) // Update systray menu if err == nil { sm.hasUpdateAvailable = false sm.latestVersion = "" sm.recreateMenu() } }() } ================================================ FILE: cmd/launcher/internal/ui.go ================================================ package launcher import ( "fmt" "log" "net/url" "fyne.io/fyne/v2" "fyne.io/fyne/v2/container" "fyne.io/fyne/v2/dialog" "fyne.io/fyne/v2/widget" ) // EnvVar represents an environment variable type EnvVar struct { Key string Value string } // LauncherUI handles the user interface type LauncherUI struct { // Status display statusLabel *widget.Label versionLabel *widget.Label // Control buttons startStopButton *widget.Button webUIButton *widget.Button updateButton *widget.Button downloadButton *widget.Button // Configuration modelsPathEntry *widget.Entry backendsPathEntry *widget.Entry addressEntry *widget.Entry logLevelSelect *widget.Select startOnBootCheck *widget.Check // Environment Variables envVarsData []EnvVar newEnvKeyEntry *widget.Entry newEnvValueEntry *widget.Entry updateEnvironmentDisplay func() // Logs logText *widget.Entry // Progress progressBar *widget.ProgressBar // Update management latestVersion string // Reference to launcher launcher *Launcher } // NewLauncherUI creates a new UI instance func NewLauncherUI() *LauncherUI { return &LauncherUI{ statusLabel: widget.NewLabel("Initializing..."), versionLabel: widget.NewLabel("Version: Unknown"), startStopButton: widget.NewButton("Start LocalAI", nil), webUIButton: widget.NewButton("Open WebUI", nil), updateButton: widget.NewButton("Check for Updates", nil), modelsPathEntry: widget.NewEntry(), backendsPathEntry: widget.NewEntry(), addressEntry: widget.NewEntry(), logLevelSelect: widget.NewSelect([]string{"error", "warn", "info", "debug", "trace"}, nil), startOnBootCheck: widget.NewCheck("Start LocalAI on system boot", nil), logText: widget.NewMultiLineEntry(), progressBar: widget.NewProgressBar(), envVarsData: []EnvVar{}, // Initialize the environment variables slice } } // CreateMainUI creates the main UI layout func (ui *LauncherUI) CreateMainUI(launcher *Launcher) *fyne.Container { ui.launcher = launcher ui.setupBindings() // Main tab with status and controls // Configuration is now the main content configTab := ui.createConfigTab() // Create a simple container instead of tabs since we only have settings tabs := container.NewVBox( widget.NewCard("LocalAI Launcher Settings", "", configTab), ) return tabs } // createConfigTab creates the configuration tab func (ui *LauncherUI) createConfigTab() *fyne.Container { // Path configuration pathsCard := widget.NewCard("Paths", "", container.NewGridWithColumns(2, widget.NewLabel("Models Path:"), ui.modelsPathEntry, widget.NewLabel("Backends Path:"), ui.backendsPathEntry, )) // Server configuration serverCard := widget.NewCard("Server", "", container.NewVBox( container.NewGridWithColumns(2, widget.NewLabel("Address:"), ui.addressEntry, widget.NewLabel("Log Level:"), ui.logLevelSelect, ), ui.startOnBootCheck, )) // Save button saveButton := widget.NewButton("Save Configuration", func() { ui.saveConfiguration() }) // Environment Variables section envCard := ui.createEnvironmentSection() return container.NewVBox( pathsCard, serverCard, envCard, saveButton, ) } // createEnvironmentSection creates the environment variables section for the config tab func (ui *LauncherUI) createEnvironmentSection() *fyne.Container { // Initialize environment variables widgets ui.newEnvKeyEntry = widget.NewEntry() ui.newEnvKeyEntry.SetPlaceHolder("Environment Variable Name") ui.newEnvValueEntry = widget.NewEntry() ui.newEnvValueEntry.SetPlaceHolder("Environment Variable Value") // Add button addButton := widget.NewButton("Add Environment Variable", func() { ui.addEnvironmentVariable() }) // Environment variables list with delete buttons ui.envVarsData = []EnvVar{} // Create container for environment variables envVarsContainer := container.NewVBox() // Update function to rebuild the environment variables display ui.updateEnvironmentDisplay = func() { envVarsContainer.Objects = nil for i, envVar := range ui.envVarsData { index := i // Capture index for closure // Create row with label and delete button envLabel := widget.NewLabel(fmt.Sprintf("%s = %s", envVar.Key, envVar.Value)) deleteBtn := widget.NewButton("Delete", func() { ui.confirmDeleteEnvironmentVariable(index) }) deleteBtn.Importance = widget.DangerImportance row := container.NewBorder(nil, nil, nil, deleteBtn, envLabel) envVarsContainer.Add(row) } envVarsContainer.Refresh() } // Create a scrollable container for the environment variables envScroll := container.NewScroll(envVarsContainer) envScroll.SetMinSize(fyne.NewSize(400, 150)) // Input section for adding new environment variables inputSection := container.NewVBox( container.NewGridWithColumns(2, ui.newEnvKeyEntry, ui.newEnvValueEntry, ), addButton, ) // Environment variables card envCard := widget.NewCard("Environment Variables", "", container.NewVBox( inputSection, widget.NewSeparator(), envScroll, )) return container.NewVBox(envCard) } // addEnvironmentVariable adds a new environment variable func (ui *LauncherUI) addEnvironmentVariable() { key := ui.newEnvKeyEntry.Text value := ui.newEnvValueEntry.Text log.Printf("addEnvironmentVariable: attempting to add %s=%s", key, value) log.Printf("addEnvironmentVariable: current ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData) if key == "" { log.Printf("addEnvironmentVariable: key is empty, showing error") dialog.ShowError(fmt.Errorf("environment variable name cannot be empty"), ui.launcher.window) return } // Check if key already exists for _, envVar := range ui.envVarsData { if envVar.Key == key { log.Printf("addEnvironmentVariable: key %s already exists, showing error", key) dialog.ShowError(fmt.Errorf("environment variable '%s' already exists", key), ui.launcher.window) return } } log.Printf("addEnvironmentVariable: adding new env var %s=%s", key, value) ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value}) log.Printf("addEnvironmentVariable: after adding, ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData) fyne.Do(func() { if ui.updateEnvironmentDisplay != nil { ui.updateEnvironmentDisplay() } // Clear input fields ui.newEnvKeyEntry.SetText("") ui.newEnvValueEntry.SetText("") }) log.Printf("addEnvironmentVariable: calling saveEnvironmentVariables") // Save to configuration ui.saveEnvironmentVariables() } // removeEnvironmentVariable removes an environment variable by index func (ui *LauncherUI) removeEnvironmentVariable(index int) { if index >= 0 && index < len(ui.envVarsData) { ui.envVarsData = append(ui.envVarsData[:index], ui.envVarsData[index+1:]...) fyne.Do(func() { if ui.updateEnvironmentDisplay != nil { ui.updateEnvironmentDisplay() } }) ui.saveEnvironmentVariables() } } // saveEnvironmentVariables saves environment variables to the configuration func (ui *LauncherUI) saveEnvironmentVariables() { if ui.launcher == nil { log.Printf("saveEnvironmentVariables: launcher is nil") return } config := ui.launcher.GetConfig() log.Printf("saveEnvironmentVariables: before - Environment vars: %v", config.EnvironmentVars) config.EnvironmentVars = make(map[string]string) for _, envVar := range ui.envVarsData { config.EnvironmentVars[envVar.Key] = envVar.Value log.Printf("saveEnvironmentVariables: adding %s=%s", envVar.Key, envVar.Value) } log.Printf("saveEnvironmentVariables: after - Environment vars: %v", config.EnvironmentVars) log.Printf("saveEnvironmentVariables: calling SetConfig with %d environment variables", len(config.EnvironmentVars)) err := ui.launcher.SetConfig(config) if err != nil { log.Printf("saveEnvironmentVariables: failed to save config: %v", err) } else { log.Printf("saveEnvironmentVariables: config saved successfully") } } // confirmDeleteEnvironmentVariable shows confirmation dialog for deleting an environment variable func (ui *LauncherUI) confirmDeleteEnvironmentVariable(index int) { if index >= 0 && index < len(ui.envVarsData) { envVar := ui.envVarsData[index] dialog.ShowConfirm("Remove Environment Variable", fmt.Sprintf("Remove environment variable '%s'?", envVar.Key), func(remove bool) { if remove { ui.removeEnvironmentVariable(index) } }, ui.launcher.window) } } // setupBindings sets up event handlers for UI elements func (ui *LauncherUI) setupBindings() { // Start/Stop button ui.startStopButton.OnTapped = func() { if ui.launcher.IsRunning() { ui.stopLocalAI() } else { ui.startLocalAI() } } // WebUI button ui.webUIButton.OnTapped = func() { ui.openWebUI() } ui.webUIButton.Disable() // Disabled until LocalAI is running // Update button ui.updateButton.OnTapped = func() { ui.checkForUpdates() } // Log level selection ui.logLevelSelect.OnChanged = func(selected string) { if ui.launcher != nil { config := ui.launcher.GetConfig() config.LogLevel = selected ui.launcher.SetConfig(config) } } } // startLocalAI starts the LocalAI service func (ui *LauncherUI) startLocalAI() { fyne.Do(func() { ui.startStopButton.Disable() }) ui.UpdateStatus("Starting LocalAI...") go func() { err := ui.launcher.StartLocalAI() if err != nil { ui.UpdateStatus("Failed to start: " + err.Error()) fyne.DoAndWait(func() { dialog.ShowError(err, ui.launcher.window) }) } else { fyne.Do(func() { ui.startStopButton.SetText("Stop LocalAI") ui.webUIButton.Enable() }) } fyne.Do(func() { ui.startStopButton.Enable() }) }() } // stopLocalAI stops the LocalAI service func (ui *LauncherUI) stopLocalAI() { fyne.Do(func() { ui.startStopButton.Disable() }) ui.UpdateStatus("Stopping LocalAI...") go func() { err := ui.launcher.StopLocalAI() if err != nil { fyne.DoAndWait(func() { dialog.ShowError(err, ui.launcher.window) }) } else { fyne.Do(func() { ui.startStopButton.SetText("Start LocalAI") ui.webUIButton.Disable() }) } fyne.Do(func() { ui.startStopButton.Enable() }) }() } // openWebUI opens the LocalAI WebUI in the default browser func (ui *LauncherUI) openWebUI() { webURL := ui.launcher.GetWebUIURL() parsedURL, err := url.Parse(webURL) if err != nil { dialog.ShowError(err, ui.launcher.window) return } // Open URL in default browser fyne.CurrentApp().OpenURL(parsedURL) } // saveConfiguration saves the current configuration func (ui *LauncherUI) saveConfiguration() { log.Printf("saveConfiguration: starting to save configuration") config := ui.launcher.GetConfig() log.Printf("saveConfiguration: current config Environment vars: %v", config.EnvironmentVars) log.Printf("saveConfiguration: ui.envVarsData has %d items: %v", len(ui.envVarsData), ui.envVarsData) config.ModelsPath = ui.modelsPathEntry.Text config.BackendsPath = ui.backendsPathEntry.Text config.Address = ui.addressEntry.Text config.LogLevel = ui.logLevelSelect.Selected config.StartOnBoot = ui.startOnBootCheck.Checked // Ensure environment variables are included in the configuration config.EnvironmentVars = make(map[string]string) for _, envVar := range ui.envVarsData { config.EnvironmentVars[envVar.Key] = envVar.Value log.Printf("saveConfiguration: adding env var %s=%s", envVar.Key, envVar.Value) } log.Printf("saveConfiguration: final config Environment vars: %v", config.EnvironmentVars) err := ui.launcher.SetConfig(config) if err != nil { log.Printf("saveConfiguration: failed to save config: %v", err) dialog.ShowError(err, ui.launcher.window) } else { log.Printf("saveConfiguration: config saved successfully") dialog.ShowInformation("Configuration", "Configuration saved successfully", ui.launcher.window) } } // checkForUpdates checks for available updates func (ui *LauncherUI) checkForUpdates() { fyne.Do(func() { ui.updateButton.Disable() }) ui.UpdateStatus("Checking for updates...") go func() { available, version, err := ui.launcher.CheckForUpdates() if err != nil { ui.UpdateStatus("Failed to check updates: " + err.Error()) fyne.DoAndWait(func() { dialog.ShowError(err, ui.launcher.window) }) } else if available { ui.latestVersion = version // Store the latest version ui.UpdateStatus("Update available: " + version) fyne.Do(func() { if ui.downloadButton != nil { ui.downloadButton.Enable() } }) ui.NotifyUpdateAvailable(version) } else { ui.UpdateStatus("No updates available") fyne.DoAndWait(func() { dialog.ShowInformation("Updates", "You are running the latest version", ui.launcher.window) }) } fyne.Do(func() { ui.updateButton.Enable() }) }() } // downloadUpdate downloads the latest update func (ui *LauncherUI) downloadUpdate() { // Use stored version or check for updates version := ui.latestVersion if version == "" { _, v, err := ui.launcher.CheckForUpdates() if err != nil { dialog.ShowError(err, ui.launcher.window) return } version = v ui.latestVersion = version } if version == "" { dialog.ShowError(fmt.Errorf("no version information available"), ui.launcher.window) return } // Disable buttons during download if ui.downloadButton != nil { fyne.Do(func() { ui.downloadButton.Disable() }) } fyne.Do(func() { ui.progressBar.Show() ui.progressBar.SetValue(0) }) ui.UpdateStatus("Downloading update " + version + "...") go func() { err := ui.launcher.DownloadUpdate(version, func(progress float64) { // Update progress bar fyne.Do(func() { ui.progressBar.SetValue(progress) }) // Update status with percentage percentage := int(progress * 100) ui.UpdateStatus(fmt.Sprintf("Downloading update %s... %d%%", version, percentage)) }) fyne.Do(func() { ui.progressBar.Hide() }) // Re-enable buttons after download if ui.downloadButton != nil { fyne.Do(func() { ui.downloadButton.Enable() }) } if err != nil { fyne.DoAndWait(func() { ui.UpdateStatus("Failed to download update: " + err.Error()) dialog.ShowError(err, ui.launcher.window) }) } else { fyne.DoAndWait(func() { ui.UpdateStatus("Update downloaded successfully") dialog.ShowInformation("Update", "Update downloaded successfully. Please restart the launcher to use the new version.", ui.launcher.window) }) } }() } // UpdateStatus updates the status label func (ui *LauncherUI) UpdateStatus(status string) { if ui.statusLabel != nil { fyne.Do(func() { ui.statusLabel.SetText(status) }) } } // OnLogUpdate handles new log content func (ui *LauncherUI) OnLogUpdate(logLine string) { if ui.logText != nil { fyne.Do(func() { currentText := ui.logText.Text ui.logText.SetText(currentText + logLine) // Auto-scroll to bottom (simplified) ui.logText.CursorRow = len(ui.logText.Text) }) } } // NotifyUpdateAvailable shows an update notification func (ui *LauncherUI) NotifyUpdateAvailable(version string) { if ui.launcher != nil && ui.launcher.window != nil { fyne.DoAndWait(func() { dialog.ShowConfirm("Update Available", "A new version ("+version+") is available. Would you like to download it?", func(confirmed bool) { if confirmed { ui.downloadUpdate() } }, ui.launcher.window) }) } } // LoadConfiguration loads the current configuration into UI elements func (ui *LauncherUI) LoadConfiguration() { if ui.launcher == nil { log.Printf("UI LoadConfiguration: launcher is nil") return } config := ui.launcher.GetConfig() log.Printf("UI LoadConfiguration: loading config - ModelsPath=%s, BackendsPath=%s, Address=%s, LogLevel=%s", config.ModelsPath, config.BackendsPath, config.Address, config.LogLevel) log.Printf("UI LoadConfiguration: Environment vars: %v", config.EnvironmentVars) ui.modelsPathEntry.SetText(config.ModelsPath) ui.backendsPathEntry.SetText(config.BackendsPath) ui.addressEntry.SetText(config.Address) ui.logLevelSelect.SetSelected(config.LogLevel) ui.startOnBootCheck.SetChecked(config.StartOnBoot) // Load environment variables ui.envVarsData = []EnvVar{} for key, value := range config.EnvironmentVars { ui.envVarsData = append(ui.envVarsData, EnvVar{Key: key, Value: value}) } if ui.updateEnvironmentDisplay != nil { fyne.Do(func() { ui.updateEnvironmentDisplay() }) } // Update version display version := ui.launcher.GetCurrentVersion() ui.versionLabel.SetText("Version: " + version) log.Printf("UI LoadConfiguration: configuration loaded successfully") } // showDownloadProgress shows a progress window for downloading LocalAI func (ui *LauncherUI) showDownloadProgress(version, title string) { fyne.DoAndWait(func() { // Create progress window using the launcher's app progressWindow := ui.launcher.app.NewWindow("Downloading LocalAI") progressWindow.Resize(fyne.NewSize(400, 250)) progressWindow.CenterOnScreen() // Progress bar progressBar := widget.NewProgressBar() progressBar.SetValue(0) // Status label statusLabel := widget.NewLabel("Preparing download...") // Release notes button releaseNotesButton := widget.NewButton("View Release Notes", func() { releaseNotesURL, err := ui.launcher.githubReleaseNotesURL(version) if err != nil { log.Printf("Failed to parse URL: %v", err) return } ui.launcher.app.OpenURL(releaseNotesURL) }) // Progress container progressContainer := container.NewVBox( widget.NewLabel(title), progressBar, statusLabel, widget.NewSeparator(), releaseNotesButton, ) progressWindow.SetContent(progressContainer) progressWindow.Show() // Start download in background go func() { err := ui.launcher.DownloadUpdate(version, func(progress float64) { // Update progress bar fyne.Do(func() { progressBar.SetValue(progress) percentage := int(progress * 100) statusLabel.SetText(fmt.Sprintf("Downloading... %d%%", percentage)) }) }) // Handle completion fyne.Do(func() { if err != nil { statusLabel.SetText(fmt.Sprintf("Download failed: %v", err)) // Show error dialog dialog.ShowError(err, progressWindow) } else { statusLabel.SetText("Download completed successfully!") progressBar.SetValue(1.0) // Show success dialog dialog.ShowConfirm("Installation Complete", "LocalAI has been downloaded and installed successfully. You can now start LocalAI from the launcher.", func(close bool) { progressWindow.Close() // Update status ui.UpdateStatus("LocalAI installed successfully") }, progressWindow) } }) }() }) } // UpdateRunningState updates UI based on LocalAI running state func (ui *LauncherUI) UpdateRunningState(isRunning bool) { fyne.Do(func() { if isRunning { ui.startStopButton.SetText("Stop LocalAI") ui.webUIButton.Enable() } else { ui.startStopButton.SetText("Start LocalAI") ui.webUIButton.Disable() } }) } // ShowWelcomeWindow displays the welcome window with helpful information func (ui *LauncherUI) ShowWelcomeWindow() { if ui.launcher == nil || ui.launcher.window == nil { log.Printf("Cannot show welcome window: launcher or window is nil") return } fyne.DoAndWait(func() { // Create welcome window welcomeWindow := ui.launcher.app.NewWindow("Welcome to LocalAI Launcher") welcomeWindow.Resize(fyne.NewSize(600, 500)) welcomeWindow.CenterOnScreen() welcomeWindow.SetCloseIntercept(func() { welcomeWindow.Close() }) // Title titleLabel := widget.NewLabel("Welcome to LocalAI Launcher!") titleLabel.TextStyle = fyne.TextStyle{Bold: true} titleLabel.Alignment = fyne.TextAlignCenter // Welcome message welcomeText := `LocalAI Launcher makes it easy to run LocalAI on your system. What you can do: • Start and stop LocalAI server • Configure models and backends paths • Set environment variables • Check for updates automatically • Access LocalAI WebUI when running Getting Started: 1. Configure your models and backends paths 2. Click "Start LocalAI" to begin 3. Use "Open WebUI" to access the interface 4. Check the system tray for quick access` welcomeLabel := widget.NewLabel(welcomeText) welcomeLabel.Wrapping = fyne.TextWrapWord // Useful links section linksTitle := widget.NewLabel("Useful Links:") linksTitle.TextStyle = fyne.TextStyle{Bold: true} // Create link buttons docsButton := widget.NewButton("📚 Documentation", func() { ui.openURL("https://localai.io/docs/") }) githubButton := widget.NewButton("🐙 GitHub Repository", func() { ui.openURL("https://github.com/mudler/LocalAI") }) modelsButton := widget.NewButton("🤖 Model Gallery", func() { ui.openURL("https://localai.io/models/") }) communityButton := widget.NewButton("💬 Community", func() { ui.openURL("https://discord.gg/XgwjKptP7Z") }) // Checkbox to disable welcome window dontShowAgainCheck := widget.NewCheck("Don't show this welcome window again", func(checked bool) { if ui.launcher != nil { config := ui.launcher.GetConfig() v := !checked config.ShowWelcome = &v ui.launcher.SetConfig(config) } }) config := ui.launcher.GetConfig() if config.ShowWelcome != nil { dontShowAgainCheck.SetChecked(*config.ShowWelcome) } // Close button closeButton := widget.NewButton("Get Started", func() { welcomeWindow.Close() }) closeButton.Importance = widget.HighImportance // Layout linksContainer := container.NewVBox( linksTitle, docsButton, githubButton, modelsButton, communityButton, ) content := container.NewVBox( titleLabel, widget.NewSeparator(), welcomeLabel, widget.NewSeparator(), linksContainer, widget.NewSeparator(), dontShowAgainCheck, widget.NewSeparator(), closeButton, ) welcomeWindow.SetContent(content) welcomeWindow.Show() }) } // openURL opens a URL in the default browser func (ui *LauncherUI) openURL(urlString string) { parsedURL, err := url.Parse(urlString) if err != nil { log.Printf("Failed to parse URL %s: %v", urlString, err) return } fyne.CurrentApp().OpenURL(parsedURL) } ================================================ FILE: cmd/launcher/main.go ================================================ package main import ( "log" "fyne.io/fyne/v2" "fyne.io/fyne/v2/app" "fyne.io/fyne/v2/driver/desktop" coreLauncher "github.com/mudler/LocalAI/cmd/launcher/internal" "github.com/mudler/LocalAI/pkg/signals" ) func main() { // Create the application with unique ID myApp := app.NewWithID("com.localai.launcher") myApp.SetIcon(resourceIconPng) myWindow := myApp.NewWindow("LocalAI Launcher") myWindow.Resize(fyne.NewSize(800, 600)) // Create the launcher UI ui := coreLauncher.NewLauncherUI() // Initialize the launcher with UI context launcher := coreLauncher.NewLauncher(ui, myWindow, myApp) // Setup the UI content := ui.CreateMainUI(launcher) myWindow.SetContent(content) // Setup window close behavior - minimize to tray instead of closing myWindow.SetCloseIntercept(func() { myWindow.Hide() }) // Setup system tray using Fyne's built-in approach`` if desk, ok := myApp.(desktop.App); ok { // Create a dynamic systray manager systray := coreLauncher.NewSystrayManager(launcher, myWindow, desk, myApp, resourceIconPng) launcher.SetSystray(systray) } // Setup signal handling for graceful shutdown signals.RegisterGracefulTerminationHandler(func() { // Perform cleanup if err := launcher.Shutdown(); err != nil { log.Printf("Error during shutdown: %v", err) } }) // Initialize the launcher state go func() { if err := launcher.Initialize(); err != nil { log.Printf("Failed to initialize launcher: %v", err) if launcher.GetUI() != nil { launcher.GetUI().UpdateStatus("Failed to initialize: " + err.Error()) } } else { // Load configuration into UI launcher.GetUI().LoadConfiguration() launcher.GetUI().UpdateStatus("Ready") // Show welcome window if configured to do so config := launcher.GetConfig() if *config.ShowWelcome { launcher.GetUI().ShowWelcomeWindow() } } }() // Run the application in background (window only shown when "Settings" is clicked) myApp.Run() } ================================================ FILE: cmd/local-ai/main.go ================================================ package main import ( "os" "path/filepath" "github.com/alecthomas/kong" "github.com/joho/godotenv" "github.com/mudler/LocalAI/core/cli" "github.com/mudler/LocalAI/internal" "github.com/mudler/xlog" _ "github.com/mudler/LocalAI/swagger" ) func main() { var err error // Initialize xlog at a level of INFO, we will set the desired level after we parse the CLI options xlog.SetLogger(xlog.NewLogger(xlog.LogLevel("info"), "text")) // handle loading environment variables from .env files envFiles := []string{".env", "localai.env"} homeDir, err := os.UserHomeDir() if err == nil { envFiles = append(envFiles, filepath.Join(homeDir, "localai.env"), filepath.Join(homeDir, ".config/localai.env")) } envFiles = append(envFiles, "/etc/localai.env") for _, envFile := range envFiles { if _, err := os.Stat(envFile); err == nil { xlog.Debug("env file found, loading environment variables from file", "envFile", envFile) err = godotenv.Load(envFile) if err != nil { xlog.Error("failed to load environment variables from file", "error", err, "envFile", envFile) continue } } } // Actually parse the CLI options k := kong.Must(&cli.CLI, kong.Description( ` LocalAI is a drop-in replacement OpenAI API for running LLM, GPT and genAI models locally on CPU, GPUs with consumer grade hardware. For a list of all available models run local-ai models list Copyright: Ettore Di Giacinto Version: ${version} For documentation and support: Documentation: https://localai.io/ Getting Started: https://localai.io/basics/getting_started/ GitHub Issues: https://github.com/mudler/LocalAI/issues `, ), kong.UsageOnError(), kong.Vars{ "basepath": kong.ExpandPath("."), "galleries": `[{"name":"localai", "url":"github:mudler/LocalAI/gallery/index.yaml@master"}]`, "backends": `[{"name":"localai", "url":"github:mudler/LocalAI/backend/index.yaml@master"}]`, "version": internal.PrintableVersion(), }, ) ctx, err := k.Parse(os.Args[1:]) if err != nil { k.FatalIfErrorf(err) } // Pass Kong model to the completion command for dynamic script generation cli.CLI.Completion.SetApplication(k.Model) // Configure the logging level before we run the application // This is here to preserve the existing --debug flag functionality logLevel := "info" if cli.CLI.Debug && cli.CLI.LogLevel == nil { logLevel = "debug" cli.CLI.LogLevel = &logLevel } if cli.CLI.LogLevel == nil { cli.CLI.LogLevel = &logLevel } // Set xlog logger with the desired level and text format xlog.SetLogger(xlog.NewLogger(xlog.LogLevel(*cli.CLI.LogLevel), *cli.CLI.LogFormat)) // Run the thing! err = ctx.Run(&cli.CLI.Context) if err != nil { xlog.Fatal("Error running the application", "error", err) } } ================================================ FILE: configuration/.keep ================================================ ================================================ FILE: core/application/agent_jobs.go ================================================ package application import ( "time" "github.com/mudler/LocalAI/core/services" "github.com/mudler/xlog" ) // RestartAgentJobService restarts the agent job service with current ApplicationConfig settings func (a *Application) RestartAgentJobService() error { a.agentJobMutex.Lock() defer a.agentJobMutex.Unlock() // Stop existing service if running if a.agentJobService != nil { if err := a.agentJobService.Stop(); err != nil { xlog.Warn("Error stopping agent job service", "error", err) } // Wait a bit for shutdown to complete time.Sleep(200 * time.Millisecond) } // Create new service instance agentJobService := services.NewAgentJobService( a.ApplicationConfig(), a.ModelLoader(), a.ModelConfigLoader(), a.TemplatesEvaluator(), ) // Start the service err := agentJobService.Start(a.ApplicationConfig().Context) if err != nil { xlog.Error("Failed to start agent job service", "error", err) return err } a.agentJobService = agentJobService xlog.Info("Agent job service restarted") return nil } ================================================ FILE: core/application/application.go ================================================ package application import ( "context" "sync" "sync/atomic" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" "gorm.io/gorm" ) type Application struct { backendLoader *config.ModelConfigLoader modelLoader *model.ModelLoader applicationConfig *config.ApplicationConfig startupConfig *config.ApplicationConfig // Stores original config from env vars (before file loading) templatesEvaluator *templates.Evaluator galleryService *services.GalleryService agentJobService *services.AgentJobService agentPoolService atomic.Pointer[services.AgentPoolService] authDB *gorm.DB watchdogMutex sync.Mutex watchdogStop chan bool p2pMutex sync.Mutex p2pCtx context.Context p2pCancel context.CancelFunc agentJobMutex sync.Mutex } func newApplication(appConfig *config.ApplicationConfig) *Application { ml := model.NewModelLoader(appConfig.SystemState) // Close MCP sessions when a model is unloaded (watchdog eviction, manual shutdown, etc.) ml.OnModelUnload(func(modelName string) { mcpTools.CloseMCPSessions(modelName) }) return &Application{ backendLoader: config.NewModelConfigLoader(appConfig.SystemState.Model.ModelsPath), modelLoader: ml, applicationConfig: appConfig, templatesEvaluator: templates.NewEvaluator(appConfig.SystemState.Model.ModelsPath), } } func (a *Application) ModelConfigLoader() *config.ModelConfigLoader { return a.backendLoader } func (a *Application) ModelLoader() *model.ModelLoader { return a.modelLoader } func (a *Application) ApplicationConfig() *config.ApplicationConfig { return a.applicationConfig } func (a *Application) TemplatesEvaluator() *templates.Evaluator { return a.templatesEvaluator } func (a *Application) GalleryService() *services.GalleryService { return a.galleryService } func (a *Application) AgentJobService() *services.AgentJobService { return a.agentJobService } func (a *Application) AgentPoolService() *services.AgentPoolService { return a.agentPoolService.Load() } // AuthDB returns the auth database connection, or nil if auth is not enabled. func (a *Application) AuthDB() *gorm.DB { return a.authDB } // StartupConfig returns the original startup configuration (from env vars, before file loading) func (a *Application) StartupConfig() *config.ApplicationConfig { return a.startupConfig } func (a *Application) start() error { galleryService := services.NewGalleryService(a.ApplicationConfig(), a.ModelLoader()) err := galleryService.Start(a.ApplicationConfig().Context, a.ModelConfigLoader(), a.ApplicationConfig().SystemState) if err != nil { return err } a.galleryService = galleryService // Initialize agent job service agentJobService := services.NewAgentJobService( a.ApplicationConfig(), a.ModelLoader(), a.ModelConfigLoader(), a.TemplatesEvaluator(), ) err = agentJobService.Start(a.ApplicationConfig().Context) if err != nil { return err } a.agentJobService = agentJobService return nil } // StartAgentPool initializes and starts the agent pool service (LocalAGI integration). // This must be called after the HTTP server is listening, because backends like // PostgreSQL need to call the embeddings API during collection initialization. func (a *Application) StartAgentPool() { if !a.applicationConfig.AgentPool.Enabled { return } aps, err := services.NewAgentPoolService(a.applicationConfig) if err != nil { xlog.Error("Failed to create agent pool service", "error", err) return } if a.authDB != nil { aps.SetAuthDB(a.authDB) } if err := aps.Start(a.applicationConfig.Context); err != nil { xlog.Error("Failed to start agent pool", "error", err) return } // Wire per-user scoped services so collections, skills, and jobs are isolated per user usm := services.NewUserServicesManager( aps.UserStorage(), a.applicationConfig, a.modelLoader, a.backendLoader, a.templatesEvaluator, ) aps.SetUserServicesManager(usm) a.agentPoolService.Store(aps) } ================================================ FILE: core/application/config_file_watcher.go ================================================ package application import ( "encoding/json" "fmt" "os" "path" "path/filepath" "slices" "time" "dario.cat/mergo" "github.com/fsnotify/fsnotify" "github.com/mudler/LocalAI/core/config" "github.com/mudler/xlog" ) type fileHandler func(fileContent []byte, appConfig *config.ApplicationConfig) error type configFileHandler struct { handlers map[string]fileHandler watcher *fsnotify.Watcher appConfig *config.ApplicationConfig } // TODO: This should be a singleton eventually so other parts of the code can register config file handlers, // then we can export it to other packages func newConfigFileHandler(appConfig *config.ApplicationConfig) configFileHandler { c := configFileHandler{ handlers: make(map[string]fileHandler), appConfig: appConfig, } err := c.Register("api_keys.json", readApiKeysJson(*appConfig), true) if err != nil { xlog.Error("unable to register config file handler", "error", err, "file", "api_keys.json") } err = c.Register("external_backends.json", readExternalBackendsJson(*appConfig), true) if err != nil { xlog.Error("unable to register config file handler", "error", err, "file", "external_backends.json") } err = c.Register("runtime_settings.json", readRuntimeSettingsJson(*appConfig), true) if err != nil { xlog.Error("unable to register config file handler", "error", err, "file", "runtime_settings.json") } // Note: agent_tasks.json and agent_jobs.json are handled by AgentJobService directly // The service watches and reloads these files internally return c } func (c *configFileHandler) Register(filename string, handler fileHandler, runNow bool) error { _, ok := c.handlers[filename] if ok { return fmt.Errorf("handler already registered for file %s", filename) } c.handlers[filename] = handler if runNow { c.callHandler(filename, handler) } return nil } func (c *configFileHandler) callHandler(filename string, handler fileHandler) { rootedFilePath := filepath.Join(c.appConfig.DynamicConfigsDir, filepath.Clean(filename)) xlog.Debug("reading file for dynamic config update", "filename", rootedFilePath) fileContent, err := os.ReadFile(rootedFilePath) if err != nil && !os.IsNotExist(err) { xlog.Error("could not read file", "error", err, "filename", rootedFilePath) } if err = handler(fileContent, c.appConfig); err != nil { xlog.Error("WatchConfigDirectory goroutine failed to update options", "error", err) } } func (c *configFileHandler) Watch() error { configWatcher, err := fsnotify.NewWatcher() c.watcher = configWatcher if err != nil { return err } if c.appConfig.DynamicConfigsDirPollInterval > 0 { xlog.Debug("Poll interval set, falling back to polling for configuration changes") ticker := time.NewTicker(c.appConfig.DynamicConfigsDirPollInterval) go func() { for { <-ticker.C for file, handler := range c.handlers { xlog.Debug("polling config file", "file", file) c.callHandler(file, handler) } } }() } // Start listening for events. go func() { for { select { case event, ok := <-c.watcher.Events: if !ok { return } if event.Has(fsnotify.Write | fsnotify.Create | fsnotify.Remove) { handler, ok := c.handlers[path.Base(event.Name)] if !ok { continue } c.callHandler(filepath.Base(event.Name), handler) } case err, ok := <-c.watcher.Errors: xlog.Error("config watcher error received", "error", err) if !ok { return } } } }() // Add a path. err = c.watcher.Add(c.appConfig.DynamicConfigsDir) if err != nil { return fmt.Errorf("unable to create a watcher on the configuration directory: %+v", err) } return nil } // TODO: When we institute graceful shutdown, this should be called func (c *configFileHandler) Stop() error { return c.watcher.Close() } func readApiKeysJson(startupAppConfig config.ApplicationConfig) fileHandler { handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error { xlog.Debug("processing api keys runtime update", "numKeys", len(startupAppConfig.ApiKeys)) if len(fileContent) > 0 { // Parse JSON content from the file var fileKeys []string err := json.Unmarshal(fileContent, &fileKeys) if err != nil { return err } xlog.Debug("discovered API keys from api keys dynamic config file", "numKeys", len(fileKeys)) appConfig.ApiKeys = append(startupAppConfig.ApiKeys, fileKeys...) } else { xlog.Debug("no API keys discovered from dynamic config file") appConfig.ApiKeys = startupAppConfig.ApiKeys } xlog.Debug("total api keys after processing", "numKeys", len(appConfig.ApiKeys)) return nil } return handler } func readExternalBackendsJson(startupAppConfig config.ApplicationConfig) fileHandler { handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error { xlog.Debug("processing external_backends.json") if len(fileContent) > 0 { // Parse JSON content from the file var fileBackends map[string]string err := json.Unmarshal(fileContent, &fileBackends) if err != nil { return err } appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends err = mergo.Merge(&appConfig.ExternalGRPCBackends, &fileBackends) if err != nil { return err } } else { appConfig.ExternalGRPCBackends = startupAppConfig.ExternalGRPCBackends } xlog.Debug("external backends loaded from external_backends.json") return nil } return handler } func readRuntimeSettingsJson(startupAppConfig config.ApplicationConfig) fileHandler { handler := func(fileContent []byte, appConfig *config.ApplicationConfig) error { xlog.Debug("processing runtime_settings.json") // Determine if settings came from env vars by comparing with startup config // startupAppConfig contains the original values set from env vars at startup. // If current values match startup values, they came from env vars (or defaults). // We apply file settings only if current values match startup values (meaning not from env vars). envWatchdogIdle := appConfig.WatchDogIdle == startupAppConfig.WatchDogIdle envWatchdogBusy := appConfig.WatchDogBusy == startupAppConfig.WatchDogBusy envWatchdogIdleTimeout := appConfig.WatchDogIdleTimeout == startupAppConfig.WatchDogIdleTimeout envWatchdogBusyTimeout := appConfig.WatchDogBusyTimeout == startupAppConfig.WatchDogBusyTimeout envSingleBackend := appConfig.SingleBackend == startupAppConfig.SingleBackend envMaxActiveBackends := appConfig.MaxActiveBackends == startupAppConfig.MaxActiveBackends envParallelRequests := appConfig.ParallelBackendRequests == startupAppConfig.ParallelBackendRequests envMemoryReclaimerEnabled := appConfig.MemoryReclaimerEnabled == startupAppConfig.MemoryReclaimerEnabled envMemoryReclaimerThreshold := appConfig.MemoryReclaimerThreshold == startupAppConfig.MemoryReclaimerThreshold envThreads := appConfig.Threads == startupAppConfig.Threads envContextSize := appConfig.ContextSize == startupAppConfig.ContextSize envF16 := appConfig.F16 == startupAppConfig.F16 envDebug := appConfig.Debug == startupAppConfig.Debug envCORS := appConfig.CORS == startupAppConfig.CORS envCSRF := appConfig.DisableCSRF == startupAppConfig.DisableCSRF envCORSAllowOrigins := appConfig.CORSAllowOrigins == startupAppConfig.CORSAllowOrigins envP2PToken := appConfig.P2PToken == startupAppConfig.P2PToken envP2PNetworkID := appConfig.P2PNetworkID == startupAppConfig.P2PNetworkID envFederated := appConfig.Federated == startupAppConfig.Federated envGalleries := slices.Equal(appConfig.Galleries, startupAppConfig.Galleries) envBackendGalleries := slices.Equal(appConfig.BackendGalleries, startupAppConfig.BackendGalleries) envAutoloadGalleries := appConfig.AutoloadGalleries == startupAppConfig.AutoloadGalleries envAutoloadBackendGalleries := appConfig.AutoloadBackendGalleries == startupAppConfig.AutoloadBackendGalleries envAgentJobRetentionDays := appConfig.AgentJobRetentionDays == startupAppConfig.AgentJobRetentionDays envForceEvictionWhenBusy := appConfig.ForceEvictionWhenBusy == startupAppConfig.ForceEvictionWhenBusy envLRUEvictionMaxRetries := appConfig.LRUEvictionMaxRetries == startupAppConfig.LRUEvictionMaxRetries envLRUEvictionRetryInterval := appConfig.LRUEvictionRetryInterval == startupAppConfig.LRUEvictionRetryInterval if len(fileContent) > 0 { var settings config.RuntimeSettings err := json.Unmarshal(fileContent, &settings) if err != nil { return err } // Apply file settings only if they don't match startup values (i.e., not from env vars) if settings.WatchdogIdleEnabled != nil && !envWatchdogIdle { appConfig.WatchDogIdle = *settings.WatchdogIdleEnabled if appConfig.WatchDogIdle { appConfig.WatchDog = true } } if settings.WatchdogBusyEnabled != nil && !envWatchdogBusy { appConfig.WatchDogBusy = *settings.WatchdogBusyEnabled if appConfig.WatchDogBusy { appConfig.WatchDog = true } } if settings.WatchdogIdleTimeout != nil && !envWatchdogIdleTimeout { dur, err := time.ParseDuration(*settings.WatchdogIdleTimeout) if err == nil { appConfig.WatchDogIdleTimeout = dur } else { xlog.Warn("invalid watchdog idle timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogIdleTimeout) } } if settings.WatchdogBusyTimeout != nil && !envWatchdogBusyTimeout { dur, err := time.ParseDuration(*settings.WatchdogBusyTimeout) if err == nil { appConfig.WatchDogBusyTimeout = dur } else { xlog.Warn("invalid watchdog busy timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogBusyTimeout) } } // Handle MaxActiveBackends (new) and SingleBackend (deprecated) if settings.MaxActiveBackends != nil && !envMaxActiveBackends { appConfig.MaxActiveBackends = *settings.MaxActiveBackends // For backward compatibility, also set SingleBackend if MaxActiveBackends == 1 appConfig.SingleBackend = (*settings.MaxActiveBackends == 1) } else if settings.SingleBackend != nil && !envSingleBackend { // Legacy: SingleBackend maps to MaxActiveBackends = 1 appConfig.SingleBackend = *settings.SingleBackend if *settings.SingleBackend { appConfig.MaxActiveBackends = 1 } else { appConfig.MaxActiveBackends = 0 } } if settings.ParallelBackendRequests != nil && !envParallelRequests { appConfig.ParallelBackendRequests = *settings.ParallelBackendRequests } if settings.MemoryReclaimerEnabled != nil && !envMemoryReclaimerEnabled { appConfig.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled if appConfig.MemoryReclaimerEnabled { appConfig.WatchDog = true // Memory reclaimer requires watchdog } } if settings.MemoryReclaimerThreshold != nil && !envMemoryReclaimerThreshold { appConfig.MemoryReclaimerThreshold = *settings.MemoryReclaimerThreshold } if settings.ForceEvictionWhenBusy != nil && !envForceEvictionWhenBusy { appConfig.ForceEvictionWhenBusy = *settings.ForceEvictionWhenBusy } if settings.LRUEvictionMaxRetries != nil && !envLRUEvictionMaxRetries { appConfig.LRUEvictionMaxRetries = *settings.LRUEvictionMaxRetries } if settings.LRUEvictionRetryInterval != nil && !envLRUEvictionRetryInterval { dur, err := time.ParseDuration(*settings.LRUEvictionRetryInterval) if err == nil { appConfig.LRUEvictionRetryInterval = dur } else { xlog.Warn("invalid LRU eviction retry interval in runtime_settings.json", "error", err, "interval", *settings.LRUEvictionRetryInterval) } } if settings.Threads != nil && !envThreads { appConfig.Threads = *settings.Threads } if settings.ContextSize != nil && !envContextSize { appConfig.ContextSize = *settings.ContextSize } if settings.F16 != nil && !envF16 { appConfig.F16 = *settings.F16 } if settings.Debug != nil && !envDebug { appConfig.Debug = *settings.Debug } if settings.CORS != nil && !envCORS { appConfig.CORS = *settings.CORS } if settings.CSRF != nil && !envCSRF { appConfig.DisableCSRF = *settings.CSRF } if settings.CORSAllowOrigins != nil && !envCORSAllowOrigins { appConfig.CORSAllowOrigins = *settings.CORSAllowOrigins } if settings.P2PToken != nil && !envP2PToken { appConfig.P2PToken = *settings.P2PToken } if settings.P2PNetworkID != nil && !envP2PNetworkID { appConfig.P2PNetworkID = *settings.P2PNetworkID } if settings.Federated != nil && !envFederated { appConfig.Federated = *settings.Federated } if settings.Galleries != nil && !envGalleries { appConfig.Galleries = *settings.Galleries } if settings.BackendGalleries != nil && !envBackendGalleries { appConfig.BackendGalleries = *settings.BackendGalleries } if settings.AutoloadGalleries != nil && !envAutoloadGalleries { appConfig.AutoloadGalleries = *settings.AutoloadGalleries } if settings.AutoloadBackendGalleries != nil && !envAutoloadBackendGalleries { appConfig.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries } if settings.ApiKeys != nil { // API keys from env vars (startup) should be kept, runtime settings keys replace all runtime keys // If runtime_settings.json specifies ApiKeys (even if empty), it replaces all runtime keys // Start with env keys, then add runtime_settings.json keys (which may be empty to clear them) envKeys := startupAppConfig.ApiKeys runtimeKeys := *settings.ApiKeys // Replace all runtime keys with what's in runtime_settings.json appConfig.ApiKeys = append(envKeys, runtimeKeys...) } if settings.AgentJobRetentionDays != nil && !envAgentJobRetentionDays { appConfig.AgentJobRetentionDays = *settings.AgentJobRetentionDays } // If watchdog is enabled via file but not via env, ensure WatchDog flag is set if !envWatchdogIdle && !envWatchdogBusy { if settings.WatchdogEnabled != nil && *settings.WatchdogEnabled { appConfig.WatchDog = true } } } xlog.Debug("runtime settings loaded from runtime_settings.json") return nil } return handler } ================================================ FILE: core/application/p2p.go ================================================ package application import ( "context" "fmt" "net" "slices" "time" "github.com/google/uuid" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/edgevpn/pkg/node" "github.com/mudler/xlog" ) func (a *Application) StopP2P() error { if a.p2pCancel != nil { a.p2pCancel() a.p2pCancel = nil a.p2pCtx = nil // Wait a bit for shutdown to complete time.Sleep(200 * time.Millisecond) } return nil } func (a *Application) StartP2P() error { // we need a p2p token if a.applicationConfig.P2PToken == "" { return fmt.Errorf("P2P token is not set") } networkID := a.applicationConfig.P2PNetworkID ctx, cancel := context.WithCancel(a.ApplicationConfig().Context) a.p2pCtx = ctx a.p2pCancel = cancel var n *node.Node // Here we are avoiding creating multiple nodes: // - if the federated mode is enabled, we create a federated node and expose a service // - exposing a service creates a node with specific options, and we don't want to create another node // If the federated mode is enabled, we expose a service to the local instance running // at r.Address if a.applicationConfig.Federated { _, port, err := net.SplitHostPort(a.applicationConfig.APIAddress) if err != nil { return err } // Here a new node is created and started // and a service is exposed by the node node, err := p2p.ExposeService(ctx, "localhost", port, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.FederatedID)) if err != nil { return err } if err := p2p.ServiceDiscoverer(ctx, node, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.FederatedID), nil, false); err != nil { return err } n = node // start node sync in the background if err := a.p2pSync(ctx, node); err != nil { return err } } // If a node wasn't created previously, create it if n == nil { node, err := p2p.NewNode(a.applicationConfig.P2PToken) if err != nil { return err } err = node.Start(ctx) if err != nil { return fmt.Errorf("starting new node: %w", err) } n = node } // Attach a ServiceDiscoverer to the p2p node for llama.cpp workers xlog.Info("Starting P2P server discovery...") if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID), func(serviceID string, node schema.NodeData) { var tunnelAddresses []string for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.LlamaCPPWorkerID)) { if v.IsOnline() { tunnelAddresses = append(tunnelAddresses, v.TunnelAddress) } else { xlog.Info("Node is offline", "node", v.ID) } } if a.applicationConfig.LlamaCPPTunnelCallback != nil { a.applicationConfig.LlamaCPPTunnelCallback(tunnelAddresses) } }, true); err != nil { return err } // Attach a ServiceDiscoverer for MLX distributed workers xlog.Info("Starting MLX P2P worker discovery...") if err := p2p.ServiceDiscoverer(ctx, n, a.applicationConfig.P2PToken, p2p.NetworkID(networkID, p2p.MLXWorkerID), func(serviceID string, node schema.NodeData) { var tunnelAddresses []string for _, v := range p2p.GetAvailableNodes(p2p.NetworkID(networkID, p2p.MLXWorkerID)) { if v.IsOnline() { tunnelAddresses = append(tunnelAddresses, v.TunnelAddress) } else { xlog.Info("MLX node is offline", "node", v.ID) } } if a.applicationConfig.MLXTunnelCallback != nil { a.applicationConfig.MLXTunnelCallback(tunnelAddresses) } }, true); err != nil { return err } return nil } // RestartP2P restarts the P2P stack with current ApplicationConfig settings // Note: This method signals that P2P should be restarted, but the actual restart // is handled by the caller to avoid import cycles func (a *Application) RestartP2P() error { a.p2pMutex.Lock() defer a.p2pMutex.Unlock() // Stop existing P2P if running if a.p2pCancel != nil { a.p2pCancel() a.p2pCancel = nil a.p2pCtx = nil // Wait a bit for shutdown to complete time.Sleep(200 * time.Millisecond) } appConfig := a.ApplicationConfig() // Start P2P if token is set if appConfig.P2PToken == "" { return fmt.Errorf("P2P token is not set") } // Create new context for P2P ctx, cancel := context.WithCancel(appConfig.Context) a.p2pCtx = ctx a.p2pCancel = cancel // Get API address from config address := appConfig.APIAddress if address == "" { address = "127.0.0.1:8080" // default } // Start P2P stack in a goroutine go func() { if err := a.StartP2P(); err != nil { xlog.Error("Failed to start P2P stack", "error", err) cancel() // Cancel context on error } }() xlog.Info("P2P stack restarted with new settings") return nil } func syncState(ctx context.Context, n *node.Node, app *Application) error { xlog.Debug("[p2p-sync] Syncing state") whatWeHave := []string{} for _, model := range app.ModelConfigLoader().GetAllModelsConfigs() { whatWeHave = append(whatWeHave, model.Name) } ledger, _ := n.Ledger() currentData := ledger.CurrentData() xlog.Debug("[p2p-sync] Current data", "data", currentData) data, exists := ledger.GetKey("shared_state", "models") if !exists { ledger.AnnounceUpdate(ctx, time.Minute, "shared_state", "models", whatWeHave) xlog.Debug("No models found in the ledger, announced our models", "models", whatWeHave) } models := []string{} if err := data.Unmarshal(&models); err != nil { xlog.Warn("error unmarshalling models", "error", err) return nil } xlog.Debug("[p2p-sync] Models comparison", "ourModels", whatWeHave, "ledgerModels", models) // Sync with our state whatIsNotThere := []string{} for _, model := range whatWeHave { if !slices.Contains(models, model) { whatIsNotThere = append(whatIsNotThere, model) } } if len(whatIsNotThere) > 0 { xlog.Debug("[p2p-sync] Announcing our models", "models", append(models, whatIsNotThere...)) ledger.AnnounceUpdate( ctx, 1*time.Minute, "shared_state", "models", append(models, whatIsNotThere...), ) } // Check if we have a model that is not in our state, otherwise install it for _, model := range models { if slices.Contains(whatWeHave, model) { xlog.Debug("[p2p-sync] Model is already present in this instance", "model", model) continue } // we install model xlog.Info("[p2p-sync] Installing model which is not present in this instance", "model", model) uuid, err := uuid.NewUUID() if err != nil { xlog.Error("error generating UUID", "error", err) continue } app.GalleryService().ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ ID: uuid.String(), GalleryElementName: model, Galleries: app.ApplicationConfig().Galleries, BackendGalleries: app.ApplicationConfig().BackendGalleries, } } return nil } func (a *Application) p2pSync(ctx context.Context, n *node.Node) error { go func() { for { select { case <-ctx.Done(): return case <-time.After(1 * time.Minute): if err := syncState(ctx, n, a); err != nil { xlog.Error("error syncing state", "error", err) } } } }() return nil } ================================================ FILE: core/application/startup.go ================================================ package application import ( "crypto/rand" "encoding/hex" "encoding/json" "fmt" "os" "path/filepath" "time" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/services" coreStartup "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/xlog" ) func New(opts ...config.AppOption) (*Application, error) { options := config.NewApplicationConfig(opts...) // Store a copy of the startup config (from env vars, before file loading) // This is used to determine if settings came from env vars vs file startupConfigCopy := *options application := newApplication(options) application.startupConfig = &startupConfigCopy xlog.Info("Starting LocalAI", "threads", options.Threads, "modelsPath", options.SystemState.Model.ModelsPath) xlog.Info("LocalAI version", "version", internal.PrintableVersion()) if err := application.start(); err != nil { return nil, err } caps, err := xsysinfo.CPUCapabilities() if err == nil { xlog.Debug("CPU capabilities", "capabilities", caps) } gpus, err := xsysinfo.GPUs() if err == nil { xlog.Debug("GPU count", "count", len(gpus)) for _, gpu := range gpus { xlog.Debug("GPU", "gpu", gpu.String()) } } // Make sure directories exists if options.SystemState.Model.ModelsPath == "" { return nil, fmt.Errorf("models path cannot be empty") } err = os.MkdirAll(options.SystemState.Model.ModelsPath, 0750) if err != nil { return nil, fmt.Errorf("unable to create ModelPath: %q", err) } if options.GeneratedContentDir != "" { err := os.MkdirAll(options.GeneratedContentDir, 0750) if err != nil { return nil, fmt.Errorf("unable to create ImageDir: %q", err) } } if options.UploadDir != "" { err := os.MkdirAll(options.UploadDir, 0750) if err != nil { return nil, fmt.Errorf("unable to create UploadDir: %q", err) } } // Create and migrate data directory if options.DataPath != "" { if err := os.MkdirAll(options.DataPath, 0750); err != nil { return nil, fmt.Errorf("unable to create DataPath: %q", err) } // Migrate data from DynamicConfigsDir to DataPath if needed if options.DynamicConfigsDir != "" && options.DataPath != options.DynamicConfigsDir { migrateDataFiles(options.DynamicConfigsDir, options.DataPath) } } // Initialize auth database if auth is enabled if options.Auth.Enabled { // Auto-generate HMAC secret if not provided if options.Auth.APIKeyHMACSecret == "" { secretFile := filepath.Join(options.DataPath, ".hmac_secret") secret, err := loadOrGenerateHMACSecret(secretFile) if err != nil { return nil, fmt.Errorf("failed to initialize HMAC secret: %w", err) } options.Auth.APIKeyHMACSecret = secret } authDB, err := auth.InitDB(options.Auth.DatabaseURL) if err != nil { return nil, fmt.Errorf("failed to initialize auth database: %w", err) } application.authDB = authDB xlog.Info("Auth enabled", "database", options.Auth.DatabaseURL) // Start session and expired API key cleanup goroutine go func() { ticker := time.NewTicker(1 * time.Hour) defer ticker.Stop() for { select { case <-options.Context.Done(): return case <-ticker.C: if err := auth.CleanExpiredSessions(authDB); err != nil { xlog.Error("failed to clean expired sessions", "error", err) } if err := auth.CleanExpiredAPIKeys(authDB); err != nil { xlog.Error("failed to clean expired API keys", "error", err) } } } }() } if err := coreStartup.InstallModels(options.Context, application.GalleryService(), options.Galleries, options.BackendGalleries, options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, nil, options.ModelsURL...); err != nil { xlog.Error("error installing models", "error", err) } for _, backend := range options.ExternalBackends { if err := services.InstallExternalBackend(options.Context, options.BackendGalleries, options.SystemState, application.ModelLoader(), nil, backend, "", ""); err != nil { xlog.Error("error installing external backend", "error", err) } } configLoaderOpts := options.ToConfigLoaderOptions() if err := application.ModelConfigLoader().LoadModelConfigsFromPath(options.SystemState.Model.ModelsPath, configLoaderOpts...); err != nil { xlog.Error("error loading config files", "error", err) } if err := gallery.RegisterBackends(options.SystemState, application.ModelLoader()); err != nil { xlog.Error("error registering external backends", "error", err) } if options.ConfigFile != "" { if err := application.ModelConfigLoader().LoadMultipleModelConfigsSingleFile(options.ConfigFile, configLoaderOpts...); err != nil { xlog.Error("error loading config file", "error", err) } } if err := application.ModelConfigLoader().Preload(options.SystemState.Model.ModelsPath); err != nil { xlog.Error("error downloading models", "error", err) } if options.PreloadJSONModels != "" { if err := services.ApplyGalleryFromString(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadJSONModels); err != nil { return nil, err } } if options.PreloadModelsFromPath != "" { if err := services.ApplyGalleryFromFile(options.SystemState, application.ModelLoader(), options.EnforcePredownloadScans, options.AutoloadBackendGalleries, options.Galleries, options.BackendGalleries, options.PreloadModelsFromPath); err != nil { return nil, err } } if options.Debug { for _, v := range application.ModelConfigLoader().GetAllModelsConfigs() { xlog.Debug("Model", "name", v.Name, "config", v) } } // Load runtime settings from file if DynamicConfigsDir is set // This applies file settings with env var precedence (env vars take priority) // Note: startupConfigCopy was already created above, so it has the original env var values if options.DynamicConfigsDir != "" { loadRuntimeSettingsFromFile(options) } application.ModelLoader().SetBackendLoggingEnabled(options.EnableBackendLogging) // turn off any process that was started by GRPC if the context is canceled go func() { <-options.Context.Done() xlog.Debug("Context canceled, shutting down") err := application.ModelLoader().StopAllGRPC() if err != nil { xlog.Error("error while stopping all grpc backends", "error", err) } }() // Initialize watchdog with current settings (after loading from file) initializeWatchdog(application, options) if options.LoadToMemory != nil && !options.SingleBackend { for _, m := range options.LoadToMemory { cfg, err := application.ModelConfigLoader().LoadModelConfigFileByNameDefaultOptions(m, options) if err != nil { return nil, err } xlog.Debug("Auto loading model into memory from file", "model", m, "file", cfg.Model) o := backend.ModelOptions(*cfg, options) var backendErr error _, backendErr = application.ModelLoader().Load(o...) if backendErr != nil { return nil, err } } } // Watch the configuration directory startWatcher(options) xlog.Info("core/startup process completed!") return application, nil } func startWatcher(options *config.ApplicationConfig) { if options.DynamicConfigsDir == "" { // No need to start the watcher if the directory is not set return } if _, err := os.Stat(options.DynamicConfigsDir); err != nil { if os.IsNotExist(err) { // We try to create the directory if it does not exist and was specified if err := os.MkdirAll(options.DynamicConfigsDir, 0700); err != nil { xlog.Error("failed creating DynamicConfigsDir", "error", err) } } else { // something else happened, we log the error and don't start the watcher xlog.Error("failed to read DynamicConfigsDir, watcher will not be started", "error", err) return } } configHandler := newConfigFileHandler(options) if err := configHandler.Watch(); err != nil { xlog.Error("failed creating watcher", "error", err) } } // loadRuntimeSettingsFromFile loads settings from runtime_settings.json with env var precedence // This function is called at startup, before env vars are applied via AppOptions. // Since env vars are applied via AppOptions in run.go, we need to check if they're set. // We do this by checking if the current options values differ from defaults, which would // indicate they were set from env vars. However, a simpler approach is to just apply // file settings here, and let the AppOptions (which are applied after this) override them. // But actually, this is called AFTER AppOptions are applied in New(), so we need to check env vars. // The cleanest solution: Store original values before applying file, or check if values match // what would be set from env vars. For now, we'll apply file settings and they'll be // overridden by AppOptions if env vars were set (but AppOptions are already applied). // Actually, this function is called in New() before AppOptions are fully processed for watchdog. // Let's check the call order: New() -> loadRuntimeSettingsFromFile() -> initializeWatchdog() // But AppOptions are applied in NewApplicationConfig() which is called first. // So at this point, options already has values from env vars. We should compare against // defaults to see if env vars were set. But we don't have defaults stored. // Simplest: Just apply file settings. If env vars were set, they're already in options. // The file watcher handler will handle runtime changes properly by comparing with startupAppConfig. func loadRuntimeSettingsFromFile(options *config.ApplicationConfig) { settingsFile := filepath.Join(options.DynamicConfigsDir, "runtime_settings.json") fileContent, err := os.ReadFile(settingsFile) if err != nil { if os.IsNotExist(err) { xlog.Debug("runtime_settings.json not found, using defaults") return } xlog.Warn("failed to read runtime_settings.json", "error", err) return } var settings config.RuntimeSettings if err := json.Unmarshal(fileContent, &settings); err != nil { xlog.Warn("failed to parse runtime_settings.json", "error", err) return } // At this point, options already has values from env vars (via AppOptions in run.go). // To avoid env var duplication, we determine if env vars were set by checking if // current values differ from defaults. Defaults are: false for bools, 0 for durations. // If current value is at default, it likely wasn't set from env var, so we can apply file. // If current value is non-default, it was likely set from env var, so we preserve it. // Note: This means env vars explicitly setting to false/0 won't be distinguishable from defaults, // but that's an acceptable limitation to avoid env var duplication. if settings.WatchdogIdleEnabled != nil { // Only apply if current value is default (false), suggesting it wasn't set from env var if !options.WatchDogIdle { options.WatchDogIdle = *settings.WatchdogIdleEnabled if options.WatchDogIdle { options.WatchDog = true } } } if settings.WatchdogBusyEnabled != nil { if !options.WatchDogBusy { options.WatchDogBusy = *settings.WatchdogBusyEnabled if options.WatchDogBusy { options.WatchDog = true } } } if settings.WatchdogIdleTimeout != nil { // Only apply if current value is default (0), suggesting it wasn't set from env var if options.WatchDogIdleTimeout == 0 { dur, err := time.ParseDuration(*settings.WatchdogIdleTimeout) if err == nil { options.WatchDogIdleTimeout = dur } else { xlog.Warn("invalid watchdog idle timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogIdleTimeout) } } } if settings.WatchdogBusyTimeout != nil { if options.WatchDogBusyTimeout == 0 { dur, err := time.ParseDuration(*settings.WatchdogBusyTimeout) if err == nil { options.WatchDogBusyTimeout = dur } else { xlog.Warn("invalid watchdog busy timeout in runtime_settings.json", "error", err, "timeout", *settings.WatchdogBusyTimeout) } } } if settings.WatchdogInterval != nil { if options.WatchDogInterval == 0 { dur, err := time.ParseDuration(*settings.WatchdogInterval) if err == nil { options.WatchDogInterval = dur } else { xlog.Warn("invalid watchdog interval in runtime_settings.json", "error", err, "interval", *settings.WatchdogInterval) options.WatchDogInterval = model.DefaultWatchdogInterval } } } // Handle MaxActiveBackends (new) and SingleBackend (deprecated) if settings.MaxActiveBackends != nil { // Only apply if current value is default (0), suggesting it wasn't set from env var if options.MaxActiveBackends == 0 { options.MaxActiveBackends = *settings.MaxActiveBackends // For backward compatibility, also set SingleBackend if MaxActiveBackends == 1 options.SingleBackend = (*settings.MaxActiveBackends == 1) } } else if settings.SingleBackend != nil { // Legacy: SingleBackend maps to MaxActiveBackends = 1 if !options.SingleBackend { options.SingleBackend = *settings.SingleBackend if *settings.SingleBackend { options.MaxActiveBackends = 1 } } } if settings.ParallelBackendRequests != nil { if !options.ParallelBackendRequests { options.ParallelBackendRequests = *settings.ParallelBackendRequests } } if settings.MemoryReclaimerEnabled != nil { // Only apply if current value is default (false), suggesting it wasn't set from env var if !options.MemoryReclaimerEnabled { options.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled if options.MemoryReclaimerEnabled { options.WatchDog = true // Memory reclaimer requires watchdog } } } if settings.MemoryReclaimerThreshold != nil { // Only apply if current value is default (0), suggesting it wasn't set from env var if options.MemoryReclaimerThreshold == 0 { options.MemoryReclaimerThreshold = *settings.MemoryReclaimerThreshold } } if settings.ForceEvictionWhenBusy != nil { // Only apply if current value is default (false), suggesting it wasn't set from env var if !options.ForceEvictionWhenBusy { options.ForceEvictionWhenBusy = *settings.ForceEvictionWhenBusy } } if settings.LRUEvictionMaxRetries != nil { // Only apply if current value is default (30), suggesting it wasn't set from env var if options.LRUEvictionMaxRetries == 0 { options.LRUEvictionMaxRetries = *settings.LRUEvictionMaxRetries } } if settings.LRUEvictionRetryInterval != nil { // Only apply if current value is default (1s), suggesting it wasn't set from env var if options.LRUEvictionRetryInterval == 0 { dur, err := time.ParseDuration(*settings.LRUEvictionRetryInterval) if err == nil { options.LRUEvictionRetryInterval = dur } else { xlog.Warn("invalid LRU eviction retry interval in runtime_settings.json", "error", err, "interval", *settings.LRUEvictionRetryInterval) } } } if settings.AgentJobRetentionDays != nil { // Only apply if current value is default (0), suggesting it wasn't set from env var if options.AgentJobRetentionDays == 0 { options.AgentJobRetentionDays = *settings.AgentJobRetentionDays } } if !options.WatchDogIdle && !options.WatchDogBusy { if settings.WatchdogEnabled != nil && *settings.WatchdogEnabled { options.WatchDog = true } } // P2P settings if settings.P2PToken != nil { if options.P2PToken == "" { options.P2PToken = *settings.P2PToken } } if settings.P2PNetworkID != nil { if options.P2PNetworkID == "" { options.P2PNetworkID = *settings.P2PNetworkID } } if settings.Federated != nil { if !options.Federated { options.Federated = *settings.Federated } } if settings.EnableBackendLogging != nil { if !options.EnableBackendLogging { options.EnableBackendLogging = *settings.EnableBackendLogging } } // Tracing settings if settings.EnableTracing != nil { if !options.EnableTracing { options.EnableTracing = *settings.EnableTracing } } if settings.TracingMaxItems != nil { if options.TracingMaxItems == 0 { options.TracingMaxItems = *settings.TracingMaxItems } } xlog.Debug("Runtime settings loaded from runtime_settings.json") } // initializeWatchdog initializes the watchdog with current ApplicationConfig settings func initializeWatchdog(application *Application, options *config.ApplicationConfig) { // Get effective max active backends (considers both MaxActiveBackends and deprecated SingleBackend) lruLimit := options.GetEffectiveMaxActiveBackends() // Create watchdog if enabled OR if LRU limit is set OR if memory reclaimer is enabled if options.WatchDog || lruLimit > 0 || options.MemoryReclaimerEnabled { wd := model.NewWatchDog( model.WithProcessManager(application.ModelLoader()), model.WithBusyTimeout(options.WatchDogBusyTimeout), model.WithIdleTimeout(options.WatchDogIdleTimeout), model.WithWatchdogInterval(options.WatchDogInterval), model.WithBusyCheck(options.WatchDogBusy), model.WithIdleCheck(options.WatchDogIdle), model.WithLRULimit(lruLimit), model.WithMemoryReclaimer(options.MemoryReclaimerEnabled, options.MemoryReclaimerThreshold), model.WithForceEvictionWhenBusy(options.ForceEvictionWhenBusy), ) application.ModelLoader().SetWatchDog(wd) // Initialize ModelLoader LRU eviction retry settings application.ModelLoader().SetLRUEvictionRetrySettings( options.LRUEvictionMaxRetries, options.LRUEvictionRetryInterval, ) // Start watchdog goroutine if any periodic checks are enabled // LRU eviction doesn't need the Run() loop - it's triggered on model load // But memory reclaimer needs the Run() loop for periodic checking if options.WatchDogBusy || options.WatchDogIdle || options.MemoryReclaimerEnabled { go wd.Run() } go func() { <-options.Context.Done() xlog.Debug("Context canceled, shutting down") wd.Shutdown() }() } } // loadOrGenerateHMACSecret loads an HMAC secret from the given file path, // or generates a random 32-byte secret and persists it if the file doesn't exist. func loadOrGenerateHMACSecret(path string) (string, error) { data, err := os.ReadFile(path) if err == nil { secret := string(data) if len(secret) >= 32 { return secret, nil } } b := make([]byte, 32) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("failed to generate HMAC secret: %w", err) } secret := hex.EncodeToString(b) if err := os.WriteFile(path, []byte(secret), 0600); err != nil { return "", fmt.Errorf("failed to persist HMAC secret: %w", err) } xlog.Info("Generated new HMAC secret for API key hashing", "path", path) return secret, nil } // migrateDataFiles moves persistent data files from the old config directory // to the new data directory. Only moves files that exist in src but not in dst. func migrateDataFiles(srcDir, dstDir string) { // Files and directories to migrate items := []string{ "agent_tasks.json", "agent_jobs.json", "collections", "assets", } migrated := false for _, item := range items { srcPath := filepath.Join(srcDir, item) dstPath := filepath.Join(dstDir, item) // Only migrate if source exists and destination does not if _, err := os.Stat(srcPath); os.IsNotExist(err) { continue } if _, err := os.Stat(dstPath); err == nil { continue // destination already exists, skip } if err := os.Rename(srcPath, dstPath); err != nil { xlog.Warn("Failed to migrate data file, will copy instead", "src", srcPath, "dst", dstPath, "error", err) // os.Rename fails across filesystems, fall back to leaving in place // and log a warning for the user to manually move xlog.Warn("Data file remains in old location, please move manually", "src", srcPath, "dst", dstPath) continue } migrated = true xlog.Info("Migrated data file to new data path", "src", srcPath, "dst", dstPath) } if migrated { xlog.Info("Data migration complete", "from", srcDir, "to", dstDir) } } ================================================ FILE: core/application/watchdog.go ================================================ package application import ( "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) func (a *Application) StopWatchdog() error { if a.watchdogStop != nil { close(a.watchdogStop) a.watchdogStop = nil } return nil } // startWatchdog starts the watchdog with current ApplicationConfig settings // This is an internal method that assumes the caller holds the watchdogMutex func (a *Application) startWatchdog() error { appConfig := a.ApplicationConfig() // Get effective max active backends (considers both MaxActiveBackends and deprecated SingleBackend) lruLimit := appConfig.GetEffectiveMaxActiveBackends() // Create watchdog if enabled OR if LRU limit is set OR if memory reclaimer is enabled // LRU eviction requires watchdog infrastructure even without busy/idle checks if appConfig.WatchDog || lruLimit > 0 || appConfig.MemoryReclaimerEnabled { wd := model.NewWatchDog( model.WithProcessManager(a.modelLoader), model.WithBusyTimeout(appConfig.WatchDogBusyTimeout), model.WithIdleTimeout(appConfig.WatchDogIdleTimeout), model.WithWatchdogInterval(appConfig.WatchDogInterval), model.WithBusyCheck(appConfig.WatchDogBusy), model.WithIdleCheck(appConfig.WatchDogIdle), model.WithLRULimit(lruLimit), model.WithMemoryReclaimer(appConfig.MemoryReclaimerEnabled, appConfig.MemoryReclaimerThreshold), model.WithForceEvictionWhenBusy(appConfig.ForceEvictionWhenBusy), ) // Create new stop channel BEFORE setting up any goroutines // This prevents race conditions where the old shutdown handler might // receive the closed channel and try to shut down the new watchdog a.watchdogStop = make(chan bool, 1) // Set the watchdog on the model loader a.modelLoader.SetWatchDog(wd) // Start watchdog goroutine if any periodic checks are enabled // LRU eviction doesn't need the Run() loop - it's triggered on model load // But memory reclaimer needs the Run() loop for periodic checking if appConfig.WatchDogBusy || appConfig.WatchDogIdle || appConfig.MemoryReclaimerEnabled { go wd.Run() } // Setup shutdown handler - this goroutine will wait on a.watchdogStop // which is now a fresh channel, so it won't receive any stale signals // Note: We capture wd in a local variable to ensure this handler operates // on the correct watchdog instance (not a later one that gets assigned to wd) wdForShutdown := wd go func() { select { case <-a.watchdogStop: xlog.Debug("Watchdog stop signal received") wdForShutdown.Shutdown() case <-appConfig.Context.Done(): xlog.Debug("Context canceled, shutting down watchdog") wdForShutdown.Shutdown() } }() xlog.Info("Watchdog started with new settings", "lruLimit", lruLimit, "busyCheck", appConfig.WatchDogBusy, "idleCheck", appConfig.WatchDogIdle, "memoryReclaimer", appConfig.MemoryReclaimerEnabled, "memoryThreshold", appConfig.MemoryReclaimerThreshold, "interval", appConfig.WatchDogInterval) } else { xlog.Info("Watchdog disabled") } return nil } // StartWatchdog starts the watchdog with current ApplicationConfig settings func (a *Application) StartWatchdog() error { a.watchdogMutex.Lock() defer a.watchdogMutex.Unlock() return a.startWatchdog() } // RestartWatchdog restarts the watchdog with current ApplicationConfig settings func (a *Application) RestartWatchdog() error { a.watchdogMutex.Lock() defer a.watchdogMutex.Unlock() // Get the old watchdog before we shut it down oldWD := a.modelLoader.GetWatchDog() // Get the state from the old watchdog before shutting it down // This preserves information about loaded models var oldState model.WatchDogState if oldWD != nil { oldState = oldWD.GetState() } // Signal all handlers to stop by closing the stop channel // This will cause any goroutine waiting on <-a.watchdogStop to unblock if a.watchdogStop != nil { close(a.watchdogStop) a.watchdogStop = nil } // Shutdown existing watchdog - this triggers the stop signal if oldWD != nil { oldWD.Shutdown() // Wait for the old watchdog's Run() goroutine to fully shut down oldWD.WaitDone() } // Start watchdog with new settings if err := a.startWatchdog(); err != nil { return err } // Restore the model state from the old watchdog to the new one // This ensures the new watchdog knows about already-loaded models newWD := a.modelLoader.GetWatchDog() if newWD != nil && len(oldState.AddressModelMap) > 0 { newWD.RestoreState(oldState) } return nil } ================================================ FILE: core/backend/backend_suite_test.go ================================================ package backend_test import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestBackend(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Backend test suite") } ================================================ FILE: core/backend/detection.go ================================================ package backend import ( "context" "fmt" "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" ) func Detection( sourceFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig, ) (*proto.DetectResponse, error) { opts := ModelOptions(modelConfig, appConfig) detectionModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return nil, err } if detectionModel == nil { return nil, fmt.Errorf("could not load detection model") } var startTime time.Time if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() } res, err := detectionModel.Detect(context.Background(), &proto.DetectOptions{ Src: sourceFile, }) if appConfig.EnableTracing { errStr := "" if err != nil { errStr = err.Error() } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), Type: trace.BackendTraceDetection, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(sourceFile, 200), Error: errStr, Data: map[string]any{ "source_file": sourceFile, }, }) } return res, err } ================================================ FILE: core/backend/embeddings.go ================================================ package backend import ( "fmt" "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc" model "github.com/mudler/LocalAI/pkg/model" ) func ModelEmbedding(s string, tokens []int, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() ([]float32, error), error) { opts := ModelOptions(modelConfig, appConfig) inferenceModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return nil, err } var fn func() ([]float32, error) switch model := inferenceModel.(type) { case grpc.Backend: fn = func() ([]float32, error) { predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath) if len(tokens) > 0 { embeds := []int32{} for _, t := range tokens { embeds = append(embeds, int32(t)) } predictOptions.EmbeddingTokens = embeds res, err := model.Embeddings(appConfig.Context, predictOptions) if err != nil { return nil, err } return res.Embeddings, nil } predictOptions.Embeddings = s res, err := model.Embeddings(appConfig.Context, predictOptions) if err != nil { return nil, err } return res.Embeddings, nil } default: fn = func() ([]float32, error) { return nil, fmt.Errorf("embeddings not supported by the backend") } } wrappedFn := func() ([]float32, error) { embeds, err := fn() if err != nil { return embeds, err } // Return embeddings as-is to preserve full dimensionality // Trailing zeros may be valid values in some embedding models return embeds, nil } if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) traceData := map[string]any{ "input_text": trace.TruncateString(s, 1000), "input_tokens_count": len(tokens), } startTime := time.Now() originalFn := wrappedFn wrappedFn = func() ([]float32, error) { result, err := originalFn() duration := time.Since(startTime) traceData["embedding_dimensions"] = len(result) errStr := "" if err != nil { errStr = err.Error() } summary := trace.TruncateString(s, 200) if summary == "" { summary = fmt.Sprintf("tokens[%d]", len(tokens)) } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: duration, Type: trace.BackendTraceEmbedding, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: summary, Error: errStr, Data: traceData, }) return result, err } } return wrappedFn, nil } ================================================ FILE: core/backend/image.go ================================================ package backend import ( "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" ) func ImageGeneration(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { opts := ModelOptions(modelConfig, appConfig) inferenceModel, err := loader.Load( opts..., ) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return nil, err } fn := func() error { _, err := inferenceModel.GenerateImage( appConfig.Context, &proto.GenerateImageRequest{ Height: int32(height), Width: int32(width), Step: int32(step), Seed: int32(seed), CLIPSkip: int32(modelConfig.Diffusers.ClipSkip), PositivePrompt: positive_prompt, NegativePrompt: negative_prompt, Dst: dst, Src: src, EnableParameters: modelConfig.Diffusers.EnableParameters, RefImages: refImages, }) return err } if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) traceData := map[string]any{ "positive_prompt": positive_prompt, "negative_prompt": negative_prompt, "height": height, "width": width, "step": step, "seed": seed, "source_image": src, "destination": dst, } startTime := time.Now() originalFn := fn fn = func() error { err := originalFn() duration := time.Since(startTime) errStr := "" if err != nil { errStr = err.Error() } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: duration, Type: trace.BackendTraceImageGeneration, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(positive_prompt, 200), Error: errStr, Data: traceData, }) return err } } return fn, nil } // ImageGenerationFunc is a test-friendly indirection to call image generation logic. // Tests can override this variable to provide a stub implementation. var ImageGenerationFunc = ImageGeneration ================================================ FILE: core/backend/llm.go ================================================ package backend import ( "context" "encoding/json" "regexp" "slices" "strings" "sync" "time" "unicode/utf8" "github.com/mudler/xlog" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" ) type LLMResponse struct { Response string // should this be []byte? Usage TokenUsage AudioOutput string Logprobs *schema.Logprobs // Logprobs from the backend response ChatDeltas []*proto.ChatDelta // Pre-parsed tool calls/content from C++ autoparser } type TokenUsage struct { Prompt int Completion int TimingPromptProcessing float64 TimingTokenGeneration float64 } // ModelInferenceFunc is a test-friendly indirection to call model inference logic. // Tests can override this variable to provide a stub implementation. var ModelInferenceFunc = ModelInference func ModelInference(ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, TokenUsage) bool, tools string, toolChoice string, logprobs *int, topLogprobs *int, logitBias map[string]float64, metadata map[string]string) (func() (LLMResponse, error), error) { modelFile := c.Model // Check if the modelFile exists, if it doesn't try to load it from the gallery if o.AutoloadGalleries { // experimental modelNames, err := services.ListModels(cl, loader, nil, services.SKIP_ALWAYS) if err != nil { return nil, err } if !slices.Contains(modelNames, c.Name) { utils.ResetDownloadTimers() // if we failed to load the model, we try to download it err := gallery.InstallModelFromGallery(ctx, o.Galleries, o.BackendGalleries, o.SystemState, loader, c.Name, gallery.GalleryModel{}, utils.DisplayDownloadFunction, o.EnforcePredownloadScans, o.AutoloadBackendGalleries) if err != nil { xlog.Error("failed to install model from gallery", "error", err, "model", modelFile) //return nil, err } } } opts := ModelOptions(*c, o) inferenceModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(o, c.Name, c.Backend, err, map[string]any{"model_file": modelFile}) return nil, err } // Detect thinking support after model load (only if not already detected) // This needs to happen after LoadModel succeeds so the backend can render templates if (c.ReasoningConfig.DisableReasoning == nil && c.ReasoningConfig.DisableReasoningTagPrefill == nil) && c.TemplateConfig.UseTokenizerTemplate { modelOpts := grpcModelOpts(*c, o.SystemState.Model.ModelsPath) config.DetectThinkingSupportFromBackend(ctx, c, inferenceModel, modelOpts) // Update the config in the loader so it persists for future requests cl.UpdateModelConfig(c.Name, func(cfg *config.ModelConfig) { cfg.ReasoningConfig.DisableReasoning = c.ReasoningConfig.DisableReasoning cfg.ReasoningConfig.DisableReasoningTagPrefill = c.ReasoningConfig.DisableReasoningTagPrefill }) } var protoMessages []*proto.Message // if we are using the tokenizer template, we need to convert the messages to proto messages // unless the prompt has already been tokenized (non-chat endpoints + functions) if c.TemplateConfig.UseTokenizerTemplate && len(messages) > 0 { protoMessages = messages.ToProto() } // in GRPC, the backend is supposed to answer to 1 single token if stream is not supported var capturedPredictOpts *proto.PredictOptions fn := func() (LLMResponse, error) { opts := gRPCPredictOpts(*c, loader.ModelPath) // Merge request-level metadata (overrides config defaults) for k, v := range metadata { opts.Metadata[k] = v } opts.Prompt = s opts.Messages = protoMessages opts.UseTokenizerTemplate = c.TemplateConfig.UseTokenizerTemplate opts.Images = images opts.Videos = videos opts.Audios = audios opts.Tools = tools opts.ToolChoice = toolChoice if logprobs != nil { opts.Logprobs = int32(*logprobs) } if topLogprobs != nil { opts.TopLogprobs = int32(*topLogprobs) } if len(logitBias) > 0 { // Serialize logit_bias map to JSON string for proto logitBiasJSON, err := json.Marshal(logitBias) if err == nil { opts.LogitBias = string(logitBiasJSON) } } capturedPredictOpts = opts tokenUsage := TokenUsage{} // check the per-model feature flag for usage, since tokenCallback may have a cost. // Defaults to off as for now it is still experimental if c.FeatureFlag.Enabled("usage") { userTokenCallback := tokenCallback if userTokenCallback == nil { userTokenCallback = func(token string, usage TokenUsage) bool { return true } } promptInfo, pErr := inferenceModel.TokenizeString(ctx, opts) if pErr == nil && promptInfo.Length > 0 { tokenUsage.Prompt = int(promptInfo.Length) } tokenCallback = func(token string, usage TokenUsage) bool { tokenUsage.Completion++ return userTokenCallback(token, tokenUsage) } } if tokenCallback != nil { if c.TemplateConfig.ReplyPrefix != "" { tokenCallback(c.TemplateConfig.ReplyPrefix, tokenUsage) } ss := "" var logprobs *schema.Logprobs var allChatDeltas []*proto.ChatDelta var partialRune []byte err := inferenceModel.PredictStream(ctx, opts, func(reply *proto.Reply) { msg := reply.Message partialRune = append(partialRune, msg...) tokenUsage.Prompt = int(reply.PromptTokens) tokenUsage.Completion = int(reply.Tokens) tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing // Collect chat deltas from C++ autoparser if len(reply.ChatDeltas) > 0 { allChatDeltas = append(allChatDeltas, reply.ChatDeltas...) } // Parse logprobs from reply if present (collect from last chunk that has them) if len(reply.Logprobs) > 0 { var parsedLogprobs schema.Logprobs if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil { logprobs = &parsedLogprobs } } // Process complete runes and accumulate them var completeRunes []byte for len(partialRune) > 0 { r, size := utf8.DecodeRune(partialRune) if r == utf8.RuneError { // incomplete rune, wait for more bytes break } completeRunes = append(completeRunes, partialRune[:size]...) partialRune = partialRune[size:] } // If we have complete runes, send them as a single token if len(completeRunes) > 0 { tokenCallback(string(completeRunes), tokenUsage) ss += string(completeRunes) } if len(msg) == 0 { tokenCallback("", tokenUsage) } }) if len(allChatDeltas) > 0 { xlog.Debug("[ChatDeltas] streaming completed, accumulated deltas from C++ autoparser", "total_deltas", len(allChatDeltas)) } return LLMResponse{ Response: ss, Usage: tokenUsage, Logprobs: logprobs, ChatDeltas: allChatDeltas, }, err } else { // TODO: Is the chicken bit the only way to get here? is that acceptable? reply, err := inferenceModel.Predict(ctx, opts) if err != nil { return LLMResponse{}, err } if tokenUsage.Prompt == 0 { tokenUsage.Prompt = int(reply.PromptTokens) } if tokenUsage.Completion == 0 { tokenUsage.Completion = int(reply.Tokens) } tokenUsage.TimingTokenGeneration = reply.TimingTokenGeneration tokenUsage.TimingPromptProcessing = reply.TimingPromptProcessing response := string(reply.Message) if c.TemplateConfig.ReplyPrefix != "" { response = c.TemplateConfig.ReplyPrefix + response } // Parse logprobs from reply if present var logprobs *schema.Logprobs if len(reply.Logprobs) > 0 { var parsedLogprobs schema.Logprobs if err := json.Unmarshal(reply.Logprobs, &parsedLogprobs); err == nil { logprobs = &parsedLogprobs } } if len(reply.ChatDeltas) > 0 { xlog.Debug("[ChatDeltas] non-streaming Predict received deltas from C++ autoparser", "total_deltas", len(reply.ChatDeltas)) } return LLMResponse{ Response: response, Usage: tokenUsage, Logprobs: logprobs, ChatDeltas: reply.ChatDeltas, }, err } } if o.EnableTracing { trace.InitBackendTracingIfEnabled(o.TracingMaxItems) traceData := map[string]any{ "chat_template": c.TemplateConfig.Chat, "function_template": c.TemplateConfig.Functions, "streaming": tokenCallback != nil, "images_count": len(images), "videos_count": len(videos), "audios_count": len(audios), } if len(messages) > 0 { if msgJSON, err := json.Marshal(messages); err == nil { traceData["messages"] = string(msgJSON) } } if reasoningJSON, err := json.Marshal(c.ReasoningConfig); err == nil { traceData["reasoning_config"] = string(reasoningJSON) } traceData["functions_config"] = map[string]any{ "grammar_disabled": c.FunctionsConfig.GrammarConfig.NoGrammar, "parallel_calls": c.FunctionsConfig.GrammarConfig.ParallelCalls, "mixed_mode": c.FunctionsConfig.GrammarConfig.MixedMode, "xml_format_preset": c.FunctionsConfig.XMLFormatPreset, } startTime := time.Now() originalFn := fn fn = func() (LLMResponse, error) { resp, err := originalFn() duration := time.Since(startTime) traceData["response"] = resp.Response traceData["token_usage"] = map[string]any{ "prompt": resp.Usage.Prompt, "completion": resp.Usage.Completion, } if len(resp.ChatDeltas) > 0 { chatDeltasInfo := map[string]any{ "total_deltas": len(resp.ChatDeltas), } var contentParts, reasoningParts []string toolCallCount := 0 for _, d := range resp.ChatDeltas { if d.Content != "" { contentParts = append(contentParts, d.Content) } if d.ReasoningContent != "" { reasoningParts = append(reasoningParts, d.ReasoningContent) } toolCallCount += len(d.ToolCalls) } if len(contentParts) > 0 { chatDeltasInfo["content"] = strings.Join(contentParts, "") } if len(reasoningParts) > 0 { chatDeltasInfo["reasoning_content"] = strings.Join(reasoningParts, "") } if toolCallCount > 0 { chatDeltasInfo["tool_call_count"] = toolCallCount } traceData["chat_deltas"] = chatDeltasInfo } if capturedPredictOpts != nil { if optsJSON, err := json.Marshal(capturedPredictOpts); err == nil { var optsMap map[string]any if err := json.Unmarshal(optsJSON, &optsMap); err == nil { traceData["predict_options"] = optsMap } } } errStr := "" if err != nil { errStr = err.Error() } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: duration, Type: trace.BackendTraceLLM, ModelName: c.Name, Backend: c.Backend, Summary: trace.GenerateLLMSummary(messages, s), Error: errStr, Data: traceData, }) return resp, err } } return fn, nil } var cutstrings map[string]*regexp.Regexp = make(map[string]*regexp.Regexp) var mu sync.Mutex = sync.Mutex{} func Finetune(config config.ModelConfig, input, prediction string) string { if config.Echo { prediction = input + prediction } for _, c := range config.Cutstrings { mu.Lock() reg, ok := cutstrings[c] if !ok { r, err := regexp.Compile(c) if err != nil { xlog.Fatal("failed to compile regex", "error", err) } cutstrings[c] = r reg = cutstrings[c] } mu.Unlock() prediction = reg.ReplaceAllString(prediction, "") } // extract results from the response which can be for instance inside XML tags var predResult string for _, r := range config.ExtractRegex { mu.Lock() reg, ok := cutstrings[r] if !ok { regex, err := regexp.Compile(r) if err != nil { xlog.Fatal("failed to compile regex", "error", err) } cutstrings[r] = regex reg = regex } mu.Unlock() predResult += reg.FindString(prediction) } if predResult != "" { prediction = predResult } for _, c := range config.TrimSpace { prediction = strings.TrimSpace(strings.TrimPrefix(prediction, c)) } for _, c := range config.TrimSuffix { prediction = strings.TrimSpace(strings.TrimSuffix(prediction, c)) } return prediction } ================================================ FILE: core/backend/llm_test.go ================================================ package backend_test import ( . "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("LLM tests", func() { Context("Finetune LLM output", func() { var ( testConfig config.ModelConfig input string prediction string result string ) BeforeEach(func() { testConfig = config.ModelConfig{ PredictionOptions: schema.PredictionOptions{ Echo: false, }, LLMConfig: config.LLMConfig{ Cutstrings: []string{`<.*?>`}, // Example regex for removing XML tags ExtractRegex: []string{`(.*?)`}, // Example regex to extract from tags TrimSpace: []string{" ", "\n"}, TrimSuffix: []string{".", "!"}, }, } }) Context("when echo is enabled", func() { BeforeEach(func() { testConfig.Echo = true input = "Hello" prediction = "World" }) It("should prepend input to prediction", func() { result = Finetune(testConfig, input, prediction) Expect(result).To(Equal("HelloWorld")) }) }) Context("when echo is disabled", func() { BeforeEach(func() { testConfig.Echo = false input = "Hello" prediction = "World" }) It("should not modify the prediction with input", func() { result = Finetune(testConfig, input, prediction) Expect(result).To(Equal("World")) }) }) Context("when cutstrings regex is applied", func() { BeforeEach(func() { input = "" prediction = "
Hello
World" }) It("should remove substrings matching cutstrings regex", func() { result = Finetune(testConfig, input, prediction) Expect(result).To(Equal("Hello World")) }) }) Context("when extract regex is applied", func() { BeforeEach(func() { input = "" prediction = "42" }) It("should extract substrings matching the extract regex", func() { result = Finetune(testConfig, input, prediction) Expect(result).To(Equal("42")) }) }) Context("when trimming spaces", func() { BeforeEach(func() { input = "" prediction = " Hello World " }) It("should trim spaces from the prediction", func() { result = Finetune(testConfig, input, prediction) Expect(result).To(Equal("Hello World")) }) }) Context("when trimming suffixes", func() { BeforeEach(func() { input = "" prediction = "Hello World." }) It("should trim suffixes from the prediction", func() { result = Finetune(testConfig, input, prediction) Expect(result).To(Equal("Hello World")) }) }) }) }) ================================================ FILE: core/backend/options.go ================================================ package backend import ( "math/rand" "os" "path/filepath" "strings" "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // recordModelLoadFailure records a backend trace when model loading fails. func recordModelLoadFailure(appConfig *config.ApplicationConfig, modelName, backend string, err error, data map[string]any) { if !appConfig.EnableTracing { return } trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: time.Now(), Type: trace.BackendTraceModelLoad, ModelName: modelName, Backend: backend, Summary: "Model load failed", Error: err.Error(), Data: data, }) } func ModelOptions(c config.ModelConfig, so *config.ApplicationConfig, opts ...model.Option) []model.Option { name := c.Name if name == "" { name = c.Model } defOpts := []model.Option{ model.WithBackendString(c.Backend), model.WithModel(c.Model), model.WithContext(so.Context), model.WithModelID(name), } threads := 1 if c.Threads != nil { threads = *c.Threads } if so.Threads != 0 { threads = so.Threads } c.Threads = &threads grpcOpts := grpcModelOpts(c, so.SystemState.Model.ModelsPath) defOpts = append(defOpts, model.WithLoadGRPCLoadModelOpts(grpcOpts)) if so.ParallelBackendRequests { defOpts = append(defOpts, model.EnableParallelRequests) } if c.GRPC.Attempts != 0 { defOpts = append(defOpts, model.WithGRPCAttempts(c.GRPC.Attempts)) } if c.GRPC.AttemptsSleepTime != 0 { defOpts = append(defOpts, model.WithGRPCAttemptsDelay(c.GRPC.AttemptsSleepTime)) } for k, v := range so.ExternalGRPCBackends { defOpts = append(defOpts, model.WithExternalBackend(k, v)) } return append(defOpts, opts...) } func getSeed(c config.ModelConfig) int32 { var seed int32 = config.RAND_SEED if c.Seed != nil { seed = int32(*c.Seed) } if seed == config.RAND_SEED { seed = rand.Int31() } return seed } func grpcModelOpts(c config.ModelConfig, modelPath string) *pb.ModelOptions { b := 512 if c.Batch != 0 { b = c.Batch } flashAttention := "auto" if c.FlashAttention != nil { flashAttention = *c.FlashAttention } f16 := false if c.F16 != nil { f16 = *c.F16 } embeddings := false if c.Embeddings != nil { embeddings = *c.Embeddings } lowVRAM := false if c.LowVRAM != nil { lowVRAM = *c.LowVRAM } reranking := false if c.Reranking != nil { reranking = *c.Reranking } mmap := false if c.MMap != nil { mmap = *c.MMap } // Intel SYCL backend has issues with mmap enabled // See: https://github.com/mudler/LocalAI/issues/9012 // Automatically disable mmap for Intel SYCL backends if c.Backend != "" { if strings.Contains(strings.ToLower(c.Backend), "intel") || strings.Contains(strings.ToLower(c.Backend), "sycl") { mmap = false xlog.Info("Auto-disabling mmap for Intel SYCL backend", "backend", c.Backend) } } ctxSize := 4096 if c.ContextSize != nil { ctxSize = *c.ContextSize } mmlock := false if c.MMlock != nil { mmlock = *c.MMlock } nGPULayers := 9999999 if c.NGPULayers != nil { nGPULayers = *c.NGPULayers } triggers := make([]*pb.GrammarTrigger, 0) for _, t := range c.FunctionsConfig.GrammarConfig.GrammarTriggers { triggers = append(triggers, &pb.GrammarTrigger{ Word: t.Word, }) } opts := &pb.ModelOptions{ CUDA: c.CUDA || c.Diffusers.CUDA, SchedulerType: c.Diffusers.SchedulerType, GrammarTriggers: triggers, PipelineType: c.Diffusers.PipelineType, CFGScale: c.CFGScale, LoraAdapter: c.LoraAdapter, LoraScale: c.LoraScale, LoraAdapters: c.LoraAdapters, LoraScales: c.LoraScales, F16Memory: f16, LoraBase: c.LoraBase, IMG2IMG: c.Diffusers.IMG2IMG, CLIPModel: c.Diffusers.ClipModel, CLIPSubfolder: c.Diffusers.ClipSubFolder, Options: c.Options, Overrides: c.Overrides, CLIPSkip: int32(c.Diffusers.ClipSkip), ControlNet: c.Diffusers.ControlNet, ContextSize: int32(ctxSize), Seed: getSeed(c), NBatch: int32(b), NoMulMatQ: c.NoMulMatQ, DraftModel: c.DraftModel, AudioPath: c.AudioPath, Quantization: c.Quantization, LoadFormat: c.LoadFormat, GPUMemoryUtilization: c.GPUMemoryUtilization, TrustRemoteCode: c.TrustRemoteCode, EnforceEager: c.EnforceEager, SwapSpace: int32(c.SwapSpace), MaxModelLen: int32(c.MaxModelLen), TensorParallelSize: int32(c.TensorParallelSize), DisableLogStatus: c.DisableLogStatus, DType: c.DType, // LimitMMPerPrompt vLLM LimitImagePerPrompt: int32(c.LimitMMPerPrompt.LimitImagePerPrompt), LimitVideoPerPrompt: int32(c.LimitMMPerPrompt.LimitVideoPerPrompt), LimitAudioPerPrompt: int32(c.LimitMMPerPrompt.LimitAudioPerPrompt), FlashAttention: flashAttention, CacheTypeKey: c.CacheTypeK, CacheTypeValue: c.CacheTypeV, NoKVOffload: c.NoKVOffloading, YarnExtFactor: c.YarnExtFactor, YarnAttnFactor: c.YarnAttnFactor, YarnBetaFast: c.YarnBetaFast, YarnBetaSlow: c.YarnBetaSlow, NGQA: c.NGQA, RMSNormEps: c.RMSNormEps, MLock: mmlock, RopeFreqBase: c.RopeFreqBase, RopeScaling: c.RopeScaling, Type: c.ModelType, RopeFreqScale: c.RopeFreqScale, NUMA: c.NUMA, Embeddings: embeddings, Reranking: reranking, LowVRAM: lowVRAM, NGPULayers: int32(nGPULayers), MMap: mmap, MainGPU: c.MainGPU, Threads: int32(*c.Threads), TensorSplit: c.TensorSplit, // RWKV Tokenizer: c.Tokenizer, } if c.MMProj != "" { opts.MMProj = filepath.Join(modelPath, c.MMProj) } return opts } func gRPCPredictOpts(c config.ModelConfig, modelPath string) *pb.PredictOptions { promptCachePath := "" if c.PromptCachePath != "" { p := filepath.Join(modelPath, c.PromptCachePath) err := os.MkdirAll(filepath.Dir(p), 0750) if err == nil { promptCachePath = p } else { xlog.Error("error creating prompt cache folder", "error", err, "promptCachePath", promptCachePath) } } pbOpts := &pb.PredictOptions{ Temperature: float32(*c.Temperature), TopP: float32(*c.TopP), NDraft: c.NDraft, TopK: int32(*c.TopK), Tokens: int32(*c.Maxtokens), Threads: int32(*c.Threads), PromptCacheAll: c.PromptCacheAll, PromptCacheRO: c.PromptCacheRO, PromptCachePath: promptCachePath, F16KV: *c.F16, DebugMode: *c.Debug, Grammar: c.Grammar, NegativePromptScale: c.NegativePromptScale, RopeFreqBase: c.RopeFreqBase, RopeFreqScale: c.RopeFreqScale, NegativePrompt: c.NegativePrompt, Mirostat: int32(*c.LLMConfig.Mirostat), MirostatETA: float32(*c.LLMConfig.MirostatETA), MirostatTAU: float32(*c.LLMConfig.MirostatTAU), Debug: *c.Debug, StopPrompts: c.StopWords, Repeat: int32(c.RepeatLastN), FrequencyPenalty: float32(c.FrequencyPenalty), PresencePenalty: float32(c.PresencePenalty), Penalty: float32(c.RepeatPenalty), NKeep: int32(c.Keep), Batch: int32(c.Batch), IgnoreEOS: c.IgnoreEOS, Seed: getSeed(c), MLock: *c.MMlock, MMap: *c.MMap, MainGPU: c.MainGPU, TensorSplit: c.TensorSplit, TailFreeSamplingZ: float32(*c.TFZ), TypicalP: float32(*c.TypicalP), } metadata := map[string]string{} if c.ReasoningConfig.DisableReasoning != nil { if *c.ReasoningConfig.DisableReasoning { metadata["enable_thinking"] = "false" } else { metadata["enable_thinking"] = "true" } } pbOpts.Metadata = metadata // Logprobs and TopLogprobs are set by the caller if provided return pbOpts } ================================================ FILE: core/backend/rerank.go ================================================ package backend import ( "context" "fmt" "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" ) func Rerank(request *proto.RerankRequest, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.RerankResult, error) { opts := ModelOptions(modelConfig, appConfig) rerankModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return nil, err } if rerankModel == nil { return nil, fmt.Errorf("could not load rerank model") } var startTime time.Time if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() } res, err := rerankModel.Rerank(context.Background(), request) if appConfig.EnableTracing { errStr := "" if err != nil { errStr = err.Error() } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), Type: trace.BackendTraceRerank, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(request.Query, 200), Error: errStr, Data: map[string]any{ "query": request.Query, "documents_count": len(request.Documents), "top_n": request.TopN, }, }) } return res, err } ================================================ FILE: core/backend/soundgeneration.go ================================================ package backend import ( "context" "fmt" "os" "path/filepath" "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" ) func SoundGeneration( text string, duration *float32, temperature *float32, doSample *bool, sourceFile *string, sourceDivisor *int32, think *bool, caption string, lyrics string, bpm *int32, keyscale string, language string, timesignature string, instrumental *bool, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig, ) (string, *proto.Result, error) { opts := ModelOptions(modelConfig, appConfig) soundGenModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return "", nil, err } if soundGenModel == nil { return "", nil, fmt.Errorf("could not load sound generation model") } if err := os.MkdirAll(appConfig.GeneratedContentDir, 0750); err != nil { return "", nil, fmt.Errorf("failed creating audio directory: %s", err) } audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio") if err := os.MkdirAll(audioDir, 0750); err != nil { return "", nil, fmt.Errorf("failed creating audio directory: %s", err) } fileName := utils.GenerateUniqueFileName(audioDir, "sound_generation", ".wav") filePath := filepath.Join(audioDir, fileName) if filePath, err = filepath.Abs(filePath); err != nil { return "", nil, fmt.Errorf("failed resolving sound generation path: %w", err) } req := &proto.SoundGenerationRequest{ Text: text, Model: modelConfig.Model, Dst: filePath, Sample: doSample, Duration: duration, Temperature: temperature, Src: sourceFile, SrcDivisor: sourceDivisor, } if think != nil { req.Think = think } if caption != "" { req.Caption = &caption } if lyrics != "" { req.Lyrics = &lyrics } if bpm != nil { req.Bpm = bpm } if keyscale != "" { req.Keyscale = &keyscale } if language != "" { req.Language = &language } if timesignature != "" { req.Timesignature = ×ignature } if instrumental != nil { req.Instrumental = instrumental } var startTime time.Time if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() } res, err := soundGenModel.SoundGeneration(context.Background(), req) if appConfig.EnableTracing { errStr := "" if err != nil { errStr = err.Error() } else if res != nil && !res.Success { errStr = fmt.Sprintf("sound generation error: %s", res.Message) } summary := trace.TruncateString(text, 200) if summary == "" && caption != "" { summary = trace.TruncateString(caption, 200) } traceData := map[string]any{ "text": text, "caption": caption, "lyrics": lyrics, } if duration != nil { traceData["duration"] = *duration } if temperature != nil { traceData["temperature"] = *temperature } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), Type: trace.BackendTraceSoundGeneration, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: summary, Error: errStr, Data: traceData, }) } if err != nil { return "", nil, err } if res != nil && !res.Success { return "", nil, fmt.Errorf("error during sound generation: %s", res.Message) } return filePath, res, nil } ================================================ FILE: core/backend/stores.go ================================================ package backend import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/model" ) func StoreBackend(sl *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string, backend string) (grpc.Backend, error) { if backend == "" { backend = model.LocalStoreBackend } sc := []model.Option{ model.WithBackendString(backend), model.WithModel(storeName), } return sl.Load(sc...) } ================================================ FILE: core/backend/token_metrics.go ================================================ package backend import ( "context" "fmt" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" ) func TokenMetrics( modelFile string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*proto.MetricsResponse, error) { opts := ModelOptions(modelConfig, appConfig, model.WithModel(modelFile)) model, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return nil, err } if model == nil { return nil, fmt.Errorf("could not loadmodel model") } res, err := model.GetTokenMetrics(context.Background(), &proto.MetricsRequest{}) return res, err } ================================================ FILE: core/backend/tokenize.go ================================================ package backend import ( "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/model" ) func ModelTokenize(s string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (schema.TokenizeResponse, error) { var inferenceModel grpc.Backend var err error opts := ModelOptions(modelConfig, appConfig) inferenceModel, err = loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return schema.TokenizeResponse{}, err } predictOptions := gRPCPredictOpts(modelConfig, loader.ModelPath) predictOptions.Prompt = s var startTime time.Time if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() } // tokenize the string resp, err := inferenceModel.TokenizeString(appConfig.Context, predictOptions) if appConfig.EnableTracing { errStr := "" if err != nil { errStr = err.Error() } tokenCount := 0 if resp.Tokens != nil { tokenCount = len(resp.Tokens) } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), Type: trace.BackendTraceTokenize, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(s, 200), Error: errStr, Data: map[string]any{ "input_text": trace.TruncateString(s, 1000), "token_count": tokenCount, }, }) } if err != nil { return schema.TokenizeResponse{}, err } if resp.Tokens == nil { resp.Tokens = make([]int32, 0) } return schema.TokenizeResponse{ Tokens: resp.Tokens, }, nil } ================================================ FILE: core/backend/transcript.go ================================================ package backend import ( "context" "fmt" "maps" "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" ) func ModelTranscription(audio, language string, translate, diarize bool, prompt string, ml *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (*schema.TranscriptionResult, error) { if modelConfig.Backend == "" { modelConfig.Backend = model.WhisperBackend } opts := ModelOptions(modelConfig, appConfig) transcriptionModel, err := ml.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return nil, err } if transcriptionModel == nil { return nil, fmt.Errorf("could not load transcription model") } var startTime time.Time var audioSnippet map[string]any if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() // Capture audio before the backend call — the backend may delete the file. audioSnippet = trace.AudioSnippet(audio) } r, err := transcriptionModel.AudioTranscription(context.Background(), &proto.TranscriptRequest{ Dst: audio, Language: language, Translate: translate, Diarize: diarize, Threads: uint32(*modelConfig.Threads), Prompt: prompt, }) if err != nil { if appConfig.EnableTracing { errData := map[string]any{ "audio_file": audio, "language": language, "translate": translate, "diarize": diarize, "prompt": prompt, } if audioSnippet != nil { maps.Copy(errData, audioSnippet) } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), Type: trace.BackendTraceTranscription, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(audio, 200), Error: err.Error(), Data: errData, }) } return nil, err } tr := &schema.TranscriptionResult{ Text: r.Text, } for _, s := range r.Segments { var tks []int for _, t := range s.Tokens { tks = append(tks, int(t)) } tr.Segments = append(tr.Segments, schema.TranscriptionSegment{ Text: s.Text, Id: int(s.Id), Start: time.Duration(s.Start), End: time.Duration(s.End), Tokens: tks, Speaker: s.Speaker, }) } if appConfig.EnableTracing { data := map[string]any{ "audio_file": audio, "language": language, "translate": translate, "diarize": diarize, "prompt": prompt, "result_text": tr.Text, "segments_count": len(tr.Segments), } if audioSnippet != nil { maps.Copy(data, audioSnippet) } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), Type: trace.BackendTraceTranscription, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(audio+" -> "+tr.Text, 200), Data: data, }) } return tr, err } ================================================ FILE: core/backend/tts.go ================================================ package backend import ( "bytes" "context" "encoding/binary" "encoding/json" "fmt" "maps" "os" "path/filepath" "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" laudio "github.com/mudler/LocalAI/pkg/audio" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" ) func ModelTTS( text, voice, language string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig, ) (string, *proto.Result, error) { opts := ModelOptions(modelConfig, appConfig) ttsModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return "", nil, err } if ttsModel == nil { return "", nil, fmt.Errorf("could not load tts model %q", modelConfig.Model) } audioDir := filepath.Join(appConfig.GeneratedContentDir, "audio") if err := os.MkdirAll(audioDir, 0750); err != nil { return "", nil, fmt.Errorf("failed creating audio directory: %s", err) } fileName := utils.GenerateUniqueFileName(audioDir, "tts", ".wav") filePath := filepath.Join(audioDir, fileName) // We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect. // This should be addressed in a follow up PR soon. // Copying it over nearly verbatim, as TTS backends are not functional without this. modelPath := "" // Checking first that it exists and is not outside ModelPath // TODO: we should actually first check if the modelFile is looking like // a FS path mp := filepath.Join(loader.ModelPath, modelConfig.Model) if _, err := os.Stat(mp); err == nil { if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil { return "", nil, err } modelPath = mp } else { modelPath = modelConfig.Model // skip this step if it fails????? } var startTime time.Time if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() } res, err := ttsModel.TTS(context.Background(), &proto.TTSRequest{ Text: text, Model: modelPath, Voice: voice, Dst: filePath, Language: &language, }) if appConfig.EnableTracing { errStr := "" if err != nil { errStr = err.Error() } else if !res.Success { errStr = fmt.Sprintf("TTS error: %s", res.Message) } data := map[string]any{ "text": text, "voice": voice, "language": language, } if err == nil && res.Success { if snippet := trace.AudioSnippet(filePath); snippet != nil { maps.Copy(data, snippet) } } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), Type: trace.BackendTraceTTS, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(text, 200), Error: errStr, Data: data, }) } if err != nil { return "", nil, err } // return RPC error if any if !res.Success { return "", nil, fmt.Errorf("error during TTS: %s", res.Message) } return filePath, res, err } func ModelTTSStream( text, voice, language string, loader *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig, audioCallback func([]byte) error, ) error { opts := ModelOptions(modelConfig, appConfig) ttsModel, err := loader.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return err } if ttsModel == nil { return fmt.Errorf("could not load tts model %q", modelConfig.Model) } // We join the model name to the model path here. This seems to only be done for TTS and is HIGHLY suspect. // This should be addressed in a follow up PR soon. // Copying it over nearly verbatim, as TTS backends are not functional without this. modelPath := "" // Checking first that it exists and is not outside ModelPath // TODO: we should actually first check if the modelFile is looking like // a FS path mp := filepath.Join(loader.ModelPath, modelConfig.Model) if _, err := os.Stat(mp); err == nil { if err := utils.VerifyPath(mp, appConfig.SystemState.Model.ModelsPath); err != nil { return err } modelPath = mp } else { modelPath = modelConfig.Model // skip this step if it fails????? } var startTime time.Time if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) startTime = time.Now() } var sampleRate uint32 = 16000 // default headerSent := false var callbackErr error // Collect up to 30s of audio for tracing var snippetPCM []byte var totalPCMBytes int snippetCapped := false err = ttsModel.TTSStream(context.Background(), &proto.TTSRequest{ Text: text, Model: modelPath, Voice: voice, Language: &language, }, func(reply *proto.Reply) { // First message contains sample rate info if !headerSent && len(reply.Message) > 0 { var info map[string]any if json.Unmarshal(reply.Message, &info) == nil { if sr, ok := info["sample_rate"].(float64); ok { sampleRate = uint32(sr) } } // Send WAV header with placeholder size (0xFFFFFFFF for streaming) header := laudio.WAVHeader{ ChunkID: [4]byte{'R', 'I', 'F', 'F'}, ChunkSize: 0xFFFFFFFF, // Unknown size for streaming Format: [4]byte{'W', 'A', 'V', 'E'}, Subchunk1ID: [4]byte{'f', 'm', 't', ' '}, Subchunk1Size: 16, AudioFormat: 1, // PCM NumChannels: 1, // Mono SampleRate: sampleRate, ByteRate: sampleRate * 2, // SampleRate * BlockAlign BlockAlign: 2, // 16-bit = 2 bytes BitsPerSample: 16, Subchunk2ID: [4]byte{'d', 'a', 't', 'a'}, Subchunk2Size: 0xFFFFFFFF, // Unknown size for streaming } var buf bytes.Buffer if writeErr := binary.Write(&buf, binary.LittleEndian, header); writeErr != nil { callbackErr = writeErr return } if writeErr := audioCallback(buf.Bytes()); writeErr != nil { callbackErr = writeErr return } headerSent = true } // Stream audio chunks if len(reply.Audio) > 0 { if writeErr := audioCallback(reply.Audio); writeErr != nil { callbackErr = writeErr } // Accumulate PCM for tracing snippet totalPCMBytes += len(reply.Audio) if appConfig.EnableTracing && !snippetCapped { maxBytes := int(sampleRate) * 2 * trace.MaxSnippetSeconds // 16-bit mono if len(snippetPCM)+len(reply.Audio) <= maxBytes { snippetPCM = append(snippetPCM, reply.Audio...) } else { remaining := maxBytes - len(snippetPCM) if remaining > 0 { // Align to sample boundary (2 bytes per sample) remaining = remaining &^ 1 snippetPCM = append(snippetPCM, reply.Audio[:remaining]...) } snippetCapped = true } } } }) resultErr := err if callbackErr != nil { resultErr = callbackErr } if appConfig.EnableTracing { errStr := "" if resultErr != nil { errStr = resultErr.Error() } data := map[string]any{ "text": text, "voice": voice, "language": language, "streaming": true, } if resultErr == nil && len(snippetPCM) > 0 { if snippet := trace.AudioSnippetFromPCM(snippetPCM, int(sampleRate), totalPCMBytes); snippet != nil { maps.Copy(data, snippet) } } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: time.Since(startTime), Type: trace.BackendTraceTTS, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(text, 200), Error: errStr, Data: data, }) } if callbackErr != nil { return callbackErr } return err } ================================================ FILE: core/backend/vad.go ================================================ package backend import ( "context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" ) func VAD(request *schema.VADRequest, ctx context.Context, ml *model.ModelLoader, appConfig *config.ApplicationConfig, modelConfig config.ModelConfig) (*schema.VADResponse, error) { opts := ModelOptions(modelConfig, appConfig) vadModel, err := ml.Load(opts...) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return nil, err } req := proto.VADRequest{ Audio: request.Audio, } resp, err := vadModel.VAD(ctx, &req) if err != nil { return nil, err } segments := []schema.VADSegment{} for _, s := range resp.Segments { segments = append(segments, schema.VADSegment{Start: s.Start, End: s.End}) } return &schema.VADResponse{ Segments: segments, }, nil } ================================================ FILE: core/backend/video.go ================================================ package backend import ( "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/trace" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" ) func VideoGeneration(height, width int32, prompt, negativePrompt, startImage, endImage, dst string, numFrames, fps, seed int32, cfgScale float32, step int32, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig) (func() error, error) { opts := ModelOptions(modelConfig, appConfig) inferenceModel, err := loader.Load( opts..., ) if err != nil { recordModelLoadFailure(appConfig, modelConfig.Name, modelConfig.Backend, err, nil) return nil, err } fn := func() error { _, err := inferenceModel.GenerateVideo( appConfig.Context, &proto.GenerateVideoRequest{ Height: height, Width: width, Prompt: prompt, NegativePrompt: negativePrompt, StartImage: startImage, EndImage: endImage, NumFrames: numFrames, Fps: fps, Seed: seed, CfgScale: cfgScale, Step: step, Dst: dst, }) return err } if appConfig.EnableTracing { trace.InitBackendTracingIfEnabled(appConfig.TracingMaxItems) traceData := map[string]any{ "prompt": prompt, "negative_prompt": negativePrompt, "height": height, "width": width, "num_frames": numFrames, "fps": fps, "seed": seed, "cfg_scale": cfgScale, "step": step, } startTime := time.Now() originalFn := fn fn = func() error { err := originalFn() duration := time.Since(startTime) errStr := "" if err != nil { errStr = err.Error() } trace.RecordBackendTrace(trace.BackendTrace{ Timestamp: startTime, Duration: duration, Type: trace.BackendTraceVideoGeneration, ModelName: modelConfig.Name, Backend: modelConfig.Backend, Summary: trace.TruncateString(prompt, 200), Error: errStr, Data: traceData, }) return err } } return fn, nil } ================================================ FILE: core/cli/agent.go ================================================ package cli import ( "context" "encoding/json" "fmt" "os" "os/signal" "syscall" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAGI/core/state" coreTypes "github.com/mudler/LocalAGI/core/types" "github.com/mudler/xlog" ) type AgentCMD struct { Run AgentRunCMD `cmd:"" help:"Run an agent standalone (without the full LocalAI server)"` List AgentListCMD `cmd:"" help:"List agents in the pool registry"` } type AgentRunCMD struct { Name string `arg:"" optional:"" help:"Agent name to run from the pool registry (pool.json)"` Config string `short:"c" help:"Path to a JSON agent config file (alternative to loading by name)" type:"path"` Prompt string `short:"p" help:"Run in foreground mode: send a single prompt and print the response"` // Agent pool settings (mirrors RunCMD agent flags) APIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"API URL for the agent to call (e.g. http://127.0.0.1:8080)" group:"agents"` APIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"API key for the agent" group:"agents"` DefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for the agent" group:"agents"` MultimodalModel string `env:"LOCALAI_AGENT_POOL_MULTIMODAL_MODEL" help:"Multimodal model for the agent" group:"agents"` TranscriptionModel string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_MODEL" help:"Transcription model for the agent" group:"agents"` TranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Transcription language for the agent" group:"agents"` TTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"TTS model for the agent" group:"agents"` StateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" default:"agents" help:"State directory containing pool.json" type:"path" group:"agents"` Timeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Agent timeout" group:"agents"` EnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service" group:"agents"` EnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"` CustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory" group:"agents"` } func (r *AgentRunCMD) Run(ctx *cliContext.Context) error { if r.Name == "" && r.Config == "" { return fmt.Errorf("either an agent name or --config must be provided") } agentConfig, err := r.loadAgentConfig() if err != nil { return err } // Override agent config fields from CLI flags when provided r.applyOverrides(agentConfig) xlog.Info("Starting standalone agent", "name", agentConfig.Name) appConfig := r.buildAppConfig() poolService, err := services.NewAgentPoolService(appConfig) if err != nil { return fmt.Errorf("failed to create agent pool service: %w", err) } if err := poolService.Start(appConfig.Context); err != nil { return fmt.Errorf("failed to start agent pool service: %w", err) } defer poolService.Stop() pool := poolService.Pool() // Start the agent standalone (does not persist to pool.json) if err := pool.StartAgentStandalone(agentConfig.Name, agentConfig); err != nil { return fmt.Errorf("failed to start agent %q: %w", agentConfig.Name, err) } ag := pool.GetAgent(agentConfig.Name) if ag == nil { return fmt.Errorf("agent %q not found after start", agentConfig.Name) } // Foreground mode: send a single prompt and exit if r.Prompt != "" { xlog.Info("Sending prompt to agent", "agent", agentConfig.Name) result := ag.Ask(coreTypes.WithText(r.Prompt)) if result == nil { return fmt.Errorf("agent returned no result") } if result.Error != nil { return fmt.Errorf("agent error: %w", result.Error) } fmt.Println(result.Response) return nil } // Background mode: run until interrupted xlog.Info("Agent running in background mode. Press Ctrl+C to stop.", "agent", agentConfig.Name) sigCh := make(chan os.Signal, 1) signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM) <-sigCh xlog.Info("Shutting down agent", "agent", agentConfig.Name) return nil } func (r *AgentRunCMD) loadAgentConfig() (*state.AgentConfig, error) { // Load from JSON config file if r.Config != "" { data, err := os.ReadFile(r.Config) if err != nil { return nil, fmt.Errorf("failed to read config file %q: %w", r.Config, err) } var cfg state.AgentConfig if err := json.Unmarshal(data, &cfg); err != nil { return nil, fmt.Errorf("failed to parse config file %q: %w", r.Config, err) } if cfg.Name == "" { return nil, fmt.Errorf("agent config must have a name") } return &cfg, nil } // Load from pool.json by name poolFile := r.StateDir + "/pool.json" data, err := os.ReadFile(poolFile) if err != nil { return nil, fmt.Errorf("failed to read pool registry %q: %w", poolFile, err) } var pool map[string]state.AgentConfig if err := json.Unmarshal(data, &pool); err != nil { return nil, fmt.Errorf("failed to parse pool registry %q: %w", poolFile, err) } cfg, ok := pool[r.Name] if !ok { available := make([]string, 0, len(pool)) for name := range pool { available = append(available, name) } return nil, fmt.Errorf("agent %q not found in pool registry. Available agents: %v", r.Name, available) } cfg.Name = r.Name return &cfg, nil } func (r *AgentRunCMD) applyOverrides(cfg *state.AgentConfig) { if r.APIURL != "" { cfg.APIURL = r.APIURL } if r.APIKey != "" { cfg.APIKey = r.APIKey } if r.DefaultModel != "" && cfg.Model == "" { cfg.Model = r.DefaultModel } if r.MultimodalModel != "" && cfg.MultimodalModel == "" { cfg.MultimodalModel = r.MultimodalModel } if r.TranscriptionModel != "" && cfg.TranscriptionModel == "" { cfg.TranscriptionModel = r.TranscriptionModel } if r.TranscriptionLanguage != "" && cfg.TranscriptionLanguage == "" { cfg.TranscriptionLanguage = r.TranscriptionLanguage } if r.TTSModel != "" && cfg.TTSModel == "" { cfg.TTSModel = r.TTSModel } } func (r *AgentRunCMD) buildAppConfig() *config.ApplicationConfig { appConfig := &config.ApplicationConfig{ Context: context.Background(), } appConfig.AgentPool = config.AgentPoolConfig{ Enabled: true, APIURL: r.APIURL, APIKey: r.APIKey, DefaultModel: r.DefaultModel, MultimodalModel: r.MultimodalModel, TranscriptionModel: r.TranscriptionModel, TranscriptionLanguage: r.TranscriptionLanguage, TTSModel: r.TTSModel, StateDir: r.StateDir, Timeout: r.Timeout, EnableSkills: r.EnableSkills, EnableLogs: r.EnableLogs, CustomActionsDir: r.CustomActionsDir, } return appConfig } type AgentListCMD struct { StateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" default:"agents" help:"State directory containing pool.json" type:"path" group:"agents"` } func (r *AgentListCMD) Run(ctx *cliContext.Context) error { poolFile := r.StateDir + "/pool.json" data, err := os.ReadFile(poolFile) if err != nil { if os.IsNotExist(err) { fmt.Println("No agents found (pool.json does not exist)") return nil } return fmt.Errorf("failed to read pool registry %q: %w", poolFile, err) } var pool map[string]state.AgentConfig if err := json.Unmarshal(data, &pool); err != nil { return fmt.Errorf("failed to parse pool registry %q: %w", poolFile, err) } if len(pool) == 0 { fmt.Println("No agents found in pool registry") return nil } fmt.Printf("Agents in %s:\n", poolFile) for name, cfg := range pool { model := cfg.Model if model == "" { model = "(default)" } desc := cfg.Description if desc == "" { desc = "(no description)" } fmt.Printf(" - %s [model: %s] %s\n", name, model, desc) } return nil } ================================================ FILE: core/cli/agent_test.go ================================================ package cli import ( "encoding/json" "os" "path/filepath" "testing" "github.com/mudler/LocalAGI/core/state" ) func TestAgentRunCMD_LoadAgentConfigFromFile(t *testing.T) { // Create a temporary agent config file tmpDir := t.TempDir() configFile := filepath.Join(tmpDir, "agent.json") cfg := state.AgentConfig{ Name: "test-agent", Model: "llama3", SystemPrompt: "You are a helpful assistant", } data, err := json.MarshalIndent(cfg, "", " ") if err != nil { t.Fatal(err) } if err := os.WriteFile(configFile, data, 0644); err != nil { t.Fatal(err) } cmd := &AgentRunCMD{ Config: configFile, StateDir: tmpDir, } loaded, err := cmd.loadAgentConfig() if err != nil { t.Fatalf("loadAgentConfig() error: %v", err) } if loaded.Name != "test-agent" { t.Errorf("expected name %q, got %q", "test-agent", loaded.Name) } if loaded.Model != "llama3" { t.Errorf("expected model %q, got %q", "llama3", loaded.Model) } } func TestAgentRunCMD_LoadAgentConfigFromPool(t *testing.T) { tmpDir := t.TempDir() pool := map[string]state.AgentConfig{ "my-agent": { Model: "gpt-4", Description: "A test agent", SystemPrompt: "Hello", }, "other-agent": { Model: "llama3", }, } data, err := json.MarshalIndent(pool, "", " ") if err != nil { t.Fatal(err) } if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil { t.Fatal(err) } cmd := &AgentRunCMD{ Name: "my-agent", StateDir: tmpDir, } loaded, err := cmd.loadAgentConfig() if err != nil { t.Fatalf("loadAgentConfig() error: %v", err) } if loaded.Name != "my-agent" { t.Errorf("expected name %q, got %q", "my-agent", loaded.Name) } if loaded.Model != "gpt-4" { t.Errorf("expected model %q, got %q", "gpt-4", loaded.Model) } } func TestAgentRunCMD_LoadAgentConfigFromPool_NotFound(t *testing.T) { tmpDir := t.TempDir() pool := map[string]state.AgentConfig{ "existing-agent": {Model: "llama3"}, } data, err := json.MarshalIndent(pool, "", " ") if err != nil { t.Fatal(err) } if err := os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644); err != nil { t.Fatal(err) } cmd := &AgentRunCMD{ Name: "nonexistent", StateDir: tmpDir, } _, err = cmd.loadAgentConfig() if err == nil { t.Fatal("expected error for missing agent, got nil") } } func TestAgentRunCMD_LoadAgentConfigNoNameOrConfig(t *testing.T) { cmd := &AgentRunCMD{ StateDir: t.TempDir(), } _, err := cmd.loadAgentConfig() if err == nil { t.Fatal("expected error when no pool.json exists, got nil") } } func TestAgentRunCMD_ApplyOverrides(t *testing.T) { cfg := &state.AgentConfig{ Name: "test", } cmd := &AgentRunCMD{ APIURL: "http://localhost:9090", APIKey: "secret", DefaultModel: "my-model", } cmd.applyOverrides(cfg) if cfg.APIURL != "http://localhost:9090" { t.Errorf("expected APIURL %q, got %q", "http://localhost:9090", cfg.APIURL) } if cfg.APIKey != "secret" { t.Errorf("expected APIKey %q, got %q", "secret", cfg.APIKey) } if cfg.Model != "my-model" { t.Errorf("expected Model %q, got %q", "my-model", cfg.Model) } } func TestAgentRunCMD_ApplyOverridesDoesNotOverwriteExisting(t *testing.T) { cfg := &state.AgentConfig{ Name: "test", Model: "existing-model", } cmd := &AgentRunCMD{ DefaultModel: "override-model", } cmd.applyOverrides(cfg) if cfg.Model != "existing-model" { t.Errorf("expected Model to remain %q, got %q", "existing-model", cfg.Model) } } func TestAgentRunCMD_LoadConfigMissingName(t *testing.T) { tmpDir := t.TempDir() configFile := filepath.Join(tmpDir, "agent.json") // Agent config with no name cfg := state.AgentConfig{ Model: "llama3", } data, _ := json.MarshalIndent(cfg, "", " ") os.WriteFile(configFile, data, 0644) cmd := &AgentRunCMD{ Config: configFile, StateDir: tmpDir, } _, err := cmd.loadAgentConfig() if err == nil { t.Fatal("expected error for config with no name, got nil") } } func TestAgentListCMD_NoPoolFile(t *testing.T) { cmd := &AgentListCMD{ StateDir: t.TempDir(), } // Should not error, just print "no agents found" err := cmd.Run(nil) if err != nil { t.Fatalf("expected no error, got: %v", err) } } func TestAgentListCMD_WithAgents(t *testing.T) { tmpDir := t.TempDir() pool := map[string]state.AgentConfig{ "agent-a": {Model: "llama3", Description: "First agent"}, "agent-b": {Model: "gpt-4"}, } data, _ := json.MarshalIndent(pool, "", " ") os.WriteFile(filepath.Join(tmpDir, "pool.json"), data, 0644) cmd := &AgentListCMD{ StateDir: tmpDir, } err := cmd.Run(nil) if err != nil { t.Fatalf("expected no error, got: %v", err) } } ================================================ FILE: core/cli/backends.go ================================================ package cli import ( "context" "encoding/json" "fmt" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" "github.com/schollz/progressbar/v3" ) type BackendsCMDFlags struct { BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"` BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"` } type BackendsList struct { BackendsCMDFlags `embed:""` } type BackendsInstall struct { BackendArgs string `arg:"" optional:"" name:"backend" help:"Backend configuration URL to load"` Name string `arg:"" optional:"" name:"name" help:"Name of the backend"` Alias string `arg:"" optional:"" name:"alias" help:"Alias of the backend"` BackendsCMDFlags `embed:""` } type BackendsUninstall struct { BackendArgs []string `arg:"" name:"backends" help:"Backend names to uninstall"` BackendsCMDFlags `embed:""` } type BackendsCMD struct { List BackendsList `cmd:"" help:"List the backends available in your galleries" default:"withargs"` Install BackendsInstall `cmd:"" help:"Install a backend from the gallery"` Uninstall BackendsUninstall `cmd:"" help:"Uninstall a backend"` } func (bl *BackendsList) Run(ctx *cliContext.Context) error { var galleries []config.Gallery if err := json.Unmarshal([]byte(bl.BackendGalleries), &galleries); err != nil { xlog.Error("unable to load galleries", "error", err) } systemState, err := system.GetSystemState( system.WithBackendSystemPath(bl.BackendsSystemPath), system.WithBackendPath(bl.BackendsPath), ) if err != nil { return err } backends, err := gallery.AvailableBackends(galleries, systemState) if err != nil { return err } for _, backend := range backends { if backend.Installed { fmt.Printf(" * %s@%s (installed)\n", backend.Gallery.Name, backend.Name) } else { fmt.Printf(" - %s@%s\n", backend.Gallery.Name, backend.Name) } } return nil } func (bi *BackendsInstall) Run(ctx *cliContext.Context) error { var galleries []config.Gallery if err := json.Unmarshal([]byte(bi.BackendGalleries), &galleries); err != nil { xlog.Error("unable to load galleries", "error", err) } systemState, err := system.GetSystemState( system.WithBackendSystemPath(bi.BackendsSystemPath), system.WithBackendPath(bi.BackendsPath), ) if err != nil { return err } progressBar := progressbar.NewOptions( 1000, progressbar.OptionSetDescription(fmt.Sprintf("downloading backend %s", bi.BackendArgs)), progressbar.OptionShowBytes(false), progressbar.OptionClearOnFinish(), ) progressCallback := func(fileName string, current string, total string, percentage float64) { v := int(percentage * 10) err := progressBar.Set(v) if err != nil { xlog.Error("error while updating progress bar", "error", err, "filename", fileName, "value", v) } } modelLoader := model.NewModelLoader(systemState) err = services.InstallExternalBackend(context.Background(), galleries, systemState, modelLoader, progressCallback, bi.BackendArgs, bi.Name, bi.Alias) if err != nil { return err } return nil } func (bu *BackendsUninstall) Run(ctx *cliContext.Context) error { for _, backendName := range bu.BackendArgs { xlog.Info("uninstalling backend", "backend", backendName) systemState, err := system.GetSystemState( system.WithBackendSystemPath(bu.BackendsSystemPath), system.WithBackendPath(bu.BackendsPath), ) if err != nil { return err } err = gallery.DeleteBackendFromSystem(systemState, backendName) if err != nil { return err } fmt.Printf("Backend %s uninstalled successfully\n", backendName) } return nil } ================================================ FILE: core/cli/cli.go ================================================ package cli import ( cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/cli/worker" ) var CLI struct { cliContext.Context `embed:""` Run RunCMD `cmd:"" help:"Run LocalAI, this the default command if no other command is specified. Run 'local-ai run --help' for more information" default:"withargs"` Federated FederatedCLI `cmd:"" help:"Run LocalAI in federated mode"` Models ModelsCMD `cmd:"" help:"Manage LocalAI models and definitions"` Backends BackendsCMD `cmd:"" help:"Manage LocalAI backends and definitions"` TTS TTSCMD `cmd:"" help:"Convert text to speech"` SoundGeneration SoundGenerationCMD `cmd:"" help:"Generates audio files from text or audio"` Transcript TranscriptCMD `cmd:"" help:"Convert audio to text"` Worker worker.Worker `cmd:"" help:"Run workers to distribute workload (llama.cpp-only)"` Util UtilCMD `cmd:"" help:"Utility commands"` Agent AgentCMD `cmd:"" help:"Run agents standalone without the full LocalAI server"` Explorer ExplorerCMD `cmd:"" help:"Run p2p explorer"` Completion CompletionCMD `cmd:"" help:"Generate shell completion scripts for bash, zsh, or fish"` } ================================================ FILE: core/cli/completion.go ================================================ package cli import ( "fmt" "strings" "github.com/alecthomas/kong" cliContext "github.com/mudler/LocalAI/core/cli/context" ) type CompletionCMD struct { Shell string `arg:"" enum:"bash,zsh,fish" help:"Shell to generate completions for (bash, zsh, fish)"` app *kong.Application `kong:"-"` } func (c *CompletionCMD) SetApplication(app *kong.Application) { c.app = app } func (c *CompletionCMD) Run(_ *cliContext.Context) error { if c.app == nil { return fmt.Errorf("application model not available") } var script string switch c.Shell { case "bash": script = generateBashCompletion(c.app) case "zsh": script = generateZshCompletion(c.app) case "fish": script = generateFishCompletion(c.app) default: return fmt.Errorf("unsupported shell: %s", c.Shell) } fmt.Print(script) return nil } func collectCommands(node *kong.Node, prefix string) []commandInfo { var cmds []commandInfo for _, child := range node.Children { if child.Hidden { continue } name := child.Name fullName := name if prefix != "" { fullName = prefix + " " + name } help := child.Help cmds = append(cmds, commandInfo{ name: name, fullName: fullName, help: help, node: child, }) cmds = append(cmds, collectCommands(child, fullName)...) } return cmds } type commandInfo struct { name string fullName string help string node *kong.Node } func collectFlags(node *kong.Node) []flagInfo { var flags []flagInfo seen := make(map[string]bool) // Collect flags from this node and its ancestors for n := node; n != nil; n = n.Parent { for _, flag := range n.Flags { if flag.Hidden || seen[flag.Name] { continue } seen[flag.Name] = true flags = append(flags, flagInfo{ name: flag.Name, short: flag.Short, help: flag.Help, }) } } return flags } type flagInfo struct { name string short rune help string } func generateBashCompletion(app *kong.Application) string { var sb strings.Builder cmds := collectCommands(app.Node, "") topLevelCmds := []string{} for _, cmd := range cmds { if !strings.Contains(cmd.fullName, " ") { topLevelCmds = append(topLevelCmds, cmd.name) } } globalFlags := collectFlags(app.Node) globalFlagNames := []string{} for _, f := range globalFlags { globalFlagNames = append(globalFlagNames, "--"+f.name) if f.short != 0 { globalFlagNames = append(globalFlagNames, "-"+string(f.short)) } } sb.WriteString(`# bash completion for local-ai # Generated by local-ai completion bash _local_ai_completions() { local cur prev words cword _init_completion || return local commands="` + strings.Join(topLevelCmds, " ") + `" local global_flags="` + strings.Join(globalFlagNames, " ") + `" # Find the subcommand local subcmd="" local subcmd_idx=0 for ((i=1; i < cword; i++)); do case "${words[i]}" in -*) # Skip flags and their values ;; *) if [[ -z "$subcmd" ]]; then subcmd="${words[i]}" subcmd_idx=$i fi ;; esac done # If completing a flag value, don't suggest anything special if [[ "$cur" == -* ]]; then case "$subcmd" in `) // Generate flag completions per top-level command for _, cmd := range cmds { if strings.Contains(cmd.fullName, " ") { continue } flags := collectFlags(cmd.node) flagNames := []string{} for _, f := range flags { flagNames = append(flagNames, "--"+f.name) if f.short != 0 { flagNames = append(flagNames, "-"+string(f.short)) } } sb.WriteString(fmt.Sprintf(" %s)\n", cmd.name)) sb.WriteString(fmt.Sprintf(" COMPREPLY=($(compgen -W \"%s\" -- \"$cur\"))\n", strings.Join(flagNames, " "))) sb.WriteString(" return\n") sb.WriteString(" ;;\n") } sb.WriteString(` *) COMPREPLY=($(compgen -W "$global_flags" -- "$cur")) return ;; esac fi # Complete subcommands for top-level commands case "$subcmd" in `) // Generate subcommand completions for _, cmd := range cmds { if strings.Contains(cmd.fullName, " ") { continue } subcmds := []string{} for _, sub := range cmds { parts := strings.SplitN(sub.fullName, " ", 2) if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") { subcmds = append(subcmds, parts[1]) } } if len(subcmds) > 0 { sb.WriteString(fmt.Sprintf(" %s)\n", cmd.name)) sb.WriteString(fmt.Sprintf(" COMPREPLY=($(compgen -W \"%s\" -- \"$cur\"))\n", strings.Join(subcmds, " "))) sb.WriteString(" return\n") sb.WriteString(" ;;\n") } } sb.WriteString(` "") COMPREPLY=($(compgen -W "$commands" -- "$cur")) return ;; esac } complete -F _local_ai_completions local-ai `) return sb.String() } func generateZshCompletion(app *kong.Application) string { var sb strings.Builder cmds := collectCommands(app.Node, "") globalFlags := collectFlags(app.Node) sb.WriteString(`#compdef local-ai # Generated by local-ai completion zsh _local_ai() { local -a commands local -a global_flags global_flags=( `) for _, f := range globalFlags { help := strings.ReplaceAll(f.help, "'", "'\\''") help = strings.ReplaceAll(help, "[", "\\[") help = strings.ReplaceAll(help, "]", "\\]") sb.WriteString(fmt.Sprintf(" '--%s[%s]'\n", f.name, help)) if f.short != 0 { sb.WriteString(fmt.Sprintf(" '-%s[%s]'\n", string(f.short), help)) } } sb.WriteString(` ) commands=( `) for _, cmd := range cmds { if strings.Contains(cmd.fullName, " ") { continue } help := strings.ReplaceAll(cmd.help, "'", "'\\''") help = strings.ReplaceAll(help, "[", "\\[") help = strings.ReplaceAll(help, "]", "\\]") sb.WriteString(fmt.Sprintf(" '%s:%s'\n", cmd.name, help)) } sb.WriteString(` ) _arguments -C \ $global_flags \ '1:command:->command' \ '*::arg:->args' case $state in command) _describe -t commands 'local-ai commands' commands ;; args) case $words[1] in `) // Per-command completions for _, cmd := range cmds { if strings.Contains(cmd.fullName, " ") { continue } sb.WriteString(fmt.Sprintf(" %s)\n", cmd.name)) // Check for subcommands subcmds := []commandInfo{} for _, sub := range cmds { parts := strings.SplitN(sub.fullName, " ", 2) if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") { subcmds = append(subcmds, sub) } } if len(subcmds) > 0 { sb.WriteString(" local -a subcmds\n") sb.WriteString(" subcmds=(\n") for _, sub := range subcmds { parts := strings.SplitN(sub.fullName, " ", 2) help := strings.ReplaceAll(sub.help, "'", "'\\''") help = strings.ReplaceAll(help, "[", "\\[") help = strings.ReplaceAll(help, "]", "\\]") sb.WriteString(fmt.Sprintf(" '%s:%s'\n", parts[1], help)) } sb.WriteString(" )\n") sb.WriteString(" _describe -t commands 'subcommands' subcmds\n") } flags := collectFlags(cmd.node) if len(flags) > 0 { sb.WriteString(" _arguments \\\n") for i, f := range flags { help := strings.ReplaceAll(f.help, "'", "'\\''") help = strings.ReplaceAll(help, "[", "\\[") help = strings.ReplaceAll(help, "]", "\\]") suffix := " \\" if i == len(flags)-1 { suffix = "" } sb.WriteString(fmt.Sprintf(" '--%s[%s]'%s\n", f.name, help, suffix)) } } sb.WriteString(" ;;\n") } sb.WriteString(` esac ;; esac } _local_ai "$@" `) return sb.String() } func generateFishCompletion(app *kong.Application) string { var sb strings.Builder cmds := collectCommands(app.Node, "") globalFlags := collectFlags(app.Node) sb.WriteString("# fish completion for local-ai\n") sb.WriteString("# Generated by local-ai completion fish\n\n") // Disable file completions by default sb.WriteString("complete -c local-ai -f\n\n") // Global flags for _, f := range globalFlags { help := strings.ReplaceAll(f.help, "'", "\\'") args := fmt.Sprintf("complete -c local-ai -l %s", f.name) if f.short != 0 { args += fmt.Sprintf(" -s %s", string(f.short)) } args += fmt.Sprintf(" -d '%s'", help) sb.WriteString(args + "\n") } sb.WriteString("\n") // Top-level commands (no condition means they show when no subcommand is given) topLevelCmds := []string{} for _, cmd := range cmds { if strings.Contains(cmd.fullName, " ") { continue } topLevelCmds = append(topLevelCmds, cmd.name) help := strings.ReplaceAll(cmd.help, "'", "\\'") sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_use_subcommand' -a %s -d '%s'\n", cmd.name, help)) } sb.WriteString("\n") // Subcommands and per-command flags for _, cmd := range cmds { if strings.Contains(cmd.fullName, " ") { continue } // Subcommands for _, sub := range cmds { parts := strings.SplitN(sub.fullName, " ", 2) if len(parts) == 2 && parts[0] == cmd.name && !strings.Contains(parts[1], " ") { help := strings.ReplaceAll(sub.help, "'", "\\'") sb.WriteString(fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -a %s -d '%s'\n", cmd.name, parts[1], help)) } } // Per-command flags flags := collectFlags(cmd.node) for _, f := range flags { help := strings.ReplaceAll(f.help, "'", "\\'") args := fmt.Sprintf("complete -c local-ai -n '__fish_seen_subcommand_from %s' -l %s", cmd.name, f.name) if f.short != 0 { args += fmt.Sprintf(" -s %s", string(f.short)) } args += fmt.Sprintf(" -d '%s'", help) sb.WriteString(args + "\n") } } return sb.String() } ================================================ FILE: core/cli/completion_test.go ================================================ package cli import ( "strings" "testing" "github.com/alecthomas/kong" ) func getTestApp() *kong.Application { var testCLI struct { Run struct{} `cmd:"" help:"Run the server"` Models struct { List struct{} `cmd:"" help:"List models"` Install struct{} `cmd:"" help:"Install a model"` } `cmd:"" help:"Manage models"` Completion CompletionCMD `cmd:"" help:"Generate shell completions"` } k := kong.Must(&testCLI) return k.Model } func TestGenerateBashCompletion(t *testing.T) { app := getTestApp() script := generateBashCompletion(app) if !strings.Contains(script, "complete -F _local_ai_completions local-ai") { t.Error("bash completion missing complete command registration") } if !strings.Contains(script, "run") { t.Error("bash completion missing 'run' command") } if !strings.Contains(script, "models") { t.Error("bash completion missing 'models' command") } if !strings.Contains(script, "completion") { t.Error("bash completion missing 'completion' command") } } func TestGenerateZshCompletion(t *testing.T) { app := getTestApp() script := generateZshCompletion(app) if !strings.Contains(script, "#compdef local-ai") { t.Error("zsh completion missing compdef header") } if !strings.Contains(script, "run") { t.Error("zsh completion missing 'run' command") } if !strings.Contains(script, "models") { t.Error("zsh completion missing 'models' command") } } func TestGenerateFishCompletion(t *testing.T) { app := getTestApp() script := generateFishCompletion(app) if !strings.Contains(script, "complete -c local-ai") { t.Error("fish completion missing complete command") } if !strings.Contains(script, "__fish_use_subcommand") { t.Error("fish completion missing subcommand detection") } if !strings.Contains(script, "run") { t.Error("fish completion missing 'run' command") } if !strings.Contains(script, "models") { t.Error("fish completion missing 'models' command") } } func TestCollectCommands(t *testing.T) { app := getTestApp() cmds := collectCommands(app.Node, "") names := make(map[string]bool) for _, cmd := range cmds { names[cmd.fullName] = true } if !names["run"] { t.Error("missing 'run' command") } if !names["models"] { t.Error("missing 'models' command") } if !names["models list"] { t.Error("missing 'models list' subcommand") } if !names["models install"] { t.Error("missing 'models install' subcommand") } } ================================================ FILE: core/cli/context/context.go ================================================ package cliContext type Context struct { Debug bool `env:"LOCALAI_DEBUG,DEBUG" default:"false" hidden:"" help:"DEPRECATED, use --log-level=debug instead. Enable debug logging"` LogLevel *string `env:"LOCALAI_LOG_LEVEL" enum:"error,warn,info,debug,trace" help:"Set the level of logs to output [${enum}]"` LogFormat *string `env:"LOCALAI_LOG_FORMAT" default:"default" enum:"default,text,json" help:"Set the format of logs to output [${enum}]"` } ================================================ FILE: core/cli/deprecations.go ================================================ package cli import ( "os" "strings" "github.com/mudler/xlog" ) // deprecatedFlags maps old flag names to their new replacements. var deprecatedFlags = map[string]string{ "--p2ptoken": "--p2p-token", } // warnDeprecatedFlags checks os.Args for any deprecated flag names and logs // a warning directing the user to the new name. Old flags continue to work // via kong aliases, so this is purely informational. func warnDeprecatedFlags() { for _, arg := range os.Args[1:] { // Strip any =value suffix to match flag names like --p2ptoken=xyz flag := arg if idx := strings.Index(flag, "="); idx != -1 { flag = flag[:idx] } if newName, ok := deprecatedFlags[flag]; ok { xlog.Warn("Deprecated flag used", "old", flag, "new", newName, "message", "please switch to the new flag name; the old name will be removed in a future release") } } } ================================================ FILE: core/cli/explorer.go ================================================ package cli import ( "context" "time" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/explorer" "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/xlog" ) type ExplorerCMD struct { Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` PoolDatabase string `env:"LOCALAI_POOL_DATABASE,POOL_DATABASE" default:"explorer.json" help:"Path to the pool database" group:"api"` ConnectionTimeout string `env:"LOCALAI_CONNECTION_TIMEOUT,CONNECTION_TIMEOUT" default:"2m" help:"Connection timeout for the explorer" group:"api"` ConnectionErrorThreshold int `env:"LOCALAI_CONNECTION_ERROR_THRESHOLD,CONNECTION_ERROR_THRESHOLD" default:"3" help:"Connection failure threshold for the explorer" group:"api"` WithSync bool `env:"LOCALAI_WITH_SYNC,WITH_SYNC" default:"false" help:"Enable sync with the network" group:"api"` OnlySync bool `env:"LOCALAI_ONLY_SYNC,ONLY_SYNC" default:"false" help:"Only sync with the network" group:"api"` } func (e *ExplorerCMD) Run(ctx *cliContext.Context) error { db, err := explorer.NewDatabase(e.PoolDatabase) if err != nil { return err } dur, err := time.ParseDuration(e.ConnectionTimeout) if err != nil { return err } if e.WithSync { ds := explorer.NewDiscoveryServer(db, dur, e.ConnectionErrorThreshold) go ds.Start(context.Background(), true) } if e.OnlySync { ds := explorer.NewDiscoveryServer(db, dur, e.ConnectionErrorThreshold) ctx := context.Background() return ds.Start(ctx, false) } appHTTP := http.Explorer(db) signals.RegisterGracefulTerminationHandler(func() { ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() if err := appHTTP.Shutdown(ctx); err != nil { xlog.Error("error during shutdown", "error", err) } }) return appHTTP.Start(e.Address) } ================================================ FILE: core/cli/federated.go ================================================ package cli import ( "context" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/pkg/signals" ) type FederatedCLI struct { Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2p-token" aliases:"p2ptoken" help:"Token for P2P mode (optional; --p2ptoken is deprecated, use --p2p-token)" group:"p2p"` RandomWorker bool `env:"LOCALAI_RANDOM_WORKER,RANDOM_WORKER" default:"false" help:"Select a random worker from the pool" group:"p2p"` Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances." group:"p2p"` TargetWorker string `env:"LOCALAI_TARGET_WORKER,TARGET_WORKER" help:"Target worker to run the federated server on" group:"p2p"` } func (f *FederatedCLI) Run(ctx *cliContext.Context) error { warnDeprecatedFlags() fs := p2p.NewFederatedServer(f.Address, p2p.NetworkID(f.Peer2PeerNetworkID, p2p.FederatedID), f.Peer2PeerToken, !f.RandomWorker, f.TargetWorker) c, cancel := context.WithCancel(context.Background()) signals.RegisterGracefulTerminationHandler(func() { cancel() }) return fs.Start(c) } ================================================ FILE: core/cli/models.go ================================================ package cli import ( "context" "encoding/json" "errors" "fmt" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/startup" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" "github.com/schollz/progressbar/v3" ) type ModelsCMDFlags struct { Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"` } type ModelsList struct { ModelsCMDFlags `embed:""` } type ModelsInstall struct { DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"` AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES" help:"If true, automatically loads backend galleries" group:"backends" default:"true"` ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"` ModelsCMDFlags `embed:""` } type ModelsCMD struct { List ModelsList `cmd:"" help:"List the models available in your galleries" default:"withargs"` Install ModelsInstall `cmd:"" help:"Install a model from the gallery"` } func (ml *ModelsList) Run(ctx *cliContext.Context) error { var galleries []config.Gallery if err := json.Unmarshal([]byte(ml.Galleries), &galleries); err != nil { xlog.Error("unable to load galleries", "error", err) } systemState, err := system.GetSystemState( system.WithModelPath(ml.ModelsPath), system.WithBackendPath(ml.BackendsPath), ) if err != nil { return err } models, err := gallery.AvailableGalleryModels(galleries, systemState) if err != nil { return err } for _, model := range models { if model.Installed { fmt.Printf(" * %s@%s (installed)\n", model.Gallery.Name, model.Name) } else { fmt.Printf(" - %s@%s\n", model.Gallery.Name, model.Name) } } return nil } func (mi *ModelsInstall) Run(ctx *cliContext.Context) error { systemState, err := system.GetSystemState( system.WithModelPath(mi.ModelsPath), system.WithBackendPath(mi.BackendsPath), ) if err != nil { return err } galleryService := services.NewGalleryService(&config.ApplicationConfig{ SystemState: systemState, }, model.NewModelLoader(systemState)) err = galleryService.Start(context.Background(), config.NewModelConfigLoader(mi.ModelsPath), systemState) if err != nil { return err } var galleries []config.Gallery if err := json.Unmarshal([]byte(mi.Galleries), &galleries); err != nil { xlog.Error("unable to load galleries", "error", err) } var backendGalleries []config.Gallery if err := json.Unmarshal([]byte(mi.BackendGalleries), &backendGalleries); err != nil { xlog.Error("unable to load backend galleries", "error", err) } for _, modelName := range mi.ModelArgs { progressBar := progressbar.NewOptions( 1000, progressbar.OptionSetDescription(fmt.Sprintf("downloading model %s", modelName)), progressbar.OptionShowBytes(false), progressbar.OptionClearOnFinish(), ) progressCallback := func(fileName string, current string, total string, percentage float64) { v := int(percentage * 10) err := progressBar.Set(v) if err != nil { xlog.Error("error while updating progress bar", "error", err, "filename", fileName, "value", v) } } //startup.InstallModels() models, err := gallery.AvailableGalleryModels(galleries, systemState) if err != nil { return err } modelURI := downloader.URI(modelName) if !modelURI.LooksLikeOCI() { model := gallery.FindGalleryElement(models, modelName) if model == nil { xlog.Error("model not found", "model", modelName) return err } err = gallery.SafetyScanGalleryModel(model) if err != nil && !errors.Is(err, downloader.ErrNonHuggingFaceFile) { return err } } modelLoader := model.NewModelLoader(systemState) err = startup.InstallModels(context.Background(), galleryService, galleries, backendGalleries, systemState, modelLoader, !mi.DisablePredownloadScan, mi.AutoloadBackendGalleries, progressCallback, modelName) if err != nil { return err } } return nil } ================================================ FILE: core/cli/run.go ================================================ package cli import ( "context" "encoding/json" "fmt" "net" "os" "path/filepath" "strings" "time" "github.com/mudler/LocalAI/core/application" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) // CLI Flag Naming Convention: // All CLI flags use kebab-case (e.g., --backends-path, --p2p-token). // When renaming flags, add the old name as an alias for backward compatibility // and document the deprecation in the help text. type RunCMD struct { ModelArgs []string `arg:"" optional:"" name:"models" help:"Model configuration URLs to load"` ExternalBackends []string `env:"LOCALAI_EXTERNAL_BACKENDS,EXTERNAL_BACKENDS" help:"A list of external backends to load from gallery on boot" group:"backends"` BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"` BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` GeneratedContentPath string `env:"LOCALAI_GENERATED_CONTENT_PATH,GENERATED_CONTENT_PATH" type:"path" default:"/tmp/generated/content" help:"Location for generated content (e.g. images, audio, videos)" group:"storage"` UploadPath string `env:"LOCALAI_UPLOAD_PATH,UPLOAD_PATH" type:"path" default:"/tmp/localai/upload" help:"Path to store uploads from files api" group:"storage"` DataPath string `env:"LOCALAI_DATA_PATH" type:"path" default:"${basepath}/data" help:"Path for persistent data (collectiondb, agent state, tasks, jobs). Separates mutable data from configuration" group:"storage"` LocalaiConfigDir string `env:"LOCALAI_CONFIG_DIR" type:"path" default:"${basepath}/configuration" help:"Directory for dynamic loading of certain configuration files (currently api_keys.json and external_backends.json)" group:"storage"` LocalaiConfigDirPollInterval time.Duration `env:"LOCALAI_CONFIG_DIR_POLL_INTERVAL" help:"Typically the config path picks up changes automatically, but if your system has broken fsnotify events, set this to an interval to poll the LocalAI Config Dir (example: 1m)" group:"storage"` // The alias on this option is there to preserve functionality with the old `--config-file` parameter ModelsConfigFile string `env:"LOCALAI_MODELS_CONFIG_FILE,CONFIG_FILE" aliases:"config-file" help:"YAML file containing a list of model backend configs" group:"storage"` BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` AutoloadGalleries bool `env:"LOCALAI_AUTOLOAD_GALLERIES,AUTOLOAD_GALLERIES" group:"models" default:"true"` AutoloadBackendGalleries bool `env:"LOCALAI_AUTOLOAD_BACKEND_GALLERIES,AUTOLOAD_BACKEND_GALLERIES" group:"backends" default:"true"` BackendImagesReleaseTag string `env:"LOCALAI_BACKEND_IMAGES_RELEASE_TAG,BACKEND_IMAGES_RELEASE_TAG" help:"Fallback release tag for backend images" group:"backends" default:"latest"` BackendImagesBranchTag string `env:"LOCALAI_BACKEND_IMAGES_BRANCH_TAG,BACKEND_IMAGES_BRANCH_TAG" help:"Fallback branch tag for backend images" group:"backends" default:"master"` BackendDevSuffix string `env:"LOCALAI_BACKEND_DEV_SUFFIX,BACKEND_DEV_SUFFIX" help:"Development suffix for backend images" group:"backends" default:"development"` PreloadModels string `env:"LOCALAI_PRELOAD_MODELS,PRELOAD_MODELS" help:"A List of models to apply in JSON at start" group:"models"` Models []string `env:"LOCALAI_MODELS,MODELS" help:"A List of model configuration URLs to load" group:"models"` PreloadModelsConfig string `env:"LOCALAI_PRELOAD_MODELS_CONFIG,PRELOAD_MODELS_CONFIG" help:"A List of models to apply at startup. Path to a YAML config file" group:"models"` F16 bool `name:"f16" env:"LOCALAI_F16,F16" help:"Enable GPU acceleration" group:"performance"` Threads int `env:"LOCALAI_THREADS,THREADS" short:"t" help:"Number of threads used for parallel computation. Usage of the number of physical cores in the system is suggested" group:"performance"` ContextSize int `env:"LOCALAI_CONTEXT_SIZE,CONTEXT_SIZE" help:"Default context size for models" group:"performance"` Address string `env:"LOCALAI_ADDRESS,ADDRESS" default:":8080" help:"Bind address for the API server" group:"api"` CORS bool `env:"LOCALAI_CORS,CORS" help:"" group:"api"` CORSAllowOrigins string `env:"LOCALAI_CORS_ALLOW_ORIGINS,CORS_ALLOW_ORIGINS" group:"api"` DisableCSRF bool `env:"LOCALAI_DISABLE_CSRF" help:"Disable CSRF middleware (enabled by default)" group:"api"` UploadLimit int `env:"LOCALAI_UPLOAD_LIMIT,UPLOAD_LIMIT" default:"15" help:"Default upload-limit in MB" group:"api"` APIKeys []string `env:"LOCALAI_API_KEY,API_KEY" help:"List of API Keys to enable API authentication. When this is set, all the requests must be authenticated with one of these API keys" group:"api"` DisableWebUI bool `env:"LOCALAI_DISABLE_WEBUI,DISABLE_WEBUI" default:"false" help:"Disables the web user interface. When set to true, the server will only expose API endpoints without serving the web interface" group:"api"` DisableRuntimeSettings bool `env:"LOCALAI_DISABLE_RUNTIME_SETTINGS,DISABLE_RUNTIME_SETTINGS" default:"false" help:"Disables the runtime settings. When set to true, the server will not load the runtime settings from the runtime_settings.json file" group:"api"` DisablePredownloadScan bool `env:"LOCALAI_DISABLE_PREDOWNLOAD_SCAN" help:"If true, disables the best-effort security scanner before downloading any files." group:"hardening" default:"false"` OpaqueErrors bool `env:"LOCALAI_OPAQUE_ERRORS" default:"false" help:"If true, all error responses are replaced with blank 500 errors. This is intended only for hardening against information leaks and is normally not recommended." group:"hardening"` UseSubtleKeyComparison bool `env:"LOCALAI_SUBTLE_KEY_COMPARISON" default:"false" help:"If true, API Key validation comparisons will be performed using constant-time comparisons rather than simple equality. This trades off performance on each request for resiliancy against timing attacks." group:"hardening"` DisableApiKeyRequirementForHttpGet bool `env:"LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET" default:"false" help:"If true, a valid API key is not required to issue GET requests to portions of the web ui. This should only be enabled in secure testing environments" group:"hardening"` DisableMetricsEndpoint bool `env:"LOCALAI_DISABLE_METRICS_ENDPOINT,DISABLE_METRICS_ENDPOINT" default:"false" help:"Disable the /metrics endpoint" group:"api"` HttpGetExemptedEndpoints []string `env:"LOCALAI_HTTP_GET_EXEMPTED_ENDPOINTS" default:"^/$,^/app(/.*)?$,^/browse(/.*)?$,^/login/?$,^/explorer/?$,^/assets/.*$,^/static/.*$,^/swagger.*$" help:"If LOCALAI_DISABLE_API_KEY_REQUIREMENT_FOR_HTTP_GET is overriden to true, this is the list of endpoints to exempt. Only adjust this in case of a security incident or as a result of a personal security posture review" group:"hardening"` Peer2Peer bool `env:"LOCALAI_P2P,P2P" name:"p2p" default:"false" help:"Enable P2P mode" group:"p2p"` Peer2PeerDHTInterval int `env:"LOCALAI_P2P_DHT_INTERVAL,P2P_DHT_INTERVAL" default:"360" name:"p2p-dht-interval" help:"Interval for DHT refresh (used during token generation)" group:"p2p"` Peer2PeerOTPInterval int `env:"LOCALAI_P2P_OTP_INTERVAL,P2P_OTP_INTERVAL" default:"9000" name:"p2p-otp-interval" help:"Interval for OTP refresh (used during token generation)" group:"p2p"` Peer2PeerToken string `env:"LOCALAI_P2P_TOKEN,P2P_TOKEN,TOKEN" name:"p2p-token" aliases:"p2ptoken" help:"Token for P2P mode (optional; --p2ptoken is deprecated, use --p2p-token)" group:"p2p"` Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"` ParallelRequests bool `env:"LOCALAI_PARALLEL_REQUESTS,PARALLEL_REQUESTS" help:"Enable backends to handle multiple requests in parallel if they support it (e.g.: llama.cpp or vllm)" group:"backends"` SingleActiveBackend bool `env:"LOCALAI_SINGLE_ACTIVE_BACKEND,SINGLE_ACTIVE_BACKEND" help:"Allow only one backend to be run at a time (deprecated: use --max-active-backends=1 instead)" group:"backends"` MaxActiveBackends int `env:"LOCALAI_MAX_ACTIVE_BACKENDS,MAX_ACTIVE_BACKENDS" default:"0" help:"Maximum number of backends to keep loaded at once (0 = unlimited, 1 = single backend mode). Least recently used backends are evicted when limit is reached" group:"backends"` PreloadBackendOnly bool `env:"LOCALAI_PRELOAD_BACKEND_ONLY,PRELOAD_BACKEND_ONLY" default:"false" help:"Do not launch the API services, only the preloaded models / backends are started (useful for multi-node setups)" group:"backends"` ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"` EnableWatchdogIdle bool `env:"LOCALAI_WATCHDOG_IDLE,WATCHDOG_IDLE" default:"false" help:"Enable watchdog for stopping backends that are idle longer than the watchdog-idle-timeout" group:"backends"` WatchdogIdleTimeout string `env:"LOCALAI_WATCHDOG_IDLE_TIMEOUT,WATCHDOG_IDLE_TIMEOUT" default:"15m" help:"Threshold beyond which an idle backend should be stopped" group:"backends"` EnableWatchdogBusy bool `env:"LOCALAI_WATCHDOG_BUSY,WATCHDOG_BUSY" default:"false" help:"Enable watchdog for stopping backends that are busy longer than the watchdog-busy-timeout" group:"backends"` WatchdogBusyTimeout string `env:"LOCALAI_WATCHDOG_BUSY_TIMEOUT,WATCHDOG_BUSY_TIMEOUT" default:"5m" help:"Threshold beyond which a busy backend should be stopped" group:"backends"` WatchdogInterval string `env:"LOCALAI_WATCHDOG_INTERVAL,WATCHDOG_INTERVAL" default:"500ms" help:"Interval between watchdog checks (e.g., 500ms, 5s, 1m) (default: 500ms)" group:"backends"` EnableMemoryReclaimer bool `env:"LOCALAI_MEMORY_RECLAIMER,MEMORY_RECLAIMER,LOCALAI_GPU_RECLAIMER,GPU_RECLAIMER" default:"false" help:"Enable memory threshold monitoring to auto-evict backends when memory usage exceeds threshold (uses GPU VRAM if available, otherwise RAM)" group:"backends"` MemoryReclaimerThreshold float64 `env:"LOCALAI_MEMORY_RECLAIMER_THRESHOLD,MEMORY_RECLAIMER_THRESHOLD,LOCALAI_GPU_RECLAIMER_THRESHOLD,GPU_RECLAIMER_THRESHOLD" default:"0.95" help:"Memory usage threshold (0.0-1.0) that triggers backend eviction (default 0.95 = 95%%)" group:"backends"` ForceEvictionWhenBusy bool `env:"LOCALAI_FORCE_EVICTION_WHEN_BUSY,FORCE_EVICTION_WHEN_BUSY" default:"false" help:"Force eviction even when models have active API calls (default: false for safety)" group:"backends"` LRUEvictionMaxRetries int `env:"LOCALAI_LRU_EVICTION_MAX_RETRIES,LRU_EVICTION_MAX_RETRIES" default:"30" help:"Maximum number of retries when waiting for busy models to become idle before eviction (default: 30)" group:"backends"` LRUEvictionRetryInterval string `env:"LOCALAI_LRU_EVICTION_RETRY_INTERVAL,LRU_EVICTION_RETRY_INTERVAL" default:"1s" help:"Interval between retries when waiting for busy models to become idle (e.g., 1s, 2s) (default: 1s)" group:"backends"` Federated bool `env:"LOCALAI_FEDERATED,FEDERATED" help:"Enable federated instance" group:"federated"` DisableGalleryEndpoint bool `env:"LOCALAI_DISABLE_GALLERY_ENDPOINT,DISABLE_GALLERY_ENDPOINT" help:"Disable the gallery endpoints" group:"api"` DisableMCP bool `env:"LOCALAI_DISABLE_MCP,DISABLE_MCP" help:"Disable MCP (Model Context Protocol) support" group:"api" default:"false"` MachineTag string `env:"LOCALAI_MACHINE_TAG,MACHINE_TAG" help:"Add Machine-Tag header to each response which is useful to track the machine in the P2P network" group:"api"` LoadToMemory []string `env:"LOCALAI_LOAD_TO_MEMORY,LOAD_TO_MEMORY" help:"A list of models to load into memory at startup" group:"models"` EnableTracing bool `env:"LOCALAI_ENABLE_TRACING,ENABLE_TRACING" help:"Enable API tracing" group:"api"` TracingMaxItems int `env:"LOCALAI_TRACING_MAX_ITEMS" default:"1024" help:"Maximum number of traces to keep" group:"api"` AgentJobRetentionDays int `env:"LOCALAI_AGENT_JOB_RETENTION_DAYS,AGENT_JOB_RETENTION_DAYS" default:"30" help:"Number of days to keep agent job history (default: 30)" group:"api"` OpenResponsesStoreTTL string `env:"LOCALAI_OPEN_RESPONSES_STORE_TTL,OPEN_RESPONSES_STORE_TTL" default:"0" help:"TTL for Open Responses store (e.g., 1h, 30m, 0 = no expiration)" group:"api"` // Agent Pool (LocalAGI) DisableAgents bool `env:"LOCALAI_DISABLE_AGENTS" default:"false" help:"Disable the agent pool feature" group:"agents"` AgentPoolAPIURL string `env:"LOCALAI_AGENT_POOL_API_URL" help:"Default API URL for agents (defaults to self-referencing LocalAI)" group:"agents"` AgentPoolAPIKey string `env:"LOCALAI_AGENT_POOL_API_KEY" help:"Default API key for agents (defaults to first LocalAI API key)" group:"agents"` AgentPoolDefaultModel string `env:"LOCALAI_AGENT_POOL_DEFAULT_MODEL" help:"Default model for agents" group:"agents"` AgentPoolMultimodalModel string `env:"LOCALAI_AGENT_POOL_MULTIMODAL_MODEL" help:"Default multimodal model for agents" group:"agents"` AgentPoolTranscriptionModel string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_MODEL" help:"Default transcription model for agents" group:"agents"` AgentPoolTranscriptionLanguage string `env:"LOCALAI_AGENT_POOL_TRANSCRIPTION_LANGUAGE" help:"Default transcription language for agents" group:"agents"` AgentPoolTTSModel string `env:"LOCALAI_AGENT_POOL_TTS_MODEL" help:"Default TTS model for agents" group:"agents"` AgentPoolStateDir string `env:"LOCALAI_AGENT_POOL_STATE_DIR" help:"State directory for agent pool" group:"agents"` AgentPoolTimeout string `env:"LOCALAI_AGENT_POOL_TIMEOUT" default:"5m" help:"Default agent timeout" group:"agents"` AgentPoolEnableSkills bool `env:"LOCALAI_AGENT_POOL_ENABLE_SKILLS" default:"false" help:"Enable skills service for agents" group:"agents"` AgentPoolVectorEngine string `env:"LOCALAI_AGENT_POOL_VECTOR_ENGINE" default:"chromem" help:"Vector engine type for agent knowledge base" group:"agents"` AgentPoolEmbeddingModel string `env:"LOCALAI_AGENT_POOL_EMBEDDING_MODEL" default:"granite-embedding-107m-multilingual" help:"Embedding model for agent knowledge base" group:"agents"` AgentPoolCustomActionsDir string `env:"LOCALAI_AGENT_POOL_CUSTOM_ACTIONS_DIR" help:"Custom actions directory for agents" group:"agents"` AgentPoolDatabaseURL string `env:"LOCALAI_AGENT_POOL_DATABASE_URL" help:"Database URL for agent collections" group:"agents"` AgentPoolMaxChunkingSize int `env:"LOCALAI_AGENT_POOL_MAX_CHUNKING_SIZE" default:"400" help:"Maximum chunking size for knowledge base documents" group:"agents"` AgentPoolChunkOverlap int `env:"LOCALAI_AGENT_POOL_CHUNK_OVERLAP" default:"0" help:"Chunk overlap size for knowledge base documents" group:"agents"` AgentPoolEnableLogs bool `env:"LOCALAI_AGENT_POOL_ENABLE_LOGS" default:"false" help:"Enable agent logging" group:"agents"` AgentPoolCollectionDBPath string `env:"LOCALAI_AGENT_POOL_COLLECTION_DB_PATH" help:"Database path for agent collections" group:"agents"` AgentHubURL string `env:"LOCALAI_AGENT_HUB_URL" default:"https://agenthub.localai.io" help:"URL for the agent hub where users can browse and download agent configurations" group:"agents"` // Authentication AuthEnabled bool `env:"LOCALAI_AUTH" default:"false" help:"Enable user authentication and authorization" group:"auth"` AuthDatabaseURL string `env:"LOCALAI_AUTH_DATABASE_URL,DATABASE_URL" help:"Database URL for auth (postgres:// or file path for SQLite). Defaults to {DataPath}/database.db" group:"auth"` GitHubClientID string `env:"GITHUB_CLIENT_ID" help:"GitHub OAuth App Client ID (auto-enables auth when set)" group:"auth"` GitHubClientSecret string `env:"GITHUB_CLIENT_SECRET" help:"GitHub OAuth App Client Secret" group:"auth"` OIDCIssuer string `env:"LOCALAI_OIDC_ISSUER" help:"OIDC issuer URL for auto-discovery" group:"auth"` OIDCClientID string `env:"LOCALAI_OIDC_CLIENT_ID" help:"OIDC Client ID (auto-enables auth)" group:"auth"` OIDCClientSecret string `env:"LOCALAI_OIDC_CLIENT_SECRET" help:"OIDC Client Secret" group:"auth"` AuthBaseURL string `env:"LOCALAI_BASE_URL" help:"Base URL for OAuth callbacks (e.g. http://localhost:8080)" group:"auth"` AuthAdminEmail string `env:"LOCALAI_ADMIN_EMAIL" help:"Email address to auto-promote to admin role" group:"auth"` AuthRegistrationMode string `env:"LOCALAI_REGISTRATION_MODE" default:"open" help:"Registration mode: 'open' (default), 'approval', or 'invite' (invite code required)" group:"auth"` DisableLocalAuth bool `env:"LOCALAI_DISABLE_LOCAL_AUTH" default:"false" help:"Disable local email/password registration and login (use with OAuth/OIDC-only setups)" group:"auth"` AuthAPIKeyHMACSecret string `env:"LOCALAI_AUTH_HMAC_SECRET" help:"HMAC secret for API key hashing (auto-generated if empty)" group:"auth"` DefaultAPIKeyExpiry string `env:"LOCALAI_DEFAULT_API_KEY_EXPIRY" help:"Default expiry for API keys (e.g. 90d, 1y; empty = no expiry)" group:"auth"` Version bool } func (r *RunCMD) Run(ctx *cliContext.Context) error { warnDeprecatedFlags() if r.Version { fmt.Println(internal.Version) return nil } os.MkdirAll(r.BackendsPath, 0750) os.MkdirAll(r.ModelsPath, 0750) systemState, err := system.GetSystemState( system.WithBackendSystemPath(r.BackendsSystemPath), system.WithModelPath(r.ModelsPath), system.WithBackendPath(r.BackendsPath), system.WithBackendImagesReleaseTag(r.BackendImagesReleaseTag), system.WithBackendImagesBranchTag(r.BackendImagesBranchTag), system.WithBackendDevSuffix(r.BackendDevSuffix), ) if err != nil { return err } opts := []config.AppOption{ config.WithContext(context.Background()), config.WithConfigFile(r.ModelsConfigFile), config.WithJSONStringPreload(r.PreloadModels), config.WithYAMLConfigPreload(r.PreloadModelsConfig), config.WithSystemState(systemState), config.WithContextSize(r.ContextSize), config.WithDebug(ctx.Debug || (ctx.LogLevel != nil && *ctx.LogLevel == "debug")), config.WithGeneratedContentDir(r.GeneratedContentPath), config.WithUploadDir(r.UploadPath), config.WithDataPath(r.DataPath), config.WithDynamicConfigDir(r.LocalaiConfigDir), config.WithDynamicConfigDirPollInterval(r.LocalaiConfigDirPollInterval), config.WithF16(r.F16), config.WithStringGalleries(r.Galleries), config.WithBackendGalleries(r.BackendGalleries), config.WithCors(r.CORS), config.WithCorsAllowOrigins(r.CORSAllowOrigins), config.WithDisableCSRF(r.DisableCSRF), config.WithThreads(r.Threads), config.WithUploadLimitMB(r.UploadLimit), config.WithApiKeys(r.APIKeys), config.WithModelsURL(append(r.Models, r.ModelArgs...)...), config.WithExternalBackends(r.ExternalBackends...), config.WithOpaqueErrors(r.OpaqueErrors), config.WithEnforcedPredownloadScans(!r.DisablePredownloadScan), config.WithSubtleKeyComparison(r.UseSubtleKeyComparison), config.WithDisableApiKeyRequirementForHttpGet(r.DisableApiKeyRequirementForHttpGet), config.WithHttpGetExemptedEndpoints(r.HttpGetExemptedEndpoints), config.WithP2PNetworkID(r.Peer2PeerNetworkID), config.WithLoadToMemory(r.LoadToMemory), config.WithMachineTag(r.MachineTag), config.WithAPIAddress(r.Address), config.WithAgentJobRetentionDays(r.AgentJobRetentionDays), config.WithLlamaCPPTunnelCallback(func(tunnels []string) { tunnelEnvVar := strings.Join(tunnels, ",") os.Setenv("LLAMACPP_GRPC_SERVERS", tunnelEnvVar) xlog.Debug("setting LLAMACPP_GRPC_SERVERS", "value", tunnelEnvVar) }), config.WithMLXTunnelCallback(func(tunnels []string) { hostfile := filepath.Join(os.TempDir(), "localai_mlx_hostfile.json") data, _ := json.Marshal(tunnels) os.WriteFile(hostfile, data, 0644) os.Setenv("MLX_DISTRIBUTED_HOSTFILE", hostfile) xlog.Debug("setting MLX_DISTRIBUTED_HOSTFILE", "value", hostfile, "tunnels", tunnels) }), } if r.DisableMetricsEndpoint { opts = append(opts, config.DisableMetricsEndpoint) } if r.DisableRuntimeSettings { opts = append(opts, config.DisableRuntimeSettings) } if r.EnableTracing { opts = append(opts, config.EnableTracing) } if r.EnableTracing { opts = append(opts, config.EnableTracing) } opts = append(opts, config.WithTracingMaxItems(r.TracingMaxItems)) token := "" if r.Peer2Peer || r.Peer2PeerToken != "" { xlog.Info("P2P mode enabled") token = r.Peer2PeerToken if token == "" { // IF no token is provided, and p2p is enabled, // we generate one and wait for the user to pick up the token (this is for interactive) xlog.Info("No token provided, generating one") token = p2p.GenerateToken(r.Peer2PeerDHTInterval, r.Peer2PeerOTPInterval) xlog.Info("Generated Token:") fmt.Println(token) xlog.Info("To use the token, you can run the following command in another node or terminal:") fmt.Printf("export TOKEN=\"%s\"\nlocal-ai worker p2p-llama-cpp-rpc\n", token) } opts = append(opts, config.WithP2PToken(token)) } if r.Federated { opts = append(opts, config.EnableFederated) } idleWatchDog := r.EnableWatchdogIdle busyWatchDog := r.EnableWatchdogBusy if r.DisableWebUI { opts = append(opts, config.DisableWebUI) } if r.DisableGalleryEndpoint { opts = append(opts, config.DisableGalleryEndpoint) } if r.DisableMCP { opts = append(opts, config.DisableMCP) } // Agent Pool if r.DisableAgents { opts = append(opts, config.DisableAgentPool) } if r.AgentPoolAPIURL != "" { opts = append(opts, config.WithAgentPoolAPIURL(r.AgentPoolAPIURL)) } if r.AgentPoolAPIKey != "" { opts = append(opts, config.WithAgentPoolAPIKey(r.AgentPoolAPIKey)) } if r.AgentPoolDefaultModel != "" { opts = append(opts, config.WithAgentPoolDefaultModel(r.AgentPoolDefaultModel)) } if r.AgentPoolMultimodalModel != "" { opts = append(opts, config.WithAgentPoolMultimodalModel(r.AgentPoolMultimodalModel)) } if r.AgentPoolTranscriptionModel != "" { opts = append(opts, config.WithAgentPoolTranscriptionModel(r.AgentPoolTranscriptionModel)) } if r.AgentPoolTranscriptionLanguage != "" { opts = append(opts, config.WithAgentPoolTranscriptionLanguage(r.AgentPoolTranscriptionLanguage)) } if r.AgentPoolTTSModel != "" { opts = append(opts, config.WithAgentPoolTTSModel(r.AgentPoolTTSModel)) } if r.AgentPoolStateDir != "" { opts = append(opts, config.WithAgentPoolStateDir(r.AgentPoolStateDir)) } if r.AgentPoolTimeout != "" { opts = append(opts, config.WithAgentPoolTimeout(r.AgentPoolTimeout)) } if r.AgentPoolEnableSkills { opts = append(opts, config.EnableAgentPoolSkills) } if r.AgentPoolVectorEngine != "" { opts = append(opts, config.WithAgentPoolVectorEngine(r.AgentPoolVectorEngine)) } if r.AgentPoolEmbeddingModel != "" { opts = append(opts, config.WithAgentPoolEmbeddingModel(r.AgentPoolEmbeddingModel)) } if r.AgentPoolCustomActionsDir != "" { opts = append(opts, config.WithAgentPoolCustomActionsDir(r.AgentPoolCustomActionsDir)) } if r.AgentPoolDatabaseURL != "" { opts = append(opts, config.WithAgentPoolDatabaseURL(r.AgentPoolDatabaseURL)) } if r.AgentPoolMaxChunkingSize > 0 { opts = append(opts, config.WithAgentPoolMaxChunkingSize(r.AgentPoolMaxChunkingSize)) } if r.AgentPoolChunkOverlap > 0 { opts = append(opts, config.WithAgentPoolChunkOverlap(r.AgentPoolChunkOverlap)) } if r.AgentPoolEnableLogs { opts = append(opts, config.EnableAgentPoolLogs) } if r.AgentPoolCollectionDBPath != "" { opts = append(opts, config.WithAgentPoolCollectionDBPath(r.AgentPoolCollectionDBPath)) } if r.AgentHubURL != "" { opts = append(opts, config.WithAgentHubURL(r.AgentHubURL)) } // Authentication authEnabled := r.AuthEnabled || r.GitHubClientID != "" || r.OIDCClientID != "" if authEnabled { opts = append(opts, config.WithAuthEnabled(true)) dbURL := r.AuthDatabaseURL if dbURL == "" { dbURL = filepath.Join(r.DataPath, "database.db") } opts = append(opts, config.WithAuthDatabaseURL(dbURL)) if r.GitHubClientID != "" { opts = append(opts, config.WithAuthGitHubClientID(r.GitHubClientID)) opts = append(opts, config.WithAuthGitHubClientSecret(r.GitHubClientSecret)) } if r.OIDCClientID != "" { opts = append(opts, config.WithAuthOIDCIssuer(r.OIDCIssuer)) opts = append(opts, config.WithAuthOIDCClientID(r.OIDCClientID)) opts = append(opts, config.WithAuthOIDCClientSecret(r.OIDCClientSecret)) } if r.AuthBaseURL != "" { opts = append(opts, config.WithAuthBaseURL(r.AuthBaseURL)) } if r.AuthAdminEmail != "" { opts = append(opts, config.WithAuthAdminEmail(r.AuthAdminEmail)) } if r.AuthRegistrationMode != "" { opts = append(opts, config.WithAuthRegistrationMode(r.AuthRegistrationMode)) } if r.DisableLocalAuth { opts = append(opts, config.WithAuthDisableLocalAuth(true)) } if r.AuthAPIKeyHMACSecret != "" { opts = append(opts, config.WithAuthAPIKeyHMACSecret(r.AuthAPIKeyHMACSecret)) } if r.DefaultAPIKeyExpiry != "" { opts = append(opts, config.WithAuthDefaultAPIKeyExpiry(r.DefaultAPIKeyExpiry)) } } if idleWatchDog || busyWatchDog { opts = append(opts, config.EnableWatchDog) if idleWatchDog { opts = append(opts, config.EnableWatchDogIdleCheck) dur, err := time.ParseDuration(r.WatchdogIdleTimeout) if err != nil { return err } opts = append(opts, config.SetWatchDogIdleTimeout(dur)) } if busyWatchDog { opts = append(opts, config.EnableWatchDogBusyCheck) dur, err := time.ParseDuration(r.WatchdogBusyTimeout) if err != nil { return err } opts = append(opts, config.SetWatchDogBusyTimeout(dur)) } if r.WatchdogInterval != "" { dur, err := time.ParseDuration(r.WatchdogInterval) if err != nil { return err } opts = append(opts, config.SetWatchDogInterval(dur)) } } // Handle memory reclaimer (uses GPU VRAM if available, otherwise RAM) if r.EnableMemoryReclaimer { opts = append(opts, config.WithMemoryReclaimer(true, r.MemoryReclaimerThreshold)) } if r.ParallelRequests { opts = append(opts, config.EnableParallelBackendRequests) } // Handle max active backends (LRU eviction) // MaxActiveBackends takes precedence over SingleActiveBackend if r.MaxActiveBackends > 0 { opts = append(opts, config.SetMaxActiveBackends(r.MaxActiveBackends)) } else if r.SingleActiveBackend { // Backward compatibility: --single-active-backend is equivalent to --max-active-backends=1 opts = append(opts, config.EnableSingleBackend) } // Handle LRU eviction settings if r.ForceEvictionWhenBusy { opts = append(opts, config.WithForceEvictionWhenBusy(true)) } if r.LRUEvictionMaxRetries > 0 { opts = append(opts, config.WithLRUEvictionMaxRetries(r.LRUEvictionMaxRetries)) } if r.LRUEvictionRetryInterval != "" { dur, err := time.ParseDuration(r.LRUEvictionRetryInterval) if err != nil { return fmt.Errorf("invalid LRU eviction retry interval: %w", err) } opts = append(opts, config.WithLRUEvictionRetryInterval(dur)) } // Handle Open Responses store TTL if r.OpenResponsesStoreTTL != "" && r.OpenResponsesStoreTTL != "0" { dur, err := time.ParseDuration(r.OpenResponsesStoreTTL) if err != nil { return fmt.Errorf("invalid Open Responses store TTL: %w", err) } opts = append(opts, config.WithOpenResponsesStoreTTL(dur)) } // split ":" to get backend name and the uri for _, v := range r.ExternalGRPCBackends { backend := v[:strings.IndexByte(v, ':')] uri := v[strings.IndexByte(v, ':')+1:] opts = append(opts, config.WithExternalBackend(backend, uri)) } if r.AutoloadGalleries { opts = append(opts, config.EnableGalleriesAutoload) } if r.AutoloadBackendGalleries { opts = append(opts, config.EnableBackendGalleriesAutoload) } if r.PreloadBackendOnly { _, err := application.New(opts...) return err } app, err := application.New(opts...) if err != nil { return fmt.Errorf("failed basic startup tasks with error %s", err.Error()) } appHTTP, err := http.API(app) if err != nil { xlog.Error("error during HTTP App construction", "error", err) return err } xlog.Info("LocalAI is started and running", "address", r.Address) // Start P2P if token was provided via CLI/env or loaded from runtime_settings.json if token != "" || app.ApplicationConfig().P2PToken != "" { if err := app.StartP2P(); err != nil { return err } } signals.RegisterGracefulTerminationHandler(func() { if err := app.ModelLoader().StopAllGRPC(); err != nil { xlog.Error("error while stopping all grpc backends", "error", err) } }) // Start the agent pool after the HTTP server is listening, because // backends like PostgreSQL need to call the embeddings API during // collection initialization. go func() { waitForServerReady(r.Address, app.ApplicationConfig().Context) app.StartAgentPool() }() return appHTTP.Start(r.Address) } // waitForServerReady polls the given address until the HTTP server is // accepting connections or the context is cancelled. func waitForServerReady(address string, ctx context.Context) { // Ensure the address has a host component for dialing. // Echo accepts ":8080" but net.Dial needs a resolvable host. host, port, err := net.SplitHostPort(address) if err == nil && host == "" { address = "127.0.0.1:" + port } for { select { case <-ctx.Done(): return default: } conn, err := net.DialTimeout("tcp", address, 500*time.Millisecond) if err == nil { conn.Close() return } time.Sleep(250 * time.Millisecond) } } ================================================ FILE: core/cli/soundgeneration.go ================================================ package cli import ( "context" "fmt" "os" "path/filepath" "strconv" "strings" "github.com/mudler/LocalAI/core/backend" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) type SoundGenerationCMD struct { Text []string `arg:""` Backend string `short:"b" required:"" help:"Backend to run the SoundGeneration model"` Model string `short:"m" required:"" help:"Model name to run the SoundGeneration"` Duration string `short:"d" help:"If specified, the length of audio to generate in seconds"` Temperature string `short:"t" help:"If specified, the temperature of the generation"` InputFile string `short:"i" help:"If specified, the input file to condition generation upon"` InputFileSampleDivisor string `short:"f" help:"If InputFile and this divisor is specified, the first portion of the sample file will be used"` DoSample bool `short:"s" default:"true" help:"Enables sampling from the model. Better quality at the cost of speed. Defaults to enabled."` OutputFile string `short:"o" type:"path" help:"The path to write the output wav file"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` ExternalGRPCBackends []string `env:"LOCALAI_EXTERNAL_GRPC_BACKENDS,EXTERNAL_GRPC_BACKENDS" help:"A list of external grpc backends" group:"backends"` } func parseToFloat32Ptr(input string) *float32 { f, err := strconv.ParseFloat(input, 32) if err != nil { return nil } f2 := float32(f) return &f2 } func parseToInt32Ptr(input string) *int32 { i, err := strconv.ParseInt(input, 10, 32) if err != nil { return nil } i2 := int32(i) return &i2 } func (t *SoundGenerationCMD) Run(ctx *cliContext.Context) error { outputFile := t.OutputFile outputDir := os.TempDir() if outputFile != "" { outputDir = filepath.Dir(outputFile) } text := strings.Join(t.Text, " ") systemState, err := system.GetSystemState( system.WithModelPath(t.ModelsPath), ) if err != nil { return err } externalBackends := make(map[string]string) // split ":" to get backend name and the uri for _, v := range t.ExternalGRPCBackends { backend := v[:strings.IndexByte(v, ':')] uri := v[strings.IndexByte(v, ':')+1:] externalBackends[backend] = uri } opts := &config.ApplicationConfig{ SystemState: systemState, Context: context.Background(), GeneratedContentDir: outputDir, ExternalGRPCBackends: externalBackends, } ml := model.NewModelLoader(systemState) defer func() { err := ml.StopAllGRPC() if err != nil { xlog.Error("unable to stop all grpc processes", "error", err) } }() options := config.ModelConfig{} options.SetDefaults() options.Backend = t.Backend options.Model = t.Model var inputFile *string if t.InputFile != "" { inputFile = &t.InputFile } filePath, _, err := backend.SoundGeneration(text, parseToFloat32Ptr(t.Duration), parseToFloat32Ptr(t.Temperature), &t.DoSample, inputFile, parseToInt32Ptr(t.InputFileSampleDivisor), nil, "", "", nil, "", "", "", nil, ml, opts, options) if err != nil { return err } if outputFile != "" { if err := os.Rename(filePath, outputFile); err != nil { return err } fmt.Printf("Generate file %s\n", outputFile) } else { fmt.Printf("Generate file %s\n", filePath) } return nil } ================================================ FILE: core/cli/transcript.go ================================================ package cli import ( "context" "encoding/json" "errors" "fmt" "strings" "github.com/mudler/LocalAI/core/backend" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/format" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) type TranscriptCMD struct { Filename string `arg:"" name:"file" help:"Audio file to transcribe" type:"path"` Backend string `short:"b" default:"whisper" help:"Backend to run the transcription model"` Model string `short:"m" required:"" help:"Model name to run the TTS"` Language string `short:"l" help:"Language of the audio file"` Translate bool `short:"c" help:"Translate the transcription to English"` Diarize bool `short:"d" help:"Mark speaker turns"` Threads int `short:"t" default:"1" help:"Number of threads used for parallel computation"` BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"storage"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` Prompt string `short:"p" help:"Previous transcribed text or words that hint at what the model should expect"` ResponseFormat schema.TranscriptionResponseFormatType `short:"f" default:"" help:"Response format for Whisper models, can be one of (txt, lrc, srt, vtt, json, verbose_json)"` PrettyPrint bool `help:"Used with response_format json or verbose_json for pretty printing"` } func (t *TranscriptCMD) Run(ctx *cliContext.Context) error { systemState, err := system.GetSystemState( system.WithBackendPath(t.BackendsPath), system.WithModelPath(t.ModelsPath), ) if err != nil { return err } opts := &config.ApplicationConfig{ SystemState: systemState, Context: context.Background(), } cl := config.NewModelConfigLoader(t.ModelsPath) ml := model.NewModelLoader(systemState) if err := gallery.RegisterBackends(systemState, ml); err != nil { xlog.Error("error registering external backends", "error", err) } if err := cl.LoadModelConfigsFromPath(t.ModelsPath); err != nil { return err } c, exists := cl.GetModelConfig(t.Model) if !exists { return errors.New("model not found") } c.Threads = &t.Threads defer func() { err := ml.StopAllGRPC() if err != nil { xlog.Error("unable to stop all grpc processes", "error", err) } }() tr, err := backend.ModelTranscription(t.Filename, t.Language, t.Translate, t.Diarize, t.Prompt, ml, c, opts) if err != nil { return err } switch t.ResponseFormat { case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt, schema.TranscriptionResponseFormatText: fmt.Println(format.TranscriptionResponse(tr, t.ResponseFormat)) case schema.TranscriptionResponseFormatJson: tr.Segments = nil fallthrough case schema.TranscriptionResponseFormatJsonVerbose: var mtr []byte var err error if t.PrettyPrint { mtr, err = json.MarshalIndent(tr, "", " ") } else { mtr, err = json.Marshal(tr) } if err != nil { return err } fmt.Println(string(mtr)) default: for _, segment := range tr.Segments { fmt.Println(segment.Start.String(), "-", strings.TrimSpace(segment.Text)) } } return nil } ================================================ FILE: core/cli/tts.go ================================================ package cli import ( "context" "fmt" "os" "path/filepath" "strings" "github.com/mudler/LocalAI/core/backend" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) type TTSCMD struct { Text []string `arg:""` Backend string `short:"b" default:"piper" help:"Backend to run the TTS model"` Model string `short:"m" required:"" help:"Model name to run the TTS"` Voice string `short:"v" help:"Voice name to run the TTS"` Language string `short:"l" help:"Language to use with the TTS"` OutputFile string `short:"o" type:"path" help:"The path to write the output wav file"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` } func (t *TTSCMD) Run(ctx *cliContext.Context) error { outputFile := t.OutputFile outputDir := os.TempDir() if outputFile != "" { outputDir = filepath.Dir(outputFile) } text := strings.Join(t.Text, " ") systemState, err := system.GetSystemState( system.WithModelPath(t.ModelsPath), ) if err != nil { return err } opts := &config.ApplicationConfig{ SystemState: systemState, Context: context.Background(), GeneratedContentDir: outputDir, } ml := model.NewModelLoader(systemState) defer func() { err := ml.StopAllGRPC() if err != nil { xlog.Error("unable to stop all grpc processes", "error", err) } }() options := config.ModelConfig{} options.SetDefaults() options.Backend = t.Backend options.Model = t.Model filePath, _, err := backend.ModelTTS(text, t.Voice, t.Language, ml, opts, options) if err != nil { return err } if outputFile != "" { if err := os.Rename(filePath, outputFile); err != nil { return err } fmt.Printf("Generate file %s\n", outputFile) } else { fmt.Printf("Generate file %s\n", filePath) } return nil } ================================================ FILE: core/cli/util.go ================================================ package cli import ( "encoding/json" "errors" "fmt" "os" "path/filepath" "strings" "github.com/mholt/archiver/v3" "github.com/mudler/xlog" gguf "github.com/gpustack/gguf-parser-go" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/oci" "github.com/mudler/LocalAI/pkg/system" ) type UtilCMD struct { GGUFInfo GGUFInfoCMD `cmd:"" name:"gguf-info" help:"Get information about a GGUF file"` CreateOCIImage CreateOCIImageCMD `cmd:"" name:"create-oci-image" help:"Create an OCI image from a file or a directory"` HFScan HFScanCMD `cmd:"" name:"hf-scan" help:"Checks installed models for known security issues. WARNING: this is a best-effort feature and may not catch everything!"` UsecaseHeuristic UsecaseHeuristicCMD `cmd:"" name:"usecase-heuristic" help:"Checks a specific model config and prints what usecase LocalAI will offer for it."` } type GGUFInfoCMD struct { Args []string `arg:"" optional:"" name:"args" help:"Arguments to pass to the utility command"` Header bool `optional:"" default:"false" name:"header" help:"Show header information"` } type HFScanCMD struct { ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` Galleries string `env:"LOCALAI_GALLERIES,GALLERIES" help:"JSON list of galleries" group:"models" default:"${galleries}"` ToScan []string `arg:""` } type UsecaseHeuristicCMD struct { ConfigName string `arg:"" name:"config-name" help:"The config file to check"` ModelsPath string `env:"LOCALAI_MODELS_PATH,MODELS_PATH" type:"path" default:"${basepath}/models" help:"Path containing models used for inferencing" group:"storage"` } type CreateOCIImageCMD struct { Input []string `arg:"" help:"Input file or directory to create an OCI image from"` Output string `default:"image.tar" help:"Output OCI image name"` ImageName string `default:"localai" help:"Image name"` Platform string `default:"linux/amd64" help:"Platform of the image"` } func (u *CreateOCIImageCMD) Run(ctx *cliContext.Context) error { xlog.Info("Creating OCI image from input") dir, err := os.MkdirTemp("", "localai") if err != nil { return err } defer os.RemoveAll(dir) err = archiver.Archive(u.Input, filepath.Join(dir, "archive.tar")) if err != nil { return err } xlog.Info("Creating OCI image", "output", u.Output, "input", u.Input) platform := strings.Split(u.Platform, "/") if len(platform) != 2 { return fmt.Errorf("invalid platform: %s", u.Platform) } return oci.CreateTar(filepath.Join(dir, "archive.tar"), u.Output, u.ImageName, platform[1], platform[0]) } func (u *GGUFInfoCMD) Run(ctx *cliContext.Context) error { if len(u.Args) == 0 { return fmt.Errorf("no GGUF file provided") } // We try to guess only if we don't have a template defined already f, err := gguf.ParseGGUFFile(u.Args[0]) if err != nil { // Only valid for gguf files xlog.Error("guessDefaultsFromFile: not a GGUF file") return err } xlog.Info("GGUF file loaded", "file", u.Args[0], "eosTokenID", f.Tokenizer().EOSTokenID, "bosTokenID", f.Tokenizer().BOSTokenID, "modelName", f.Metadata().Name, "architecture", f.Architecture().Architecture) xlog.Info("Tokenizer", "tokenizer", fmt.Sprintf("%+v", f.Tokenizer())) xlog.Info("Architecture", "architecture", fmt.Sprintf("%+v", f.Architecture())) v, exists := f.Header.MetadataKV.Get("tokenizer.chat_template") if exists { xlog.Info("chat_template", "template", v.ValueString()) } if u.Header { for _, metadata := range f.Header.MetadataKV { xlog.Info("metadata", "key", metadata.Key, "value", metadata.Value) } // log.Info().Any("header", fmt.Sprintf("%+v", f.Header)).Msg("Header") } return nil } func (hfscmd *HFScanCMD) Run(ctx *cliContext.Context) error { systemState, err := system.GetSystemState( system.WithModelPath(hfscmd.ModelsPath), ) if err != nil { return err } xlog.Info("LocalAI Security Scanner - This is BEST EFFORT functionality! Currently limited to huggingface models!") if len(hfscmd.ToScan) == 0 { xlog.Info("Checking all installed models against galleries") var galleries []config.Gallery if err := json.Unmarshal([]byte(hfscmd.Galleries), &galleries); err != nil { xlog.Error("unable to load galleries", "error", err) } err := gallery.SafetyScanGalleryModels(galleries, systemState) if err == nil { xlog.Info("No security warnings were detected for your installed models. Please note that this is a BEST EFFORT tool, and all issues may not be detected.") } else { xlog.Error("! WARNING ! A known-vulnerable model is installed!", "error", err) } return err } else { var errs error = nil for _, uri := range hfscmd.ToScan { xlog.Info("scanning specific uri", "uri", uri) scanResults, err := downloader.HuggingFaceScan(downloader.URI(uri)) if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) { xlog.Error("! WARNING ! A known-vulnerable model is included in this repo!", "error", err, "clamAV", scanResults.ClamAVInfectedFiles, "pickles", scanResults.DangerousPickles) errs = errors.Join(errs, err) } } if errs != nil { return errs } xlog.Info("No security warnings were detected for your installed models. Please note that this is a BEST EFFORT tool, and all issues may not be detected.") return nil } } func (uhcmd *UsecaseHeuristicCMD) Run(ctx *cliContext.Context) error { if len(uhcmd.ConfigName) == 0 { xlog.Error("ConfigName is a required parameter") return fmt.Errorf("config name is a required parameter") } if len(uhcmd.ModelsPath) == 0 { xlog.Error("ModelsPath is a required parameter") return fmt.Errorf("model path is a required parameter") } bcl := config.NewModelConfigLoader(uhcmd.ModelsPath) err := bcl.ReadModelConfig(uhcmd.ConfigName) if err != nil { xlog.Error("error while loading backend", "error", err, "ConfigName", uhcmd.ConfigName) return err } bc, exists := bcl.GetModelConfig(uhcmd.ConfigName) if !exists { xlog.Error("ConfigName not found", "ConfigName", uhcmd.ConfigName) } for name, uc := range config.GetAllModelConfigUsecases() { if bc.HasUsecases(uc) { xlog.Info("Usecase", "usecase", name) } } xlog.Info("---") return nil } ================================================ FILE: core/cli/worker/worker.go ================================================ package worker type WorkerFlags struct { BackendsPath string `env:"LOCALAI_BACKENDS_PATH,BACKENDS_PATH" type:"path" default:"${basepath}/backends" help:"Path containing backends used for inferencing" group:"backends"` BackendGalleries string `env:"LOCALAI_BACKEND_GALLERIES,BACKEND_GALLERIES" help:"JSON list of backend galleries" group:"backends" default:"${backends}"` BackendsSystemPath string `env:"LOCALAI_BACKENDS_SYSTEM_PATH,BACKEND_SYSTEM_PATH" type:"path" default:"/var/lib/local-ai/backends" help:"Path containing system backends used for inferencing" group:"backends"` ExtraLLamaCPPArgs string `name:"llama-cpp-args" env:"LOCALAI_EXTRA_LLAMA_CPP_ARGS,EXTRA_LLAMA_CPP_ARGS" help:"Extra arguments to pass to llama-cpp-rpc-server"` } type Worker struct { P2P P2P `cmd:"" name:"p2p-llama-cpp-rpc" help:"Starts a LocalAI llama.cpp worker in P2P mode (requires a token)"` P2PMLX P2PMLX `cmd:"" name:"p2p-mlx" help:"Starts a LocalAI MLX distributed worker in P2P mode (requires a token)"` LLamaCPP LLamaCPP `cmd:"" name:"llama-cpp-rpc" help:"Starts a llama.cpp worker in standalone mode"` MLXDistributed MLXDistributed `cmd:"" name:"mlx-distributed" help:"Starts an MLX distributed worker in standalone mode (requires --hostfile and --rank)"` } ================================================ FILE: core/cli/worker/worker_llamacpp.go ================================================ package worker import ( "context" "encoding/json" "errors" "fmt" "os" "path/filepath" "strings" "syscall" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) type LLamaCPP struct { WorkerFlags `embed:""` } const ( llamaCPPRPCBinaryName = "llama-cpp-rpc-server" llamaCPPGalleryName = "llama-cpp" ) func findLLamaCPPBackend(galleries string, systemState *system.SystemState) (string, error) { backends, err := gallery.ListSystemBackends(systemState) if err != nil { xlog.Warn("Failed listing system backends", "error", err) return "", err } xlog.Debug("System backends", "backends", backends) backend, ok := backends.Get(llamaCPPGalleryName) if !ok { ml := model.NewModelLoader(systemState) var gals []config.Gallery if err := json.Unmarshal([]byte(galleries), &gals); err != nil { xlog.Error("failed loading galleries", "error", err) return "", err } err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, llamaCPPGalleryName, nil, true) if err != nil { xlog.Error("llama-cpp backend not found, failed to install it", "error", err) return "", err } } backendPath := filepath.Dir(backend.RunFile) if backendPath == "" { return "", errors.New("llama-cpp backend not found, install it first") } grpcProcess := filepath.Join( backendPath, llamaCPPRPCBinaryName, ) return grpcProcess, nil } func (r *LLamaCPP) Run(ctx *cliContext.Context) error { if len(os.Args) < 4 { return fmt.Errorf("usage: local-ai worker llama-cpp-rpc -- ") } systemState, err := system.GetSystemState( system.WithBackendPath(r.BackendsPath), system.WithBackendSystemPath(r.BackendsSystemPath), ) if err != nil { return err } grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState) if err != nil { return err } args := strings.Split(r.ExtraLLamaCPPArgs, " ") args = append([]string{grpcProcess}, args...) return syscall.Exec( grpcProcess, args, os.Environ()) } ================================================ FILE: core/cli/worker/worker_mlx_common.go ================================================ package worker import ( "context" "encoding/json" "errors" "os/exec" "path/filepath" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) const mlxDistributedGalleryName = "mlx-distributed" // findMLXDistributedBackendPath finds or installs the mlx-distributed backend // and returns the directory containing run.sh. func findMLXDistributedBackendPath(galleries string, systemState *system.SystemState) (string, error) { backends, err := gallery.ListSystemBackends(systemState) if err != nil { return "", err } backend, ok := backends.Get(mlxDistributedGalleryName) if !ok { ml := model.NewModelLoader(systemState) var gals []config.Gallery if err := json.Unmarshal([]byte(galleries), &gals); err != nil { xlog.Error("failed loading galleries", "error", err) return "", err } if err := gallery.InstallBackendFromGallery(context.Background(), gals, systemState, ml, mlxDistributedGalleryName, nil, true); err != nil { xlog.Error("mlx-distributed backend not found, failed to install it", "error", err) return "", err } // Re-fetch after install backends, err = gallery.ListSystemBackends(systemState) if err != nil { return "", err } backend, ok = backends.Get(mlxDistributedGalleryName) if !ok { return "", errors.New("mlx-distributed backend not found after install") } } backendPath := filepath.Dir(backend.RunFile) if backendPath == "" { return "", errors.New("mlx-distributed backend not found, install it first") } return backendPath, nil } // buildMLXCommand builds the exec.Cmd to launch the mlx-distributed backend. // backendPath is the directory containing run.sh (empty string to fall back to // running backend.py directly via python3). func buildMLXCommand(backendPath string, args ...string) *exec.Cmd { if backendPath != "" { return exec.Command(filepath.Join(backendPath, "run.sh"), args...) } return exec.Command("python3", append([]string{"backend.py"}, args...)...) } ================================================ FILE: core/cli/worker/worker_mlx_distributed.go ================================================ package worker import ( "fmt" "os" "syscall" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) type MLXDistributed struct { WorkerFlags `embed:""` Hostfile string `env:"MLX_DISTRIBUTED_HOSTFILE" required:"" help:"Path to hostfile JSON. Ring: array of 'ip:port' where entry i is rank i's listen address. JACCL: 2D matrix of RDMA device names."` Rank int `env:"MLX_RANK" required:"" help:"Rank of this process (0 = gRPC server + ring participant, >0 = worker only)"` Backend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend: 'ring' (TCP pipeline parallelism) or 'jaccl' (RDMA tensor parallelism)"` Addr string `env:"MLX_DISTRIBUTED_ADDR" default:"localhost:50051" help:"gRPC API listen address for LocalAI (rank 0 only, separate from ring communication)"` Coordinator string `env:"MLX_JACCL_COORDINATOR" default:"" help:"JACCL coordinator ip:port — rank 0's address where it accepts RDMA setup connections (all ranks must use the same value)"` } func (r *MLXDistributed) Run(ctx *cliContext.Context) error { systemState, err := system.GetSystemState( system.WithBackendPath(r.BackendsPath), system.WithBackendSystemPath(r.BackendsSystemPath), ) if err != nil { return err } backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState) if err != nil { return fmt.Errorf("cannot find mlx-distributed backend: %w", err) } args := []string{ "--backend", r.Backend, "--hostfile", r.Hostfile, "--rank", fmt.Sprint(r.Rank), } if r.Rank == 0 { args = append(args, "--addr", r.Addr) } else { args = append(args, "--worker") } if r.Backend == "jaccl" && r.Coordinator != "" { args = append(args, "--coordinator", r.Coordinator) } cmd := buildMLXCommand(backendPath, args...) runSh := cmd.Path xlog.Info("Starting mlx-distributed", "rank", r.Rank, "backend", r.Backend, "hostfile", r.Hostfile) return syscall.Exec( runSh, append([]string{runSh}, args...), os.Environ(), ) } ================================================ FILE: core/cli/worker/worker_p2p.go ================================================ package worker import ( "context" "fmt" "os" "os/exec" "strings" "time" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" "github.com/phayes/freeport" ) type P2P struct { WorkerFlags `embed:""` Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"` NoRunner bool `env:"LOCALAI_NO_RUNNER,NO_RUNNER" help:"Do not start the llama-cpp-rpc-server"` RunnerAddress string `env:"LOCALAI_RUNNER_ADDRESS,RUNNER_ADDRESS" help:"Address of the llama-cpp-rpc-server"` RunnerPort string `env:"LOCALAI_RUNNER_PORT,RUNNER_PORT" help:"Port of the llama-cpp-rpc-server"` Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode, can be set arbitrarly by the user for grouping a set of instances" group:"p2p"` } func (r *P2P) Run(ctx *cliContext.Context) error { systemState, err := system.GetSystemState( system.WithBackendPath(r.BackendsPath), system.WithBackendSystemPath(r.BackendsSystemPath), ) if err != nil { return err } // Check if the token is set // as we always need it. if r.Token == "" { return fmt.Errorf("Token is required") } port, err := freeport.GetFreePort() if err != nil { return err } address := "127.0.0.1" c, cancel := context.WithCancel(context.Background()) defer cancel() if r.NoRunner { // Let override which port and address to bind if the user // configure the llama-cpp service on its own p := fmt.Sprint(port) if r.RunnerAddress != "" { address = r.RunnerAddress } if r.RunnerPort != "" { p = r.RunnerPort } _, err = p2p.ExposeService(c, address, p, r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID)) if err != nil { return err } xlog.Info("You need to start llama-cpp-rpc-server", "address", address, "port", p) } else { // Start llama.cpp directly from the version we have pre-packaged go func() { for { xlog.Info("Starting llama-cpp-rpc-server", "address", address, "port", port) grpcProcess, err := findLLamaCPPBackend(r.BackendGalleries, systemState) if err != nil { xlog.Error("Failed to find llama-cpp-rpc-server", "error", err) return } var extraArgs []string if r.ExtraLLamaCPPArgs != "" { extraArgs = strings.Split(r.ExtraLLamaCPPArgs, " ") } args := append([]string{"--host", address, "--port", fmt.Sprint(port)}, extraArgs...) xlog.Debug("Starting llama-cpp-rpc-server", "address", address, "port", port, "args", args, "argCount", len(args)) cmd := exec.Command( grpcProcess, args..., ) cmd.Env = os.Environ() cmd.Stderr = os.Stdout cmd.Stdout = os.Stdout if err := cmd.Start(); err != nil { xlog.Error("Failed to start llama-cpp-rpc-server", "error", err, "grpcProcess", grpcProcess, "args", args) } cmd.Wait() } }() _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.LlamaCPPWorkerID)) if err != nil { return err } } signals.RegisterGracefulTerminationHandler(func() { cancel() }) for { time.Sleep(1 * time.Second) } } ================================================ FILE: core/cli/worker/worker_p2p_mlx.go ================================================ package worker import ( "context" "fmt" "os" "time" cliContext "github.com/mudler/LocalAI/core/cli/context" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/pkg/signals" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" "github.com/phayes/freeport" ) type P2PMLX struct { WorkerFlags `embed:""` Token string `env:"LOCALAI_TOKEN,LOCALAI_P2P_TOKEN,TOKEN" help:"P2P token to use"` Peer2PeerNetworkID string `env:"LOCALAI_P2P_NETWORK_ID,P2P_NETWORK_ID" help:"Network ID for P2P mode" group:"p2p"` MLXListenPort string `env:"MLX_LISTEN_PORT" default:"5555" help:"Port for MLX distributed communication"` MLXBackend string `env:"MLX_DISTRIBUTED_BACKEND" default:"ring" help:"MLX distributed backend (ring or jaccl)"` } func (r *P2PMLX) Run(ctx *cliContext.Context) error { if r.Token == "" { return fmt.Errorf("token is required") } systemState, err := system.GetSystemState( system.WithBackendPath(r.BackendsPath), system.WithBackendSystemPath(r.BackendsSystemPath), ) if err != nil { return err } port, err := freeport.GetFreePort() if err != nil { return err } if r.MLXListenPort != "" { fmt.Sscanf(r.MLXListenPort, "%d", &port) } address := "127.0.0.1" c, cancel := context.WithCancel(context.Background()) defer cancel() backendPath, err := findMLXDistributedBackendPath(r.BackendGalleries, systemState) if err != nil { xlog.Warn("Could not find mlx-distributed backend from gallery, will try backend.py directly", "error", err) } go func() { for { hostfile := os.Getenv("MLX_DISTRIBUTED_HOSTFILE") if hostfile == "" { xlog.Info("Waiting for MLX_DISTRIBUTED_HOSTFILE to be set by P2P discovery...") time.Sleep(2 * time.Second) continue } xlog.Info("Starting mlx-distributed worker", "address", address, "port", port, "hostfile", hostfile) cmd := buildMLXCommand(backendPath, "--worker", "--backend", r.MLXBackend, "--hostfile", hostfile, "--rank", "0", ) cmd.Env = os.Environ() cmd.Stderr = os.Stderr cmd.Stdout = os.Stdout if err := cmd.Run(); err != nil { xlog.Error("mlx-distributed worker exited", "error", err) } time.Sleep(2 * time.Second) } }() _, err = p2p.ExposeService(c, address, fmt.Sprint(port), r.Token, p2p.NetworkID(r.Peer2PeerNetworkID, p2p.MLXWorkerID)) if err != nil { return err } xlog.Info("MLX distributed worker registered on P2P network", "address", address, "port", port) signals.RegisterGracefulTerminationHandler(func() { cancel() }) <-c.Done() return nil } ================================================ FILE: core/clients/store.go ================================================ package clients import ( "bytes" "encoding/json" "fmt" "io" "net/http" ) // Define a struct to hold the store API client type StoreClient struct { BaseURL string Client *http.Client } type SetRequest struct { Keys [][]float32 `json:"keys"` Values []string `json:"values"` } type GetRequest struct { Keys [][]float32 `json:"keys"` } type GetResponse struct { Keys [][]float32 `json:"keys"` Values []string `json:"values"` } type DeleteRequest struct { Keys [][]float32 `json:"keys"` } type FindRequest struct { TopK int `json:"topk"` Key []float32 `json:"key"` } type FindResponse struct { Keys [][]float32 `json:"keys"` Values []string `json:"values"` Similarities []float32 `json:"similarities"` } // Constructor for StoreClient func NewStoreClient(baseUrl string) *StoreClient { return &StoreClient{ BaseURL: baseUrl, Client: &http.Client{}, } } // Implement Set method func (c *StoreClient) Set(req SetRequest) error { return c.doRequest("stores/set", req) } // Implement Get method func (c *StoreClient) Get(req GetRequest) (*GetResponse, error) { body, err := c.doRequestWithResponse("stores/get", req) if err != nil { return nil, err } var resp GetResponse err = json.Unmarshal(body, &resp) if err != nil { return nil, err } return &resp, nil } // Implement Delete method func (c *StoreClient) Delete(req DeleteRequest) error { return c.doRequest("stores/delete", req) } // Implement Find method func (c *StoreClient) Find(req FindRequest) (*FindResponse, error) { body, err := c.doRequestWithResponse("stores/find", req) if err != nil { return nil, err } var resp FindResponse err = json.Unmarshal(body, &resp) if err != nil { return nil, err } return &resp, nil } // Helper function to perform a request without expecting a response body func (c *StoreClient) doRequest(path string, data interface{}) error { jsonData, err := json.Marshal(data) if err != nil { return err } req, err := http.NewRequest("POST", c.BaseURL+"/"+path, bytes.NewBuffer(jsonData)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") resp, err := c.Client.Do(req) if err != nil { return err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return fmt.Errorf("API request to %s failed with status code %d", path, resp.StatusCode) } return nil } // Helper function to perform a request and parse the response body func (c *StoreClient) doRequestWithResponse(path string, data interface{}) ([]byte, error) { jsonData, err := json.Marshal(data) if err != nil { return nil, err } req, err := http.NewRequest("POST", c.BaseURL+"/"+path, bytes.NewBuffer(jsonData)) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") resp, err := c.Client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("API request to %s failed with status code %d", path, resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } return body, nil } ================================================ FILE: core/config/application_config.go ================================================ package config import ( "context" "encoding/json" "regexp" "time" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/xlog" ) type ApplicationConfig struct { Context context.Context ConfigFile string SystemState *system.SystemState ExternalBackends []string UploadLimitMB, Threads, ContextSize int F16 bool Debug bool EnableTracing bool TracingMaxItems int EnableBackendLogging bool GeneratedContentDir string UploadDir string DataPath string // Persistent data directory for collectiondb, agents, etc. DynamicConfigsDir string DynamicConfigsDirPollInterval time.Duration CORS bool DisableCSRF bool PreloadJSONModels string PreloadModelsFromPath string CORSAllowOrigins string ApiKeys []string P2PToken string P2PNetworkID string Federated bool DisableWebUI bool EnforcePredownloadScans bool OpaqueErrors bool UseSubtleKeyComparison bool DisableApiKeyRequirementForHttpGet bool DisableMetrics bool HttpGetExemptedEndpoints []*regexp.Regexp DisableGalleryEndpoint bool DisableMCP bool LoadToMemory []string Galleries []Gallery BackendGalleries []Gallery ExternalGRPCBackends map[string]string AutoloadGalleries, AutoloadBackendGalleries bool SingleBackend bool // Deprecated: use MaxActiveBackends = 1 instead MaxActiveBackends int // Maximum number of active backends (0 = unlimited, 1 = single backend mode) ParallelBackendRequests bool WatchDogIdle bool WatchDogBusy bool WatchDog bool // Memory Reclaimer settings (works with GPU if available, otherwise RAM) MemoryReclaimerEnabled bool // Enable memory threshold monitoring MemoryReclaimerThreshold float64 // Threshold 0.0-1.0 (e.g., 0.95 = 95%) // Eviction settings ForceEvictionWhenBusy bool // Force eviction even when models have active API calls (default: false for safety) LRUEvictionMaxRetries int // Maximum number of retries when waiting for busy models to become idle (default: 30) LRUEvictionRetryInterval time.Duration // Interval between retries when waiting for busy models (default: 1s) ModelsURL []string WatchDogBusyTimeout, WatchDogIdleTimeout time.Duration WatchDogInterval time.Duration // Interval between watchdog checks MachineTag string APIAddress string LlamaCPPTunnelCallback func(tunnels []string) MLXTunnelCallback func(tunnels []string) DisableRuntimeSettings bool AgentJobRetentionDays int // Default: 30 days OpenResponsesStoreTTL time.Duration // TTL for Open Responses store (0 = no expiration) PathWithoutAuth []string // Agent Pool (LocalAGI integration) AgentPool AgentPoolConfig // Authentication & Authorization Auth AuthConfig } // AuthConfig holds configuration for user authentication and authorization. type AuthConfig struct { Enabled bool DatabaseURL string // "postgres://..." or file path for SQLite GitHubClientID string GitHubClientSecret string OIDCIssuer string // OIDC issuer URL for auto-discovery (e.g. https://accounts.google.com) OIDCClientID string OIDCClientSecret string BaseURL string // for OAuth callback URLs (e.g. "http://localhost:8080") AdminEmail string // auto-promote to admin on login RegistrationMode string // "open", "approval" (default when empty), "invite" DisableLocalAuth bool // disable local email/password registration and login APIKeyHMACSecret string // HMAC secret for API key hashing; auto-generated if empty DefaultAPIKeyExpiry string // default expiry duration for API keys (e.g. "90d"); empty = no expiry } // AgentPoolConfig holds configuration for the LocalAGI agent pool integration. type AgentPoolConfig struct { Enabled bool // default: true (disabled by LOCALAI_DISABLE_AGENTS=true) StateDir string // default: DynamicConfigsDir (LocalAI configuration folder) APIURL string // default: self-referencing LocalAI (http://127.0.0.1:) APIKey string // default: first API key from LocalAI config DefaultModel string MultimodalModel string TranscriptionModel string TranscriptionLanguage string TTSModel string Timeout string // default: "5m" EnableSkills bool EnableLogs bool CustomActionsDir string CollectionDBPath string VectorEngine string // default: "chromem" EmbeddingModel string // default: "granite-embedding-107m-multilingual" MaxChunkingSize int // default: 400 ChunkOverlap int // default: 0 DatabaseURL string AgentHubURL string // default: "https://agenthub.localai.io" } type AppOption func(*ApplicationConfig) func NewApplicationConfig(o ...AppOption) *ApplicationConfig { opt := &ApplicationConfig{ Context: context.Background(), UploadLimitMB: 15, Debug: true, AgentJobRetentionDays: 30, // Default: 30 days LRUEvictionMaxRetries: 30, // Default: 30 retries LRUEvictionRetryInterval: 1 * time.Second, // Default: 1 second WatchDogInterval: 500 * time.Millisecond, // Default: 500ms TracingMaxItems: 1024, AgentPool: AgentPoolConfig{ Enabled: true, Timeout: "5m", VectorEngine: "chromem", EmbeddingModel: "granite-embedding-107m-multilingual", MaxChunkingSize: 400, AgentHubURL: "https://agenthub.localai.io", }, PathWithoutAuth: []string{ "/static/", "/generated-audio/", "/generated-images/", "/generated-videos/", "/favicon.svg", "/readyz", "/healthz", "/api/auth/", "/assets/", }, } for _, oo := range o { oo(opt) } return opt } func WithModelsURL(urls ...string) AppOption { return func(o *ApplicationConfig) { o.ModelsURL = urls } } func WithSystemState(state *system.SystemState) AppOption { return func(o *ApplicationConfig) { o.SystemState = state } } func WithExternalBackends(backends ...string) AppOption { return func(o *ApplicationConfig) { o.ExternalBackends = backends } } func WithMachineTag(tag string) AppOption { return func(o *ApplicationConfig) { o.MachineTag = tag } } func WithCors(b bool) AppOption { return func(o *ApplicationConfig) { o.CORS = b } } func WithP2PNetworkID(s string) AppOption { return func(o *ApplicationConfig) { o.P2PNetworkID = s } } func WithDisableCSRF(b bool) AppOption { return func(o *ApplicationConfig) { o.DisableCSRF = b } } func WithP2PToken(s string) AppOption { return func(o *ApplicationConfig) { o.P2PToken = s } } var EnableWatchDog = func(o *ApplicationConfig) { o.WatchDog = true } var EnableTracing = func(o *ApplicationConfig) { o.EnableTracing = true } var EnableBackendLogging = func(o *ApplicationConfig) { o.EnableBackendLogging = true } var EnableWatchDogIdleCheck = func(o *ApplicationConfig) { o.WatchDog = true o.WatchDogIdle = true } var DisableGalleryEndpoint = func(o *ApplicationConfig) { o.DisableGalleryEndpoint = true } var DisableMCP = func(o *ApplicationConfig) { o.DisableMCP = true } var EnableWatchDogBusyCheck = func(o *ApplicationConfig) { o.WatchDog = true o.WatchDogBusy = true } var DisableWebUI = func(o *ApplicationConfig) { o.DisableWebUI = true } var DisableRuntimeSettings = func(o *ApplicationConfig) { o.DisableRuntimeSettings = true } func SetWatchDogBusyTimeout(t time.Duration) AppOption { return func(o *ApplicationConfig) { o.WatchDogBusyTimeout = t } } func SetWatchDogIdleTimeout(t time.Duration) AppOption { return func(o *ApplicationConfig) { o.WatchDogIdleTimeout = t } } func SetWatchDogInterval(t time.Duration) AppOption { return func(o *ApplicationConfig) { o.WatchDogInterval = t } } // EnableMemoryReclaimer enables memory threshold monitoring. // When enabled, the watchdog will evict backends if memory usage exceeds the threshold. // Works with GPU VRAM if available, otherwise uses system RAM. var EnableMemoryReclaimer = func(o *ApplicationConfig) { o.MemoryReclaimerEnabled = true o.WatchDog = true // Memory reclaimer requires watchdog infrastructure } // SetMemoryReclaimerThreshold sets the memory usage threshold (0.0-1.0). // When memory usage exceeds this threshold, backends will be evicted using LRU strategy. func SetMemoryReclaimerThreshold(threshold float64) AppOption { return func(o *ApplicationConfig) { if threshold > 0 && threshold <= 1.0 { o.MemoryReclaimerThreshold = threshold o.MemoryReclaimerEnabled = true o.WatchDog = true // Memory reclaimer requires watchdog infrastructure } } } // WithMemoryReclaimer configures the memory reclaimer with the given settings func WithMemoryReclaimer(enabled bool, threshold float64) AppOption { return func(o *ApplicationConfig) { o.MemoryReclaimerEnabled = enabled if threshold > 0 && threshold <= 1.0 { o.MemoryReclaimerThreshold = threshold } if enabled { o.WatchDog = true // Memory reclaimer requires watchdog infrastructure } } } // EnableSingleBackend is deprecated: use SetMaxActiveBackends(1) instead. // This is kept for backward compatibility. var EnableSingleBackend = func(o *ApplicationConfig) { o.SingleBackend = true o.MaxActiveBackends = 1 } // SetMaxActiveBackends sets the maximum number of active backends. // 0 = unlimited, 1 = single backend mode (replaces EnableSingleBackend) func SetMaxActiveBackends(n int) AppOption { return func(o *ApplicationConfig) { o.MaxActiveBackends = n // For backward compatibility, also set SingleBackend if n == 1 if n == 1 { o.SingleBackend = true } } } // GetEffectiveMaxActiveBackends returns the effective max active backends limit. // It considers both MaxActiveBackends and the deprecated SingleBackend setting. // If MaxActiveBackends is set (> 0), it takes precedence. // If SingleBackend is true and MaxActiveBackends is 0, returns 1. // Otherwise returns 0 (unlimited). func (o *ApplicationConfig) GetEffectiveMaxActiveBackends() int { if o.MaxActiveBackends > 0 { return o.MaxActiveBackends } if o.SingleBackend { return 1 } return 0 } // WithForceEvictionWhenBusy sets whether to force eviction even when models have active API calls func WithForceEvictionWhenBusy(enabled bool) AppOption { return func(o *ApplicationConfig) { o.ForceEvictionWhenBusy = enabled } } // WithLRUEvictionMaxRetries sets the maximum number of retries when waiting for busy models to become idle func WithLRUEvictionMaxRetries(maxRetries int) AppOption { return func(o *ApplicationConfig) { if maxRetries > 0 { o.LRUEvictionMaxRetries = maxRetries } } } // WithLRUEvictionRetryInterval sets the interval between retries when waiting for busy models func WithLRUEvictionRetryInterval(interval time.Duration) AppOption { return func(o *ApplicationConfig) { if interval > 0 { o.LRUEvictionRetryInterval = interval } } } var EnableParallelBackendRequests = func(o *ApplicationConfig) { o.ParallelBackendRequests = true } var EnableGalleriesAutoload = func(o *ApplicationConfig) { o.AutoloadGalleries = true } var EnableBackendGalleriesAutoload = func(o *ApplicationConfig) { o.AutoloadBackendGalleries = true } var EnableFederated = func(o *ApplicationConfig) { o.Federated = true } func WithExternalBackend(name string, uri string) AppOption { return func(o *ApplicationConfig) { if o.ExternalGRPCBackends == nil { o.ExternalGRPCBackends = make(map[string]string) } o.ExternalGRPCBackends[name] = uri } } func WithCorsAllowOrigins(b string) AppOption { return func(o *ApplicationConfig) { o.CORSAllowOrigins = b } } func WithStringGalleries(galls string) AppOption { return func(o *ApplicationConfig) { if galls == "" { o.Galleries = []Gallery{} return } var galleries []Gallery if err := json.Unmarshal([]byte(galls), &galleries); err != nil { xlog.Error("failed loading galleries", "error", err) } o.Galleries = append(o.Galleries, galleries...) } } func WithBackendGalleries(galls string) AppOption { return func(o *ApplicationConfig) { if galls == "" { o.BackendGalleries = []Gallery{} return } var galleries []Gallery if err := json.Unmarshal([]byte(galls), &galleries); err != nil { xlog.Error("failed loading galleries", "error", err) } o.BackendGalleries = append(o.BackendGalleries, galleries...) } } func WithGalleries(galleries []Gallery) AppOption { return func(o *ApplicationConfig) { o.Galleries = append(o.Galleries, galleries...) } } func WithContext(ctx context.Context) AppOption { return func(o *ApplicationConfig) { o.Context = ctx } } func WithYAMLConfigPreload(configFile string) AppOption { return func(o *ApplicationConfig) { o.PreloadModelsFromPath = configFile } } func WithJSONStringPreload(configFile string) AppOption { return func(o *ApplicationConfig) { o.PreloadJSONModels = configFile } } func WithConfigFile(configFile string) AppOption { return func(o *ApplicationConfig) { o.ConfigFile = configFile } } func WithUploadLimitMB(limit int) AppOption { return func(o *ApplicationConfig) { o.UploadLimitMB = limit } } func WithThreads(threads int) AppOption { return func(o *ApplicationConfig) { if threads == 0 { // 0 is not allowed threads = xsysinfo.CPUPhysicalCores() } o.Threads = threads } } func WithContextSize(ctxSize int) AppOption { return func(o *ApplicationConfig) { o.ContextSize = ctxSize } } func WithLlamaCPPTunnelCallback(callback func(tunnels []string)) AppOption { return func(o *ApplicationConfig) { o.LlamaCPPTunnelCallback = callback } } func WithMLXTunnelCallback(callback func(tunnels []string)) AppOption { return func(o *ApplicationConfig) { o.MLXTunnelCallback = callback } } func WithF16(f16 bool) AppOption { return func(o *ApplicationConfig) { o.F16 = f16 } } func WithDebug(debug bool) AppOption { return func(o *ApplicationConfig) { o.Debug = debug } } func WithTracingMaxItems(items int) AppOption { return func(o *ApplicationConfig) { o.TracingMaxItems = items } } func WithGeneratedContentDir(generatedContentDir string) AppOption { return func(o *ApplicationConfig) { o.GeneratedContentDir = generatedContentDir } } func WithUploadDir(uploadDir string) AppOption { return func(o *ApplicationConfig) { o.UploadDir = uploadDir } } func WithDataPath(dataPath string) AppOption { return func(o *ApplicationConfig) { o.DataPath = dataPath } } func WithDynamicConfigDir(dynamicConfigsDir string) AppOption { return func(o *ApplicationConfig) { o.DynamicConfigsDir = dynamicConfigsDir } } func WithDynamicConfigDirPollInterval(interval time.Duration) AppOption { return func(o *ApplicationConfig) { o.DynamicConfigsDirPollInterval = interval } } func WithApiKeys(apiKeys []string) AppOption { return func(o *ApplicationConfig) { o.ApiKeys = apiKeys } } func WithAgentJobRetentionDays(days int) AppOption { return func(o *ApplicationConfig) { o.AgentJobRetentionDays = days } } func WithOpenResponsesStoreTTL(ttl time.Duration) AppOption { return func(o *ApplicationConfig) { o.OpenResponsesStoreTTL = ttl } } func WithEnforcedPredownloadScans(enforced bool) AppOption { return func(o *ApplicationConfig) { o.EnforcePredownloadScans = enforced } } func WithOpaqueErrors(opaque bool) AppOption { return func(o *ApplicationConfig) { o.OpaqueErrors = opaque } } func WithLoadToMemory(models []string) AppOption { return func(o *ApplicationConfig) { o.LoadToMemory = models } } func WithSubtleKeyComparison(subtle bool) AppOption { return func(o *ApplicationConfig) { o.UseSubtleKeyComparison = subtle } } func WithDisableApiKeyRequirementForHttpGet(required bool) AppOption { return func(o *ApplicationConfig) { o.DisableApiKeyRequirementForHttpGet = required } } func WithAPIAddress(address string) AppOption { return func(o *ApplicationConfig) { o.APIAddress = address } } var DisableMetricsEndpoint AppOption = func(o *ApplicationConfig) { o.DisableMetrics = true } func WithHttpGetExemptedEndpoints(endpoints []string) AppOption { return func(o *ApplicationConfig) { o.HttpGetExemptedEndpoints = []*regexp.Regexp{} for _, epr := range endpoints { r, err := regexp.Compile(epr) if err == nil && r != nil { o.HttpGetExemptedEndpoints = append(o.HttpGetExemptedEndpoints, r) } else { xlog.Warn("Error while compiling HTTP Get Exemption regex, skipping this entry.", "error", err, "regex", epr) } } } } // Agent Pool options var DisableAgentPool = func(o *ApplicationConfig) { o.AgentPool.Enabled = false } func WithAgentPoolAPIURL(url string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.APIURL = url } } func WithAgentPoolAPIKey(key string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.APIKey = key } } func WithAgentPoolDefaultModel(model string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.DefaultModel = model } } func WithAgentPoolMultimodalModel(model string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.MultimodalModel = model } } func WithAgentPoolTranscriptionModel(model string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.TranscriptionModel = model } } func WithAgentPoolTranscriptionLanguage(lang string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.TranscriptionLanguage = lang } } func WithAgentPoolTTSModel(model string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.TTSModel = model } } func WithAgentPoolStateDir(dir string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.StateDir = dir } } func WithAgentPoolTimeout(timeout string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.Timeout = timeout } } var EnableAgentPoolSkills = func(o *ApplicationConfig) { o.AgentPool.EnableSkills = true } func WithAgentPoolVectorEngine(engine string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.VectorEngine = engine } } func WithAgentPoolEmbeddingModel(model string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.EmbeddingModel = model } } func WithAgentPoolCustomActionsDir(dir string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.CustomActionsDir = dir } } func WithAgentPoolDatabaseURL(url string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.DatabaseURL = url } } func WithAgentPoolMaxChunkingSize(size int) AppOption { return func(o *ApplicationConfig) { o.AgentPool.MaxChunkingSize = size } } func WithAgentPoolChunkOverlap(overlap int) AppOption { return func(o *ApplicationConfig) { o.AgentPool.ChunkOverlap = overlap } } var EnableAgentPoolLogs = func(o *ApplicationConfig) { o.AgentPool.EnableLogs = true } func WithAgentPoolCollectionDBPath(path string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.CollectionDBPath = path } } func WithAgentHubURL(url string) AppOption { return func(o *ApplicationConfig) { o.AgentPool.AgentHubURL = url } } // Auth options func WithAuthEnabled(enabled bool) AppOption { return func(o *ApplicationConfig) { o.Auth.Enabled = enabled } } func WithAuthDatabaseURL(url string) AppOption { return func(o *ApplicationConfig) { o.Auth.DatabaseURL = url } } func WithAuthGitHubClientID(clientID string) AppOption { return func(o *ApplicationConfig) { o.Auth.GitHubClientID = clientID } } func WithAuthGitHubClientSecret(clientSecret string) AppOption { return func(o *ApplicationConfig) { o.Auth.GitHubClientSecret = clientSecret } } func WithAuthBaseURL(baseURL string) AppOption { return func(o *ApplicationConfig) { o.Auth.BaseURL = baseURL } } func WithAuthAdminEmail(email string) AppOption { return func(o *ApplicationConfig) { o.Auth.AdminEmail = email } } func WithAuthRegistrationMode(mode string) AppOption { return func(o *ApplicationConfig) { o.Auth.RegistrationMode = mode } } func WithAuthDisableLocalAuth(disable bool) AppOption { return func(o *ApplicationConfig) { o.Auth.DisableLocalAuth = disable } } func WithAuthOIDCIssuer(issuer string) AppOption { return func(o *ApplicationConfig) { o.Auth.OIDCIssuer = issuer } } func WithAuthOIDCClientID(clientID string) AppOption { return func(o *ApplicationConfig) { o.Auth.OIDCClientID = clientID } } func WithAuthOIDCClientSecret(clientSecret string) AppOption { return func(o *ApplicationConfig) { o.Auth.OIDCClientSecret = clientSecret } } func WithAuthAPIKeyHMACSecret(secret string) AppOption { return func(o *ApplicationConfig) { o.Auth.APIKeyHMACSecret = secret } } func WithAuthDefaultAPIKeyExpiry(expiry string) AppOption { return func(o *ApplicationConfig) { o.Auth.DefaultAPIKeyExpiry = expiry } } // ToConfigLoaderOptions returns a slice of ConfigLoader Option. // Some options defined at the application level are going to be passed as defaults for // all the configuration for the models. // This includes for instance the context size or the number of threads. // If a model doesn't set configs directly to the config model file // it will use the defaults defined here. func (o *ApplicationConfig) ToConfigLoaderOptions() []ConfigLoaderOption { return []ConfigLoaderOption{ LoadOptionContextSize(o.ContextSize), LoadOptionDebug(o.Debug), LoadOptionF16(o.F16), LoadOptionThreads(o.Threads), ModelPath(o.SystemState.Model.ModelsPath), } } // ToRuntimeSettings converts ApplicationConfig to RuntimeSettings for API responses and JSON serialization. // This provides a single source of truth - ApplicationConfig holds the live values, // and this method creates a RuntimeSettings snapshot for external consumption. func (o *ApplicationConfig) ToRuntimeSettings() RuntimeSettings { // Create local copies for pointer fields watchdogEnabled := o.WatchDog watchdogIdle := o.WatchDogIdle watchdogBusy := o.WatchDogBusy singleBackend := o.SingleBackend maxActiveBackends := o.MaxActiveBackends parallelBackendRequests := o.ParallelBackendRequests memoryReclaimerEnabled := o.MemoryReclaimerEnabled memoryReclaimerThreshold := o.MemoryReclaimerThreshold forceEvictionWhenBusy := o.ForceEvictionWhenBusy lruEvictionMaxRetries := o.LRUEvictionMaxRetries threads := o.Threads contextSize := o.ContextSize f16 := o.F16 debug := o.Debug tracingMaxItems := o.TracingMaxItems enableTracing := o.EnableTracing enableBackendLogging := o.EnableBackendLogging cors := o.CORS csrf := o.DisableCSRF corsAllowOrigins := o.CORSAllowOrigins p2pToken := o.P2PToken p2pNetworkID := o.P2PNetworkID federated := o.Federated galleries := o.Galleries backendGalleries := o.BackendGalleries autoloadGalleries := o.AutoloadGalleries autoloadBackendGalleries := o.AutoloadBackendGalleries apiKeys := o.ApiKeys agentJobRetentionDays := o.AgentJobRetentionDays // Format timeouts as strings var idleTimeout, busyTimeout, watchdogInterval string if o.WatchDogIdleTimeout > 0 { idleTimeout = o.WatchDogIdleTimeout.String() } else { idleTimeout = "15m" // default } if o.WatchDogBusyTimeout > 0 { busyTimeout = o.WatchDogBusyTimeout.String() } else { busyTimeout = "5m" // default } if o.WatchDogInterval > 0 { watchdogInterval = o.WatchDogInterval.String() } else { watchdogInterval = "2s" // default } var lruEvictionRetryInterval string if o.LRUEvictionRetryInterval > 0 { lruEvictionRetryInterval = o.LRUEvictionRetryInterval.String() } else { lruEvictionRetryInterval = "1s" // default } var openResponsesStoreTTL string if o.OpenResponsesStoreTTL > 0 { openResponsesStoreTTL = o.OpenResponsesStoreTTL.String() } else { openResponsesStoreTTL = "0" // default: no expiration } // Agent Pool settings agentPoolEnabled := o.AgentPool.Enabled agentPoolDefaultModel := o.AgentPool.DefaultModel agentPoolEmbeddingModel := o.AgentPool.EmbeddingModel agentPoolMaxChunkingSize := o.AgentPool.MaxChunkingSize agentPoolChunkOverlap := o.AgentPool.ChunkOverlap agentPoolEnableLogs := o.AgentPool.EnableLogs agentPoolCollectionDBPath := o.AgentPool.CollectionDBPath return RuntimeSettings{ WatchdogEnabled: &watchdogEnabled, WatchdogIdleEnabled: &watchdogIdle, WatchdogBusyEnabled: &watchdogBusy, WatchdogIdleTimeout: &idleTimeout, WatchdogBusyTimeout: &busyTimeout, WatchdogInterval: &watchdogInterval, SingleBackend: &singleBackend, MaxActiveBackends: &maxActiveBackends, ParallelBackendRequests: ¶llelBackendRequests, MemoryReclaimerEnabled: &memoryReclaimerEnabled, MemoryReclaimerThreshold: &memoryReclaimerThreshold, ForceEvictionWhenBusy: &forceEvictionWhenBusy, LRUEvictionMaxRetries: &lruEvictionMaxRetries, LRUEvictionRetryInterval: &lruEvictionRetryInterval, Threads: &threads, ContextSize: &contextSize, F16: &f16, Debug: &debug, TracingMaxItems: &tracingMaxItems, EnableTracing: &enableTracing, EnableBackendLogging: &enableBackendLogging, CORS: &cors, CSRF: &csrf, CORSAllowOrigins: &corsAllowOrigins, P2PToken: &p2pToken, P2PNetworkID: &p2pNetworkID, Federated: &federated, Galleries: &galleries, BackendGalleries: &backendGalleries, AutoloadGalleries: &autoloadGalleries, AutoloadBackendGalleries: &autoloadBackendGalleries, ApiKeys: &apiKeys, AgentJobRetentionDays: &agentJobRetentionDays, OpenResponsesStoreTTL: &openResponsesStoreTTL, AgentPoolEnabled: &agentPoolEnabled, AgentPoolDefaultModel: &agentPoolDefaultModel, AgentPoolEmbeddingModel: &agentPoolEmbeddingModel, AgentPoolMaxChunkingSize: &agentPoolMaxChunkingSize, AgentPoolChunkOverlap: &agentPoolChunkOverlap, AgentPoolEnableLogs: &agentPoolEnableLogs, AgentPoolCollectionDBPath: &agentPoolCollectionDBPath, } } // ApplyRuntimeSettings applies RuntimeSettings to ApplicationConfig. // Only non-nil fields in RuntimeSettings are applied. // Returns true if watchdog-related settings changed (requiring restart). func (o *ApplicationConfig) ApplyRuntimeSettings(settings *RuntimeSettings) (requireRestart bool) { if settings == nil { return false } if settings.WatchdogEnabled != nil { o.WatchDog = *settings.WatchdogEnabled requireRestart = true } if settings.WatchdogIdleEnabled != nil { o.WatchDogIdle = *settings.WatchdogIdleEnabled if o.WatchDogIdle { o.WatchDog = true } requireRestart = true } if settings.WatchdogBusyEnabled != nil { o.WatchDogBusy = *settings.WatchdogBusyEnabled if o.WatchDogBusy { o.WatchDog = true } requireRestart = true } if settings.WatchdogIdleTimeout != nil { if dur, err := time.ParseDuration(*settings.WatchdogIdleTimeout); err == nil { o.WatchDogIdleTimeout = dur requireRestart = true } } if settings.WatchdogBusyTimeout != nil { if dur, err := time.ParseDuration(*settings.WatchdogBusyTimeout); err == nil { o.WatchDogBusyTimeout = dur requireRestart = true } } if settings.WatchdogInterval != nil { if dur, err := time.ParseDuration(*settings.WatchdogInterval); err == nil { o.WatchDogInterval = dur requireRestart = true } } if settings.MaxActiveBackends != nil { o.MaxActiveBackends = *settings.MaxActiveBackends o.SingleBackend = (*settings.MaxActiveBackends == 1) requireRestart = true } else if settings.SingleBackend != nil { o.SingleBackend = *settings.SingleBackend if *settings.SingleBackend { o.MaxActiveBackends = 1 } else { o.MaxActiveBackends = 0 } requireRestart = true } if settings.ParallelBackendRequests != nil { o.ParallelBackendRequests = *settings.ParallelBackendRequests } if settings.MemoryReclaimerEnabled != nil { o.MemoryReclaimerEnabled = *settings.MemoryReclaimerEnabled if *settings.MemoryReclaimerEnabled { o.WatchDog = true } requireRestart = true } if settings.MemoryReclaimerThreshold != nil { if *settings.MemoryReclaimerThreshold > 0 && *settings.MemoryReclaimerThreshold <= 1.0 { o.MemoryReclaimerThreshold = *settings.MemoryReclaimerThreshold requireRestart = true } } if settings.ForceEvictionWhenBusy != nil { o.ForceEvictionWhenBusy = *settings.ForceEvictionWhenBusy // This setting doesn't require restart, can be updated dynamically } if settings.LRUEvictionMaxRetries != nil { o.LRUEvictionMaxRetries = *settings.LRUEvictionMaxRetries // This setting doesn't require restart, can be updated dynamically } if settings.LRUEvictionRetryInterval != nil { if dur, err := time.ParseDuration(*settings.LRUEvictionRetryInterval); err == nil { o.LRUEvictionRetryInterval = dur // This setting doesn't require restart, can be updated dynamically } } if settings.Threads != nil { o.Threads = *settings.Threads } if settings.ContextSize != nil { o.ContextSize = *settings.ContextSize } if settings.F16 != nil { o.F16 = *settings.F16 } if settings.Debug != nil { o.Debug = *settings.Debug } if settings.EnableTracing != nil { o.EnableTracing = *settings.EnableTracing } if settings.TracingMaxItems != nil { o.TracingMaxItems = *settings.TracingMaxItems } if settings.EnableBackendLogging != nil { o.EnableBackendLogging = *settings.EnableBackendLogging } if settings.CORS != nil { o.CORS = *settings.CORS } if settings.CSRF != nil { o.DisableCSRF = *settings.CSRF } if settings.CORSAllowOrigins != nil { o.CORSAllowOrigins = *settings.CORSAllowOrigins } if settings.P2PToken != nil { o.P2PToken = *settings.P2PToken } if settings.P2PNetworkID != nil { o.P2PNetworkID = *settings.P2PNetworkID } if settings.Federated != nil { o.Federated = *settings.Federated } if settings.Galleries != nil { o.Galleries = *settings.Galleries } if settings.BackendGalleries != nil { o.BackendGalleries = *settings.BackendGalleries } if settings.AutoloadGalleries != nil { o.AutoloadGalleries = *settings.AutoloadGalleries } if settings.AutoloadBackendGalleries != nil { o.AutoloadBackendGalleries = *settings.AutoloadBackendGalleries } if settings.AgentJobRetentionDays != nil { o.AgentJobRetentionDays = *settings.AgentJobRetentionDays } if settings.OpenResponsesStoreTTL != nil { if *settings.OpenResponsesStoreTTL == "0" || *settings.OpenResponsesStoreTTL == "" { o.OpenResponsesStoreTTL = 0 // No expiration } else if dur, err := time.ParseDuration(*settings.OpenResponsesStoreTTL); err == nil { o.OpenResponsesStoreTTL = dur } // This setting doesn't require restart, can be updated dynamically } // Agent Pool settings if settings.AgentPoolEnabled != nil { o.AgentPool.Enabled = *settings.AgentPoolEnabled requireRestart = true } if settings.AgentPoolDefaultModel != nil { o.AgentPool.DefaultModel = *settings.AgentPoolDefaultModel requireRestart = true } if settings.AgentPoolEmbeddingModel != nil { o.AgentPool.EmbeddingModel = *settings.AgentPoolEmbeddingModel requireRestart = true } if settings.AgentPoolMaxChunkingSize != nil { o.AgentPool.MaxChunkingSize = *settings.AgentPoolMaxChunkingSize requireRestart = true } if settings.AgentPoolChunkOverlap != nil { o.AgentPool.ChunkOverlap = *settings.AgentPoolChunkOverlap requireRestart = true } if settings.AgentPoolEnableLogs != nil { o.AgentPool.EnableLogs = *settings.AgentPoolEnableLogs requireRestart = true } if settings.AgentPoolCollectionDBPath != nil { o.AgentPool.CollectionDBPath = *settings.AgentPoolCollectionDBPath requireRestart = true } // Note: ApiKeys requires special handling (merging with startup keys) - handled in caller return requireRestart } // func WithMetrics(meter *metrics.Metrics) AppOption { // return func(o *StartupOptions) { // o.Metrics = meter // } // } ================================================ FILE: core/config/application_config_test.go ================================================ package config import ( "time" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("ApplicationConfig RuntimeSettings Conversion", func() { Describe("ToRuntimeSettings", func() { It("should convert all fields correctly", func() { appConfig := &ApplicationConfig{ WatchDog: true, WatchDogIdle: true, WatchDogBusy: true, WatchDogIdleTimeout: 20 * time.Minute, WatchDogBusyTimeout: 10 * time.Minute, SingleBackend: false, MaxActiveBackends: 5, ParallelBackendRequests: true, MemoryReclaimerEnabled: true, MemoryReclaimerThreshold: 0.85, Threads: 8, ContextSize: 4096, F16: true, Debug: true, CORS: true, DisableCSRF: true, CORSAllowOrigins: "https://example.com", P2PToken: "test-token", P2PNetworkID: "test-network", Federated: true, Galleries: []Gallery{{Name: "test-gallery", URL: "https://example.com"}}, BackendGalleries: []Gallery{{Name: "backend-gallery", URL: "https://example.com/backend"}}, AutoloadGalleries: true, AutoloadBackendGalleries: true, ApiKeys: []string{"key1", "key2"}, AgentJobRetentionDays: 30, } rs := appConfig.ToRuntimeSettings() Expect(rs.WatchdogEnabled).ToNot(BeNil()) Expect(*rs.WatchdogEnabled).To(BeTrue()) Expect(rs.WatchdogIdleEnabled).ToNot(BeNil()) Expect(*rs.WatchdogIdleEnabled).To(BeTrue()) Expect(rs.WatchdogBusyEnabled).ToNot(BeNil()) Expect(*rs.WatchdogBusyEnabled).To(BeTrue()) Expect(rs.WatchdogIdleTimeout).ToNot(BeNil()) Expect(*rs.WatchdogIdleTimeout).To(Equal("20m0s")) Expect(rs.WatchdogBusyTimeout).ToNot(BeNil()) Expect(*rs.WatchdogBusyTimeout).To(Equal("10m0s")) Expect(rs.SingleBackend).ToNot(BeNil()) Expect(*rs.SingleBackend).To(BeFalse()) Expect(rs.MaxActiveBackends).ToNot(BeNil()) Expect(*rs.MaxActiveBackends).To(Equal(5)) Expect(rs.ParallelBackendRequests).ToNot(BeNil()) Expect(*rs.ParallelBackendRequests).To(BeTrue()) Expect(rs.MemoryReclaimerEnabled).ToNot(BeNil()) Expect(*rs.MemoryReclaimerEnabled).To(BeTrue()) Expect(rs.MemoryReclaimerThreshold).ToNot(BeNil()) Expect(*rs.MemoryReclaimerThreshold).To(Equal(0.85)) Expect(rs.Threads).ToNot(BeNil()) Expect(*rs.Threads).To(Equal(8)) Expect(rs.ContextSize).ToNot(BeNil()) Expect(*rs.ContextSize).To(Equal(4096)) Expect(rs.F16).ToNot(BeNil()) Expect(*rs.F16).To(BeTrue()) Expect(rs.Debug).ToNot(BeNil()) Expect(*rs.Debug).To(BeTrue()) Expect(rs.CORS).ToNot(BeNil()) Expect(*rs.CORS).To(BeTrue()) Expect(rs.CSRF).ToNot(BeNil()) Expect(*rs.CSRF).To(BeTrue()) Expect(rs.CORSAllowOrigins).ToNot(BeNil()) Expect(*rs.CORSAllowOrigins).To(Equal("https://example.com")) Expect(rs.P2PToken).ToNot(BeNil()) Expect(*rs.P2PToken).To(Equal("test-token")) Expect(rs.P2PNetworkID).ToNot(BeNil()) Expect(*rs.P2PNetworkID).To(Equal("test-network")) Expect(rs.Federated).ToNot(BeNil()) Expect(*rs.Federated).To(BeTrue()) Expect(rs.Galleries).ToNot(BeNil()) Expect(*rs.Galleries).To(HaveLen(1)) Expect((*rs.Galleries)[0].Name).To(Equal("test-gallery")) Expect(rs.BackendGalleries).ToNot(BeNil()) Expect(*rs.BackendGalleries).To(HaveLen(1)) Expect((*rs.BackendGalleries)[0].Name).To(Equal("backend-gallery")) Expect(rs.AutoloadGalleries).ToNot(BeNil()) Expect(*rs.AutoloadGalleries).To(BeTrue()) Expect(rs.AutoloadBackendGalleries).ToNot(BeNil()) Expect(*rs.AutoloadBackendGalleries).To(BeTrue()) Expect(rs.ApiKeys).ToNot(BeNil()) Expect(*rs.ApiKeys).To(HaveLen(2)) Expect(*rs.ApiKeys).To(ContainElements("key1", "key2")) Expect(rs.AgentJobRetentionDays).ToNot(BeNil()) Expect(*rs.AgentJobRetentionDays).To(Equal(30)) }) It("should use default timeouts when not set", func() { appConfig := &ApplicationConfig{} rs := appConfig.ToRuntimeSettings() Expect(rs.WatchdogIdleTimeout).ToNot(BeNil()) Expect(*rs.WatchdogIdleTimeout).To(Equal("15m")) Expect(rs.WatchdogBusyTimeout).ToNot(BeNil()) Expect(*rs.WatchdogBusyTimeout).To(Equal("5m")) }) }) Describe("ApplyRuntimeSettings", func() { It("should return false when settings is nil", func() { appConfig := &ApplicationConfig{} changed := appConfig.ApplyRuntimeSettings(nil) Expect(changed).To(BeFalse()) }) It("should only apply non-nil fields", func() { appConfig := &ApplicationConfig{ WatchDog: false, Threads: 4, ContextSize: 2048, } watchdogEnabled := true rs := &RuntimeSettings{ WatchdogEnabled: &watchdogEnabled, // Leave other fields nil } changed := appConfig.ApplyRuntimeSettings(rs) Expect(changed).To(BeTrue()) Expect(appConfig.WatchDog).To(BeTrue()) // Unchanged fields should remain Expect(appConfig.Threads).To(Equal(4)) Expect(appConfig.ContextSize).To(Equal(2048)) }) It("should apply watchdog settings and return changed=true", func() { appConfig := &ApplicationConfig{} watchdogEnabled := true watchdogIdle := true watchdogBusy := true idleTimeout := "30m" busyTimeout := "15m" rs := &RuntimeSettings{ WatchdogEnabled: &watchdogEnabled, WatchdogIdleEnabled: &watchdogIdle, WatchdogBusyEnabled: &watchdogBusy, WatchdogIdleTimeout: &idleTimeout, WatchdogBusyTimeout: &busyTimeout, } changed := appConfig.ApplyRuntimeSettings(rs) Expect(changed).To(BeTrue()) Expect(appConfig.WatchDog).To(BeTrue()) Expect(appConfig.WatchDogIdle).To(BeTrue()) Expect(appConfig.WatchDogBusy).To(BeTrue()) Expect(appConfig.WatchDogIdleTimeout).To(Equal(30 * time.Minute)) Expect(appConfig.WatchDogBusyTimeout).To(Equal(15 * time.Minute)) }) It("should enable watchdog when idle is enabled", func() { appConfig := &ApplicationConfig{WatchDog: false} watchdogIdle := true rs := &RuntimeSettings{ WatchdogIdleEnabled: &watchdogIdle, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.WatchDog).To(BeTrue()) Expect(appConfig.WatchDogIdle).To(BeTrue()) }) It("should enable watchdog when busy is enabled", func() { appConfig := &ApplicationConfig{WatchDog: false} watchdogBusy := true rs := &RuntimeSettings{ WatchdogBusyEnabled: &watchdogBusy, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.WatchDog).To(BeTrue()) Expect(appConfig.WatchDogBusy).To(BeTrue()) }) It("should handle MaxActiveBackends and update SingleBackend accordingly", func() { appConfig := &ApplicationConfig{} maxBackends := 1 rs := &RuntimeSettings{ MaxActiveBackends: &maxBackends, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.MaxActiveBackends).To(Equal(1)) Expect(appConfig.SingleBackend).To(BeTrue()) // Test with multiple backends maxBackends = 5 rs = &RuntimeSettings{ MaxActiveBackends: &maxBackends, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.MaxActiveBackends).To(Equal(5)) Expect(appConfig.SingleBackend).To(BeFalse()) }) It("should handle SingleBackend and update MaxActiveBackends accordingly", func() { appConfig := &ApplicationConfig{} singleBackend := true rs := &RuntimeSettings{ SingleBackend: &singleBackend, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.SingleBackend).To(BeTrue()) Expect(appConfig.MaxActiveBackends).To(Equal(1)) // Test disabling single backend singleBackend = false rs = &RuntimeSettings{ SingleBackend: &singleBackend, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.SingleBackend).To(BeFalse()) Expect(appConfig.MaxActiveBackends).To(Equal(0)) }) It("should enable watchdog when memory reclaimer is enabled", func() { appConfig := &ApplicationConfig{WatchDog: false} memoryEnabled := true threshold := 0.90 rs := &RuntimeSettings{ MemoryReclaimerEnabled: &memoryEnabled, MemoryReclaimerThreshold: &threshold, } changed := appConfig.ApplyRuntimeSettings(rs) Expect(changed).To(BeTrue()) Expect(appConfig.WatchDog).To(BeTrue()) Expect(appConfig.MemoryReclaimerEnabled).To(BeTrue()) Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.90)) }) It("should reject invalid memory threshold values", func() { appConfig := &ApplicationConfig{MemoryReclaimerThreshold: 0.50} // Test threshold > 1.0 invalidThreshold := 1.5 rs := &RuntimeSettings{ MemoryReclaimerThreshold: &invalidThreshold, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.50)) // Should remain unchanged // Test threshold <= 0 invalidThreshold = 0.0 rs = &RuntimeSettings{ MemoryReclaimerThreshold: &invalidThreshold, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.50)) // Should remain unchanged // Test negative threshold invalidThreshold = -0.5 rs = &RuntimeSettings{ MemoryReclaimerThreshold: &invalidThreshold, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.50)) // Should remain unchanged }) It("should accept valid memory threshold at boundary", func() { appConfig := &ApplicationConfig{} // Test threshold = 1.0 (maximum valid) threshold := 1.0 rs := &RuntimeSettings{ MemoryReclaimerThreshold: &threshold, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.MemoryReclaimerThreshold).To(Equal(1.0)) // Test threshold just above 0 threshold = 0.01 rs = &RuntimeSettings{ MemoryReclaimerThreshold: &threshold, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.MemoryReclaimerThreshold).To(Equal(0.01)) }) It("should apply performance settings without triggering watchdog change", func() { appConfig := &ApplicationConfig{} threads := 16 contextSize := 8192 f16 := true debug := true rs := &RuntimeSettings{ Threads: &threads, ContextSize: &contextSize, F16: &f16, Debug: &debug, } changed := appConfig.ApplyRuntimeSettings(rs) // These settings don't require watchdog restart Expect(changed).To(BeFalse()) Expect(appConfig.Threads).To(Equal(16)) Expect(appConfig.ContextSize).To(Equal(8192)) Expect(appConfig.F16).To(BeTrue()) Expect(appConfig.Debug).To(BeTrue()) }) It("should apply CORS and security settings", func() { appConfig := &ApplicationConfig{} cors := true csrf := true origins := "https://example.com,https://other.com" rs := &RuntimeSettings{ CORS: &cors, CSRF: &csrf, CORSAllowOrigins: &origins, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.CORS).To(BeTrue()) Expect(appConfig.DisableCSRF).To(BeTrue()) Expect(appConfig.CORSAllowOrigins).To(Equal("https://example.com,https://other.com")) }) It("should apply P2P settings", func() { appConfig := &ApplicationConfig{} token := "p2p-test-token" networkID := "p2p-test-network" federated := true rs := &RuntimeSettings{ P2PToken: &token, P2PNetworkID: &networkID, Federated: &federated, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.P2PToken).To(Equal("p2p-test-token")) Expect(appConfig.P2PNetworkID).To(Equal("p2p-test-network")) Expect(appConfig.Federated).To(BeTrue()) }) It("should apply gallery settings", func() { appConfig := &ApplicationConfig{} galleries := []Gallery{ {Name: "gallery1", URL: "https://gallery1.com"}, {Name: "gallery2", URL: "https://gallery2.com"}, } backendGalleries := []Gallery{ {Name: "backend-gallery", URL: "https://backend.com"}, } autoload := true autoloadBackend := true rs := &RuntimeSettings{ Galleries: &galleries, BackendGalleries: &backendGalleries, AutoloadGalleries: &autoload, AutoloadBackendGalleries: &autoloadBackend, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.Galleries).To(HaveLen(2)) Expect(appConfig.Galleries[0].Name).To(Equal("gallery1")) Expect(appConfig.BackendGalleries).To(HaveLen(1)) Expect(appConfig.AutoloadGalleries).To(BeTrue()) Expect(appConfig.AutoloadBackendGalleries).To(BeTrue()) }) It("should apply agent settings", func() { appConfig := &ApplicationConfig{} retentionDays := 14 rs := &RuntimeSettings{ AgentJobRetentionDays: &retentionDays, } appConfig.ApplyRuntimeSettings(rs) Expect(appConfig.AgentJobRetentionDays).To(Equal(14)) }) }) Describe("Round-trip conversion", func() { It("should maintain values through ToRuntimeSettings -> ApplyRuntimeSettings", func() { original := &ApplicationConfig{ WatchDog: true, WatchDogIdle: true, WatchDogBusy: false, WatchDogIdleTimeout: 25 * time.Minute, WatchDogBusyTimeout: 12 * time.Minute, SingleBackend: false, MaxActiveBackends: 3, ParallelBackendRequests: true, MemoryReclaimerEnabled: true, MemoryReclaimerThreshold: 0.92, Threads: 12, ContextSize: 6144, F16: true, Debug: false, CORS: true, DisableCSRF: false, CORSAllowOrigins: "https://test.com", P2PToken: "round-trip-token", P2PNetworkID: "round-trip-network", Federated: true, AutoloadGalleries: true, AutoloadBackendGalleries: false, AgentJobRetentionDays: 60, } // Convert to RuntimeSettings rs := original.ToRuntimeSettings() // Apply to a new ApplicationConfig target := &ApplicationConfig{} target.ApplyRuntimeSettings(&rs) // Verify all values match Expect(target.WatchDog).To(Equal(original.WatchDog)) Expect(target.WatchDogIdle).To(Equal(original.WatchDogIdle)) Expect(target.WatchDogBusy).To(Equal(original.WatchDogBusy)) Expect(target.WatchDogIdleTimeout).To(Equal(original.WatchDogIdleTimeout)) Expect(target.WatchDogBusyTimeout).To(Equal(original.WatchDogBusyTimeout)) Expect(target.MaxActiveBackends).To(Equal(original.MaxActiveBackends)) Expect(target.ParallelBackendRequests).To(Equal(original.ParallelBackendRequests)) Expect(target.MemoryReclaimerEnabled).To(Equal(original.MemoryReclaimerEnabled)) Expect(target.MemoryReclaimerThreshold).To(Equal(original.MemoryReclaimerThreshold)) Expect(target.Threads).To(Equal(original.Threads)) Expect(target.ContextSize).To(Equal(original.ContextSize)) Expect(target.F16).To(Equal(original.F16)) Expect(target.Debug).To(Equal(original.Debug)) Expect(target.CORS).To(Equal(original.CORS)) Expect(target.DisableCSRF).To(Equal(original.DisableCSRF)) Expect(target.CORSAllowOrigins).To(Equal(original.CORSAllowOrigins)) Expect(target.P2PToken).To(Equal(original.P2PToken)) Expect(target.P2PNetworkID).To(Equal(original.P2PNetworkID)) Expect(target.Federated).To(Equal(original.Federated)) Expect(target.AutoloadGalleries).To(Equal(original.AutoloadGalleries)) Expect(target.AutoloadBackendGalleries).To(Equal(original.AutoloadBackendGalleries)) Expect(target.AgentJobRetentionDays).To(Equal(original.AgentJobRetentionDays)) }) It("should handle empty galleries correctly in round-trip", func() { original := &ApplicationConfig{ Galleries: []Gallery{}, BackendGalleries: []Gallery{}, ApiKeys: []string{}, } rs := original.ToRuntimeSettings() target := &ApplicationConfig{} target.ApplyRuntimeSettings(&rs) Expect(target.Galleries).To(BeEmpty()) Expect(target.BackendGalleries).To(BeEmpty()) }) }) Describe("Edge cases", func() { It("should handle invalid timeout string in ApplyRuntimeSettings", func() { appConfig := &ApplicationConfig{ WatchDogIdleTimeout: 10 * time.Minute, } invalidTimeout := "not-a-duration" rs := &RuntimeSettings{ WatchdogIdleTimeout: &invalidTimeout, } appConfig.ApplyRuntimeSettings(rs) // Should remain unchanged due to parse error Expect(appConfig.WatchDogIdleTimeout).To(Equal(10 * time.Minute)) }) It("should handle zero values in ApplicationConfig", func() { appConfig := &ApplicationConfig{ // All zero values } rs := appConfig.ToRuntimeSettings() // Should still have non-nil pointers with zero/default values Expect(rs.WatchdogEnabled).ToNot(BeNil()) Expect(*rs.WatchdogEnabled).To(BeFalse()) Expect(rs.Threads).ToNot(BeNil()) Expect(*rs.Threads).To(Equal(0)) Expect(rs.MemoryReclaimerThreshold).ToNot(BeNil()) Expect(*rs.MemoryReclaimerThreshold).To(Equal(0.0)) }) It("should prefer MaxActiveBackends over SingleBackend when both are set", func() { appConfig := &ApplicationConfig{} maxBackends := 3 singleBackend := true rs := &RuntimeSettings{ MaxActiveBackends: &maxBackends, SingleBackend: &singleBackend, } appConfig.ApplyRuntimeSettings(rs) // MaxActiveBackends should take precedence Expect(appConfig.MaxActiveBackends).To(Equal(3)) Expect(appConfig.SingleBackend).To(BeFalse()) // 3 != 1, so single backend is false }) }) }) ================================================ FILE: core/config/config_suite_test.go ================================================ package config_test import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestConfig(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Config test suite") } ================================================ FILE: core/config/gallery.go ================================================ package config type Gallery struct { URL string `json:"url" yaml:"url"` Name string `json:"name" yaml:"name"` } ================================================ FILE: core/config/gguf.go ================================================ package config import ( "context" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/grpc" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/reasoning" "github.com/mudler/LocalAI/pkg/xsysinfo" "github.com/mudler/xlog" gguf "github.com/gpustack/gguf-parser-go" "github.com/gpustack/gguf-parser-go/util/ptr" ) const ( defaultContextSize = 1024 defaultNGPULayers = 99999999 ) func guessGGUFFromFile(cfg *ModelConfig, f *gguf.GGUFFile, defaultCtx int) { if defaultCtx == 0 && cfg.ContextSize == nil { ctxSize := f.EstimateLLaMACppRun().ContextSize if ctxSize > 0 { cSize := int(ctxSize) cfg.ContextSize = &cSize } else { defaultCtx = defaultContextSize cfg.ContextSize = &defaultCtx } } // GPU options if cfg.Options == nil { if xsysinfo.HasGPU("nvidia") || xsysinfo.HasGPU("amd") { cfg.Options = []string{"gpu"} } } if cfg.NGPULayers == nil { // we assume we want to offload all layers defaultHigh := defaultNGPULayers cfg.NGPULayers = &defaultHigh } xlog.Debug("[gguf] guessDefaultsFromFile: NGPULayers set", "NGPULayers", cfg.NGPULayers, "modelName", f.Metadata().Name) // identify from well known templates first, otherwise use the raw jinja template chatTemplate, found := f.Header.MetadataKV.Get("tokenizer.chat_template") if found { // fill jinja template cfg.modelTemplate = chatTemplate.ValueString() } // Thinking support detection is done after model load via DetectThinkingSupportFromBackend // template estimations if cfg.HasTemplate() { // nothing to guess here xlog.Debug("[gguf] guessDefaultsFromFile: template already set", "name", cfg.Name, "modelName", f.Metadata().Name) return } xlog.Debug("[gguf] Model file loaded", "file", cfg.ModelFileName(), "eosTokenID", f.Tokenizer().EOSTokenID, "bosTokenID", f.Tokenizer().BOSTokenID, "modelName", f.Metadata().Name, "architecture", f.Architecture().Architecture) // guess the name if cfg.Name == "" { cfg.Name = f.Metadata().Name } // Instruct to use template from llama.cpp cfg.TemplateConfig.UseTokenizerTemplate = true cfg.FunctionsConfig.GrammarConfig.NoGrammar = true cfg.Options = append(cfg.Options, "use_jinja:true") cfg.KnownUsecaseStrings = append(cfg.KnownUsecaseStrings, "FLAG_CHAT") } // DetectThinkingSupportFromBackend calls the ModelMetadata gRPC method to detect // if the model supports thinking mode and if the template ends with a thinking start token. // This should be called after the model is loaded. // The results are stored in cfg.SupportsThinking and cfg.ThinkingForcedOpen. func DetectThinkingSupportFromBackend(ctx context.Context, cfg *ModelConfig, backendClient grpc.Backend, modelOptions *pb.ModelOptions) { if backendClient == nil { xlog.Debug("[gguf] DetectThinkingSupportFromBackend: backend client is nil, skipping detection") return } if modelOptions == nil { xlog.Debug("[gguf] DetectThinkingSupportFromBackend: model options is nil, skipping detection") return } // Only detect for llama-cpp backend when using tokenizer templates if cfg.Backend != "llama-cpp" || !cfg.TemplateConfig.UseTokenizerTemplate { xlog.Debug("[gguf] DetectThinkingSupportFromBackend: skipping detection", "backend", cfg.Backend, "useTokenizerTemplate", cfg.TemplateConfig.UseTokenizerTemplate) return } metadata, err := backendClient.ModelMetadata(ctx, modelOptions) if err != nil { xlog.Warn("[gguf] DetectThinkingSupportFromBackend: failed to get model metadata", "error", err) return } if metadata != nil { cfg.ReasoningConfig.DisableReasoning = ptr.To(!metadata.SupportsThinking) // Use the rendered template to detect if thinking token is at the end // This reuses the existing DetectThinkingStartToken function if metadata.RenderedTemplate != "" { thinkingStartToken := reasoning.DetectThinkingStartToken(metadata.RenderedTemplate, &cfg.ReasoningConfig) thinkingForcedOpen := thinkingStartToken != "" cfg.ReasoningConfig.DisableReasoningTagPrefill = ptr.To(!thinkingForcedOpen) xlog.Debug("[gguf] DetectThinkingSupportFromBackend: thinking support detected", "supports_thinking", metadata.SupportsThinking, "thinking_forced_open", thinkingForcedOpen, "thinking_start_token", thinkingStartToken) } else { cfg.ReasoningConfig.DisableReasoningTagPrefill = ptr.To(true) xlog.Debug("[gguf] DetectThinkingSupportFromBackend: thinking support detected", "supports_thinking", metadata.SupportsThinking, "thinking_forced_open", false) } // Extract tool format markers from autoparser analysis if tf := metadata.GetToolFormat(); tf != nil && tf.FormatType != "" { cfg.FunctionsConfig.ToolFormatMarkers = &functions.ToolFormatMarkers{ FormatType: tf.FormatType, SectionStart: tf.SectionStart, SectionEnd: tf.SectionEnd, PerCallStart: tf.PerCallStart, PerCallEnd: tf.PerCallEnd, FuncNamePrefix: tf.FuncNamePrefix, FuncNameSuffix: tf.FuncNameSuffix, FuncClose: tf.FuncClose, ArgNamePrefix: tf.ArgNamePrefix, ArgNameSuffix: tf.ArgNameSuffix, ArgValuePrefix: tf.ArgValuePrefix, ArgValueSuffix: tf.ArgValueSuffix, ArgSeparator: tf.ArgSeparator, ArgsStart: tf.ArgsStart, ArgsEnd: tf.ArgsEnd, NameField: tf.NameField, ArgsField: tf.ArgsField, IDField: tf.IdField, FunNameIsKey: tf.FunNameIsKey, ToolsArrayWrapped: tf.ToolsArrayWrapped, UsesPythonDicts: tf.UsesPythonDicts, FunctionField: tf.FunctionField, ParameterOrder: tf.ParameterOrder, GenIDField: tf.GenIdField, CallIDPosition: tf.CallIdPosition, CallIDPrefix: tf.CallIdPrefix, CallIDSuffix: tf.CallIdSuffix, ReasoningStart: tf.ReasoningStart, ReasoningEnd: tf.ReasoningEnd, ContentStart: tf.ContentStart, ContentEnd: tf.ContentEnd, } xlog.Debug("[gguf] DetectThinkingSupportFromBackend: tool format markers detected", "format_type", tf.FormatType, "section_start", tf.SectionStart, "func_name_prefix", tf.FuncNamePrefix) } } } ================================================ FILE: core/config/guesser.go ================================================ package config import ( "os" "path/filepath" gguf "github.com/gpustack/gguf-parser-go" "github.com/mudler/xlog" ) func guessDefaultsFromFile(cfg *ModelConfig, modelPath string, defaultCtx int) { if os.Getenv("LOCALAI_DISABLE_GUESSING") == "true" { xlog.Debug("guessDefaultsFromFile: guessing disabled with LOCALAI_DISABLE_GUESSING") return } if modelPath == "" { xlog.Debug("guessDefaultsFromFile: modelPath is empty") return } // We try to guess only if we don't have a template defined already guessPath := filepath.Join(modelPath, cfg.ModelFileName()) defer func() { if r := recover(); r != nil { xlog.Error("guessDefaultsFromFile: panic while parsing gguf file") } }() defer func() { if cfg.ContextSize == nil { if defaultCtx == 0 { defaultCtx = defaultContextSize } cfg.ContextSize = &defaultCtx } }() // try to parse the gguf file f, err := gguf.ParseGGUFFile(guessPath) if err == nil { guessGGUFFromFile(cfg, f, defaultCtx) return } } ================================================ FILE: core/config/model_config.go ================================================ package config import ( "fmt" "os" "regexp" "slices" "strings" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/reasoning" "github.com/mudler/cogito" "gopkg.in/yaml.v3" ) const ( RAND_SEED = -1 ) // @Description TTS configuration type TTSConfig struct { // Voice wav path or id Voice string `yaml:"voice,omitempty" json:"voice,omitempty"` AudioPath string `yaml:"audio_path,omitempty" json:"audio_path,omitempty"` } // @Description ModelConfig represents a model configuration type ModelConfig struct { modelConfigFile string `yaml:"-" json:"-"` modelTemplate string `yaml:"-" json:"-"` schema.PredictionOptions `yaml:"parameters,omitempty" json:"parameters,omitempty"` Name string `yaml:"name,omitempty" json:"name,omitempty"` F16 *bool `yaml:"f16,omitempty" json:"f16,omitempty"` Threads *int `yaml:"threads,omitempty" json:"threads,omitempty"` Debug *bool `yaml:"debug,omitempty" json:"debug,omitempty"` Roles map[string]string `yaml:"roles,omitempty" json:"roles,omitempty"` Embeddings *bool `yaml:"embeddings,omitempty" json:"embeddings,omitempty"` Backend string `yaml:"backend,omitempty" json:"backend,omitempty"` TemplateConfig TemplateConfig `yaml:"template,omitempty" json:"template,omitempty"` KnownUsecaseStrings []string `yaml:"known_usecases,omitempty" json:"known_usecases,omitempty"` KnownUsecases *ModelConfigUsecase `yaml:"-" json:"-"` Pipeline Pipeline `yaml:"pipeline,omitempty" json:"pipeline,omitempty"` PromptStrings, InputStrings []string `yaml:"-" json:"-"` InputToken [][]int `yaml:"-" json:"-"` functionCallString, functionCallNameString string `yaml:"-" json:"-"` ResponseFormat string `yaml:"-" json:"-"` ResponseFormatMap map[string]interface{} `yaml:"-" json:"-"` FunctionsConfig functions.FunctionsConfig `yaml:"function,omitempty" json:"function,omitempty"` ReasoningConfig reasoning.Config `yaml:"reasoning,omitempty" json:"reasoning,omitempty"` FeatureFlag FeatureFlag `yaml:"feature_flags,omitempty" json:"feature_flags,omitempty"` // Feature Flag registry. We move fast, and features may break on a per model/backend basis. Registry for (usually temporary) flags that indicate aborting something early. // LLM configs (GPT4ALL, Llama.cpp, ...) LLMConfig `yaml:",inline" json:",inline"` // Diffusers Diffusers Diffusers `yaml:"diffusers,omitempty" json:"diffusers,omitempty"` Step int `yaml:"step,omitempty" json:"step,omitempty"` // GRPC Options GRPC GRPC `yaml:"grpc,omitempty" json:"grpc,omitempty"` // TTS specifics TTSConfig `yaml:"tts,omitempty" json:"tts,omitempty"` // CUDA // Explicitly enable CUDA or not (some backends might need it) CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"` DownloadFiles []File `yaml:"download_files,omitempty" json:"download_files,omitempty"` Description string `yaml:"description,omitempty" json:"description,omitempty"` Usage string `yaml:"usage,omitempty" json:"usage,omitempty"` Options []string `yaml:"options,omitempty" json:"options,omitempty"` Overrides []string `yaml:"overrides,omitempty" json:"overrides,omitempty"` MCP MCPConfig `yaml:"mcp,omitempty" json:"mcp,omitempty"` Agent AgentConfig `yaml:"agent,omitempty" json:"agent,omitempty"` } // @Description MCP configuration type MCPConfig struct { Servers string `yaml:"remote,omitempty" json:"remote,omitempty"` Stdio string `yaml:"stdio,omitempty" json:"stdio,omitempty"` } // @Description Agent configuration type AgentConfig struct { MaxAttempts int `yaml:"max_attempts,omitempty" json:"max_attempts,omitempty"` MaxIterations int `yaml:"max_iterations,omitempty" json:"max_iterations,omitempty"` EnableReasoning bool `yaml:"enable_reasoning,omitempty" json:"enable_reasoning,omitempty"` EnablePlanning bool `yaml:"enable_planning,omitempty" json:"enable_planning,omitempty"` EnableMCPPrompts bool `yaml:"enable_mcp_prompts,omitempty" json:"enable_mcp_prompts,omitempty"` EnablePlanReEvaluator bool `yaml:"enable_plan_re_evaluator,omitempty" json:"enable_plan_re_evaluator,omitempty"` DisableSinkState bool `yaml:"disable_sink_state,omitempty" json:"disable_sink_state,omitempty"` LoopDetection int `yaml:"loop_detection,omitempty" json:"loop_detection,omitempty"` MaxAdjustmentAttempts int `yaml:"max_adjustment_attempts,omitempty" json:"max_adjustment_attempts,omitempty"` ForceReasoningTool bool `yaml:"force_reasoning_tool,omitempty" json:"force_reasoning_tool,omitempty"` } func (c *MCPConfig) MCPConfigFromYAML() (MCPGenericConfig[MCPRemoteServers], MCPGenericConfig[MCPSTDIOServers], error) { var remote MCPGenericConfig[MCPRemoteServers] var stdio MCPGenericConfig[MCPSTDIOServers] if err := yaml.Unmarshal([]byte(c.Servers), &remote); err != nil { return remote, stdio, err } if err := yaml.Unmarshal([]byte(c.Stdio), &stdio); err != nil { return remote, stdio, err } return remote, stdio, nil } // @Description MCP generic configuration type MCPGenericConfig[T any] struct { Servers T `yaml:"mcpServers,omitempty" json:"mcpServers,omitempty"` } type MCPRemoteServers map[string]MCPRemoteServer type MCPSTDIOServers map[string]MCPSTDIOServer // @Description MCP remote server configuration type MCPRemoteServer struct { URL string `json:"url,omitempty"` Token string `json:"token,omitempty"` } // @Description MCP STDIO server configuration type MCPSTDIOServer struct { Args []string `json:"args,omitempty"` Env map[string]string `json:"env,omitempty"` Command string `json:"command,omitempty"` } // @Description Pipeline defines other models to use for audio-to-audio type Pipeline struct { TTS string `yaml:"tts,omitempty" json:"tts,omitempty"` LLM string `yaml:"llm,omitempty" json:"llm,omitempty"` Transcription string `yaml:"transcription,omitempty" json:"transcription,omitempty"` VAD string `yaml:"vad,omitempty" json:"vad,omitempty"` } // @Description File configuration for model downloads type File struct { Filename string `yaml:"filename,omitempty" json:"filename,omitempty"` SHA256 string `yaml:"sha256,omitempty" json:"sha256,omitempty"` URI downloader.URI `yaml:"uri,omitempty" json:"uri,omitempty"` } type FeatureFlag map[string]*bool func (ff FeatureFlag) Enabled(s string) bool { if v, exists := ff[s]; exists && v != nil { return *v } return false } // @Description GRPC configuration type GRPC struct { Attempts int `yaml:"attempts,omitempty" json:"attempts,omitempty"` AttemptsSleepTime int `yaml:"attempts_sleep_time,omitempty" json:"attempts_sleep_time,omitempty"` } // @Description Diffusers configuration type Diffusers struct { CUDA bool `yaml:"cuda,omitempty" json:"cuda,omitempty"` PipelineType string `yaml:"pipeline_type,omitempty" json:"pipeline_type,omitempty"` SchedulerType string `yaml:"scheduler_type,omitempty" json:"scheduler_type,omitempty"` EnableParameters string `yaml:"enable_parameters,omitempty" json:"enable_parameters,omitempty"` // A list of comma separated parameters to specify IMG2IMG bool `yaml:"img2img,omitempty" json:"img2img,omitempty"` // Image to Image Diffuser ClipSkip int `yaml:"clip_skip,omitempty" json:"clip_skip,omitempty"` // Skip every N frames ClipModel string `yaml:"clip_model,omitempty" json:"clip_model,omitempty"` // Clip model to use ClipSubFolder string `yaml:"clip_subfolder,omitempty" json:"clip_subfolder,omitempty"` // Subfolder to use for clip model ControlNet string `yaml:"control_net,omitempty" json:"control_net,omitempty"` } // @Description LLMConfig is a struct that holds the configuration that are generic for most of the LLM backends. type LLMConfig struct { SystemPrompt string `yaml:"system_prompt,omitempty" json:"system_prompt,omitempty"` TensorSplit string `yaml:"tensor_split,omitempty" json:"tensor_split,omitempty"` MainGPU string `yaml:"main_gpu,omitempty" json:"main_gpu,omitempty"` RMSNormEps float32 `yaml:"rms_norm_eps,omitempty" json:"rms_norm_eps,omitempty"` NGQA int32 `yaml:"ngqa,omitempty" json:"ngqa,omitempty"` PromptCachePath string `yaml:"prompt_cache_path,omitempty" json:"prompt_cache_path,omitempty"` PromptCacheAll bool `yaml:"prompt_cache_all,omitempty" json:"prompt_cache_all,omitempty"` PromptCacheRO bool `yaml:"prompt_cache_ro,omitempty" json:"prompt_cache_ro,omitempty"` MirostatETA *float64 `yaml:"mirostat_eta,omitempty" json:"mirostat_eta,omitempty"` MirostatTAU *float64 `yaml:"mirostat_tau,omitempty" json:"mirostat_tau,omitempty"` Mirostat *int `yaml:"mirostat,omitempty" json:"mirostat,omitempty"` NGPULayers *int `yaml:"gpu_layers,omitempty" json:"gpu_layers,omitempty"` MMap *bool `yaml:"mmap,omitempty" json:"mmap,omitempty"` MMlock *bool `yaml:"mmlock,omitempty" json:"mmlock,omitempty"` LowVRAM *bool `yaml:"low_vram,omitempty" json:"low_vram,omitempty"` Reranking *bool `yaml:"reranking,omitempty" json:"reranking,omitempty"` Grammar string `yaml:"grammar,omitempty" json:"grammar,omitempty"` StopWords []string `yaml:"stopwords,omitempty" json:"stopwords,omitempty"` Cutstrings []string `yaml:"cutstrings,omitempty" json:"cutstrings,omitempty"` ExtractRegex []string `yaml:"extract_regex,omitempty" json:"extract_regex,omitempty"` TrimSpace []string `yaml:"trimspace,omitempty" json:"trimspace,omitempty"` TrimSuffix []string `yaml:"trimsuffix,omitempty" json:"trimsuffix,omitempty"` ContextSize *int `yaml:"context_size,omitempty" json:"context_size,omitempty"` NUMA bool `yaml:"numa,omitempty" json:"numa,omitempty"` LoraAdapter string `yaml:"lora_adapter,omitempty" json:"lora_adapter,omitempty"` LoraBase string `yaml:"lora_base,omitempty" json:"lora_base,omitempty"` LoraAdapters []string `yaml:"lora_adapters,omitempty" json:"lora_adapters,omitempty"` LoraScales []float32 `yaml:"lora_scales,omitempty" json:"lora_scales,omitempty"` LoraScale float32 `yaml:"lora_scale,omitempty" json:"lora_scale,omitempty"` NoMulMatQ bool `yaml:"no_mulmatq,omitempty" json:"no_mulmatq,omitempty"` DraftModel string `yaml:"draft_model,omitempty" json:"draft_model,omitempty"` NDraft int32 `yaml:"n_draft,omitempty" json:"n_draft,omitempty"` Quantization string `yaml:"quantization,omitempty" json:"quantization,omitempty"` LoadFormat string `yaml:"load_format,omitempty" json:"load_format,omitempty"` GPUMemoryUtilization float32 `yaml:"gpu_memory_utilization,omitempty" json:"gpu_memory_utilization,omitempty"` // vLLM TrustRemoteCode bool `yaml:"trust_remote_code,omitempty" json:"trust_remote_code,omitempty"` // vLLM EnforceEager bool `yaml:"enforce_eager,omitempty" json:"enforce_eager,omitempty"` // vLLM SwapSpace int `yaml:"swap_space,omitempty" json:"swap_space,omitempty"` // vLLM MaxModelLen int `yaml:"max_model_len,omitempty" json:"max_model_len,omitempty"` // vLLM TensorParallelSize int `yaml:"tensor_parallel_size,omitempty" json:"tensor_parallel_size,omitempty"` // vLLM DisableLogStatus bool `yaml:"disable_log_stats,omitempty" json:"disable_log_stats,omitempty"` // vLLM DType string `yaml:"dtype,omitempty" json:"dtype,omitempty"` // vLLM LimitMMPerPrompt LimitMMPerPrompt `yaml:"limit_mm_per_prompt,omitempty" json:"limit_mm_per_prompt,omitempty"` // vLLM MMProj string `yaml:"mmproj,omitempty" json:"mmproj,omitempty"` FlashAttention *string `yaml:"flash_attention,omitempty" json:"flash_attention,omitempty"` NoKVOffloading bool `yaml:"no_kv_offloading,omitempty" json:"no_kv_offloading,omitempty"` CacheTypeK string `yaml:"cache_type_k,omitempty" json:"cache_type_k,omitempty"` CacheTypeV string `yaml:"cache_type_v,omitempty" json:"cache_type_v,omitempty"` RopeScaling string `yaml:"rope_scaling,omitempty" json:"rope_scaling,omitempty"` ModelType string `yaml:"type,omitempty" json:"type,omitempty"` YarnExtFactor float32 `yaml:"yarn_ext_factor,omitempty" json:"yarn_ext_factor,omitempty"` YarnAttnFactor float32 `yaml:"yarn_attn_factor,omitempty" json:"yarn_attn_factor,omitempty"` YarnBetaFast float32 `yaml:"yarn_beta_fast,omitempty" json:"yarn_beta_fast,omitempty"` YarnBetaSlow float32 `yaml:"yarn_beta_slow,omitempty" json:"yarn_beta_slow,omitempty"` CFGScale float32 `yaml:"cfg_scale,omitempty" json:"cfg_scale,omitempty"` // Classifier-Free Guidance Scale } // @Description LimitMMPerPrompt is a struct that holds the configuration for the limit-mm-per-prompt config in vLLM type LimitMMPerPrompt struct { LimitImagePerPrompt int `yaml:"image,omitempty" json:"image,omitempty"` LimitVideoPerPrompt int `yaml:"video,omitempty" json:"video,omitempty"` LimitAudioPerPrompt int `yaml:"audio,omitempty" json:"audio,omitempty"` } // @Description TemplateConfig is a struct that holds the configuration of the templating system type TemplateConfig struct { // Chat is the template used in the chat completion endpoint Chat string `yaml:"chat,omitempty" json:"chat,omitempty"` // ChatMessage is the template used for chat messages ChatMessage string `yaml:"chat_message,omitempty" json:"chat_message,omitempty"` // Completion is the template used for completion requests Completion string `yaml:"completion,omitempty" json:"completion,omitempty"` // Edit is the template used for edit completion requests Edit string `yaml:"edit,omitempty" json:"edit,omitempty"` // Functions is the template used when tools are present in the client requests Functions string `yaml:"function,omitempty" json:"function,omitempty"` // UseTokenizerTemplate is a flag that indicates if the tokenizer template should be used. // Note: this is mostly consumed for backends such as vllm and transformers // that can use the tokenizers specified in the JSON config files of the models UseTokenizerTemplate bool `yaml:"use_tokenizer_template,omitempty" json:"use_tokenizer_template,omitempty"` // JoinChatMessagesByCharacter is a string that will be used to join chat messages together. // It defaults to \n JoinChatMessagesByCharacter *string `yaml:"join_chat_messages_by_character,omitempty" json:"join_chat_messages_by_character,omitempty"` Multimodal string `yaml:"multimodal,omitempty" json:"multimodal,omitempty"` ReplyPrefix string `yaml:"reply_prefix,omitempty" json:"reply_prefix,omitempty"` } func (c *ModelConfig) syncKnownUsecasesFromString() { c.KnownUsecases = GetUsecasesFromYAML(c.KnownUsecaseStrings) // Make sure the usecases are valid, we rewrite with what we identified c.KnownUsecaseStrings = []string{} for k, usecase := range GetAllModelConfigUsecases() { if c.HasUsecases(usecase) { c.KnownUsecaseStrings = append(c.KnownUsecaseStrings, k) } } } func (c *ModelConfig) UnmarshalYAML(value *yaml.Node) error { type BCAlias ModelConfig var aux BCAlias if err := value.Decode(&aux); err != nil { return err } mc := ModelConfig(aux) *c = mc c.syncKnownUsecasesFromString() return nil } func (c *ModelConfig) SetFunctionCallString(s string) { c.functionCallString = s } func (c *ModelConfig) SetFunctionCallNameString(s string) { c.functionCallNameString = s } func (c *ModelConfig) ShouldUseFunctions() bool { return ((c.functionCallString != "none" || c.functionCallString == "") || c.ShouldCallSpecificFunction()) } func (c *ModelConfig) ShouldCallSpecificFunction() bool { return len(c.functionCallNameString) > 0 } // MMProjFileName returns the filename of the MMProj file // If the MMProj is a URL, it will return the MD5 of the URL which is the filename func (c *ModelConfig) MMProjFileName() string { uri := downloader.URI(c.MMProj) if uri.LooksLikeURL() { f, _ := uri.FilenameFromUrl() return f } return c.MMProj } func (c *ModelConfig) IsMMProjURL() bool { uri := downloader.URI(c.MMProj) return uri.LooksLikeURL() } func (c *ModelConfig) IsModelURL() bool { uri := downloader.URI(c.Model) return uri.LooksLikeURL() } // ModelFileName returns the filename of the model // If the model is a URL, it will return the MD5 of the URL which is the filename func (c *ModelConfig) ModelFileName() string { uri := downloader.URI(c.Model) if uri.LooksLikeURL() { f, _ := uri.FilenameFromUrl() return f } return c.Model } func (c *ModelConfig) FunctionToCall() string { if c.functionCallNameString != "" && c.functionCallNameString != "none" && c.functionCallNameString != "auto" { return c.functionCallNameString } return c.functionCallString } func (cfg *ModelConfig) SetDefaults(opts ...ConfigLoaderOption) { lo := &LoadOptions{} lo.Apply(opts...) ctx := lo.ctxSize threads := lo.threads f16 := lo.f16 debug := lo.debug // https://github.com/ggerganov/llama.cpp/blob/75cd4c77292034ecec587ecb401366f57338f7c0/common/sampling.h#L22 defaultTopP := 0.95 defaultTopK := 40 defaultTemp := 0.9 // https://github.com/mudler/LocalAI/issues/2780 defaultMirostat := 0 defaultMirostatTAU := 5.0 defaultMirostatETA := 0.1 defaultTypicalP := 1.0 defaultTFZ := 1.0 defaultZero := 0 trueV := true falseV := false if cfg.Seed == nil { // random number generator seed defaultSeed := RAND_SEED cfg.Seed = &defaultSeed } if cfg.TopK == nil { cfg.TopK = &defaultTopK } if cfg.TypicalP == nil { cfg.TypicalP = &defaultTypicalP } if cfg.TFZ == nil { cfg.TFZ = &defaultTFZ } if cfg.MMap == nil { // MMap is enabled by default // Only exception is for Intel GPUs if os.Getenv("XPU") != "" { cfg.MMap = &falseV } else { cfg.MMap = &trueV } } if cfg.MMlock == nil { // MMlock is disabled by default cfg.MMlock = &falseV } if cfg.TopP == nil { cfg.TopP = &defaultTopP } if cfg.Temperature == nil { cfg.Temperature = &defaultTemp } if cfg.Maxtokens == nil { cfg.Maxtokens = &defaultZero } if cfg.Mirostat == nil { cfg.Mirostat = &defaultMirostat } if cfg.MirostatETA == nil { cfg.MirostatETA = &defaultMirostatETA } if cfg.MirostatTAU == nil { cfg.MirostatTAU = &defaultMirostatTAU } if cfg.LowVRAM == nil { cfg.LowVRAM = &falseV } if cfg.Embeddings == nil { cfg.Embeddings = &falseV } if cfg.Reranking == nil { cfg.Reranking = &falseV } if threads == 0 { // Threads can't be 0 threads = 4 } if cfg.Threads == nil { cfg.Threads = &threads } if cfg.F16 == nil { cfg.F16 = &f16 } if cfg.Debug == nil { cfg.Debug = &falseV } if debug { cfg.Debug = &trueV } guessDefaultsFromFile(cfg, lo.modelPath, ctx) cfg.syncKnownUsecasesFromString() } func (c *ModelConfig) Validate() (bool, error) { downloadedFileNames := []string{} for _, f := range c.DownloadFiles { downloadedFileNames = append(downloadedFileNames, f.Filename) } validationTargets := []string{c.Backend, c.Model, c.MMProj} validationTargets = append(validationTargets, downloadedFileNames...) // Simple validation to make sure the model can be correctly loaded for _, n := range validationTargets { if n == "" { continue } if strings.HasPrefix(n, string(os.PathSeparator)) || strings.Contains(n, "..") { return false, fmt.Errorf("invalid file path: %s", n) } } if c.Backend != "" { // a regex that checks that is a string name with no special characters, except '-' and '_' re := regexp.MustCompile(`^[a-zA-Z0-9-_]+$`) if !re.MatchString(c.Backend) { return false, fmt.Errorf("invalid backend name: %s", c.Backend) } } // Validate MCP configuration if present if c.MCP.Servers != "" || c.MCP.Stdio != "" { if _, _, err := c.MCP.MCPConfigFromYAML(); err != nil { return false, fmt.Errorf("invalid MCP configuration: %w", err) } } return true, nil } func (c *ModelConfig) HasTemplate() bool { return c.TemplateConfig.Completion != "" || c.TemplateConfig.Edit != "" || c.TemplateConfig.Chat != "" || c.TemplateConfig.ChatMessage != "" || c.TemplateConfig.UseTokenizerTemplate } func (c *ModelConfig) GetModelConfigFile() string { return c.modelConfigFile } // GetModelTemplate returns the model's chat template if available func (c *ModelConfig) GetModelTemplate() string { return c.modelTemplate } type ModelConfigUsecase int const ( FLAG_ANY ModelConfigUsecase = 0b000000000000 FLAG_CHAT ModelConfigUsecase = 0b000000000001 FLAG_COMPLETION ModelConfigUsecase = 0b000000000010 FLAG_EDIT ModelConfigUsecase = 0b000000000100 FLAG_EMBEDDINGS ModelConfigUsecase = 0b000000001000 FLAG_RERANK ModelConfigUsecase = 0b000000010000 FLAG_IMAGE ModelConfigUsecase = 0b000000100000 FLAG_TRANSCRIPT ModelConfigUsecase = 0b000001000000 FLAG_TTS ModelConfigUsecase = 0b000010000000 FLAG_SOUND_GENERATION ModelConfigUsecase = 0b000100000000 FLAG_TOKENIZE ModelConfigUsecase = 0b001000000000 FLAG_VAD ModelConfigUsecase = 0b010000000000 FLAG_VIDEO ModelConfigUsecase = 0b100000000000 FLAG_DETECTION ModelConfigUsecase = 0b1000000000000 // Common Subsets FLAG_LLM ModelConfigUsecase = FLAG_CHAT | FLAG_COMPLETION | FLAG_EDIT ) func GetAllModelConfigUsecases() map[string]ModelConfigUsecase { return map[string]ModelConfigUsecase{ // Note: FLAG_ANY is intentionally excluded from this map // because it's 0 and would always match in HasUsecases checks "FLAG_CHAT": FLAG_CHAT, "FLAG_COMPLETION": FLAG_COMPLETION, "FLAG_EDIT": FLAG_EDIT, "FLAG_EMBEDDINGS": FLAG_EMBEDDINGS, "FLAG_RERANK": FLAG_RERANK, "FLAG_IMAGE": FLAG_IMAGE, "FLAG_TRANSCRIPT": FLAG_TRANSCRIPT, "FLAG_TTS": FLAG_TTS, "FLAG_SOUND_GENERATION": FLAG_SOUND_GENERATION, "FLAG_TOKENIZE": FLAG_TOKENIZE, "FLAG_VAD": FLAG_VAD, "FLAG_LLM": FLAG_LLM, "FLAG_VIDEO": FLAG_VIDEO, "FLAG_DETECTION": FLAG_DETECTION, } } func stringToFlag(s string) string { return "FLAG_" + strings.ToUpper(s) } func GetUsecasesFromYAML(input []string) *ModelConfigUsecase { if len(input) == 0 { return nil } result := FLAG_ANY flags := GetAllModelConfigUsecases() for _, str := range input { for _, flag := range []string{stringToFlag(str), str} { f, exists := flags[flag] if exists { result |= f } } } return &result } // HasUsecases examines a ModelConfig and determines which endpoints have a chance of success. func (c *ModelConfig) HasUsecases(u ModelConfigUsecase) bool { if (c.KnownUsecases != nil) && ((u & *c.KnownUsecases) == u) { return true } return c.GuessUsecases(u) } // GuessUsecases is a **heuristic based** function, as the backend in question may not be loaded yet, and the config may not record what it's useful at. // In its current state, this function should ideally check for properties of the config like templates, rather than the direct backend name checks for the lower half. // This avoids the maintenance burden of updating this list for each new backend - but unfortunately, that's the best option for some services currently. func (c *ModelConfig) GuessUsecases(u ModelConfigUsecase) bool { if (u & FLAG_CHAT) == FLAG_CHAT { if c.TemplateConfig.Chat == "" && c.TemplateConfig.ChatMessage == "" && !c.TemplateConfig.UseTokenizerTemplate { return false } } if (u & FLAG_COMPLETION) == FLAG_COMPLETION { if c.TemplateConfig.Completion == "" { return false } } if (u & FLAG_EDIT) == FLAG_EDIT { if c.TemplateConfig.Edit == "" { return false } } if (u & FLAG_EMBEDDINGS) == FLAG_EMBEDDINGS { if c.Embeddings == nil || !*c.Embeddings { return false } } if (u & FLAG_IMAGE) == FLAG_IMAGE { imageBackends := []string{"diffusers", "stablediffusion", "stablediffusion-ggml"} if !slices.Contains(imageBackends, c.Backend) { return false } if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" { return false } } if (u & FLAG_VIDEO) == FLAG_VIDEO { videoBackends := []string{"diffusers", "stablediffusion", "vllm-omni"} if !slices.Contains(videoBackends, c.Backend) { return false } if c.Backend == "diffusers" && c.Diffusers.PipelineType == "" { return false } } if (u & FLAG_RERANK) == FLAG_RERANK { if c.Backend != "rerankers" && (c.Reranking == nil || !*c.Reranking) { return false } } if (u & FLAG_TRANSCRIPT) == FLAG_TRANSCRIPT { if c.Backend != "whisper" { return false } } if (u & FLAG_TTS) == FLAG_TTS { ttsBackends := []string{"piper", "transformers-musicgen", "kokoro"} if !slices.Contains(ttsBackends, c.Backend) { return false } } if (u & FLAG_DETECTION) == FLAG_DETECTION { if c.Backend != "rfdetr" { return false } } if (u & FLAG_SOUND_GENERATION) == FLAG_SOUND_GENERATION { soundGenBackends := []string{"transformers-musicgen", "ace-step", "acestep-cpp", "mock-backend"} if !slices.Contains(soundGenBackends, c.Backend) { return false } } if (u & FLAG_TOKENIZE) == FLAG_TOKENIZE { tokenizeCapableBackends := []string{"llama.cpp", "rwkv"} if !slices.Contains(tokenizeCapableBackends, c.Backend) { return false } } if (u & FLAG_VAD) == FLAG_VAD { if c.Backend != "silero-vad" { return false } } return true } // BuildCogitoOptions generates cogito options from the model configuration // It accepts a context, MCP sessions, and optional callback functions for status, reasoning, tool calls, and tool results func (c *ModelConfig) BuildCogitoOptions() []cogito.Option { cogitoOpts := []cogito.Option{ cogito.WithIterations(3), // default to 3 iterations cogito.WithMaxAttempts(3), // default to 3 attempts cogito.WithForceReasoning(), } // Apply agent configuration options if c.Agent.EnableReasoning { cogitoOpts = append(cogitoOpts, cogito.WithForceReasoning()) } if c.Agent.EnablePlanning { cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlan) } if c.Agent.EnableMCPPrompts { cogitoOpts = append(cogitoOpts, cogito.EnableMCPPrompts) } if c.Agent.EnablePlanReEvaluator { cogitoOpts = append(cogitoOpts, cogito.EnableAutoPlanReEvaluator) } if c.Agent.MaxIterations != 0 { cogitoOpts = append(cogitoOpts, cogito.WithIterations(c.Agent.MaxIterations)) } if c.Agent.MaxAttempts != 0 { cogitoOpts = append(cogitoOpts, cogito.WithMaxAttempts(c.Agent.MaxAttempts)) } if c.Agent.DisableSinkState { cogitoOpts = append(cogitoOpts, cogito.DisableSinkState) } if c.Agent.LoopDetection != 0 { cogitoOpts = append(cogitoOpts, cogito.WithLoopDetection(c.Agent.LoopDetection)) } if c.Agent.MaxAdjustmentAttempts != 0 { cogitoOpts = append(cogitoOpts, cogito.WithMaxAdjustmentAttempts(c.Agent.MaxAdjustmentAttempts)) } if c.Agent.ForceReasoningTool { cogitoOpts = append(cogitoOpts, cogito.WithForceReasoningTool()) } return cogitoOpts } ================================================ FILE: core/config/model_config_filter.go ================================================ package config import "regexp" type ModelConfigFilterFn func(string, *ModelConfig) bool func NoFilterFn(_ string, _ *ModelConfig) bool { return true } func BuildNameFilterFn(filter string) (ModelConfigFilterFn, error) { if filter == "" { return NoFilterFn, nil } rxp, err := regexp.Compile(filter) if err != nil { return nil, err } return func(name string, config *ModelConfig) bool { if config != nil { return rxp.MatchString(config.Name) } return rxp.MatchString(name) }, nil } func BuildUsecaseFilterFn(usecases ModelConfigUsecase) ModelConfigFilterFn { if usecases == FLAG_ANY { return NoFilterFn } return func(name string, config *ModelConfig) bool { if config == nil { return false // TODO: Potentially make this a param, for now, no known usecase to include } return config.HasUsecases(usecases) } } ================================================ FILE: core/config/model_config_loader.go ================================================ package config import ( "errors" "fmt" "io/fs" "os" "path/filepath" "sort" "strings" "sync" "github.com/charmbracelet/glamour" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/xlog" "gopkg.in/yaml.v3" ) type ModelConfigLoader struct { configs map[string]ModelConfig modelPath string sync.Mutex } func NewModelConfigLoader(modelPath string) *ModelConfigLoader { return &ModelConfigLoader{ configs: make(map[string]ModelConfig), modelPath: modelPath, } } type LoadOptions struct { modelPath string debug bool threads, ctxSize int f16 bool } func LoadOptionDebug(debug bool) ConfigLoaderOption { return func(o *LoadOptions) { o.debug = debug } } func LoadOptionThreads(threads int) ConfigLoaderOption { return func(o *LoadOptions) { o.threads = threads } } func LoadOptionContextSize(ctxSize int) ConfigLoaderOption { return func(o *LoadOptions) { o.ctxSize = ctxSize } } func ModelPath(modelPath string) ConfigLoaderOption { return func(o *LoadOptions) { o.modelPath = modelPath } } func LoadOptionF16(f16 bool) ConfigLoaderOption { return func(o *LoadOptions) { o.f16 = f16 } } type ConfigLoaderOption func(*LoadOptions) func (lo *LoadOptions) Apply(options ...ConfigLoaderOption) { for _, l := range options { l(lo) } } // readModelConfigsFromFile reads a config file that may contain either a single // ModelConfig or an array of ModelConfigs. It tries to unmarshal as an array first, // then falls back to a single config if that fails. func readModelConfigsFromFile(file string, opts ...ConfigLoaderOption) ([]*ModelConfig, error) { f, err := os.ReadFile(file) if err != nil { return nil, fmt.Errorf("readModelConfigsFromFile cannot read config file %q: %w", file, err) } // Try to unmarshal as array first var configs []*ModelConfig if err := yaml.Unmarshal(f, &configs); err == nil && len(configs) > 0 { for _, cc := range configs { cc.modelConfigFile = file cc.SetDefaults(opts...) cc.syncKnownUsecasesFromString() } return configs, nil } // Fall back to single config c := &ModelConfig{} if err := yaml.Unmarshal(f, c); err != nil { return nil, fmt.Errorf("readModelConfigsFromFile cannot unmarshal config file %q: %w", file, err) } c.modelConfigFile = file c.syncKnownUsecasesFromString() c.SetDefaults(opts...) return []*ModelConfig{c}, nil } // Load a config file for a model func (bcl *ModelConfigLoader) LoadModelConfigFileByName(modelName, modelPath string, opts ...ConfigLoaderOption) (*ModelConfig, error) { // Load a config file if present after the model name cfg := &ModelConfig{ PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: modelName, }, }, } cfgExisting, exists := bcl.GetModelConfig(modelName) if exists { cfg = &cfgExisting } else { // Try loading a model config file modelConfig := filepath.Join(modelPath, modelName+".yaml") if _, err := os.Stat(modelConfig); err == nil { if err := bcl.ReadModelConfig( modelConfig, opts..., ); err != nil { return nil, fmt.Errorf("failed loading model config (%s) %s", modelConfig, err.Error()) } cfgExisting, exists = bcl.GetModelConfig(modelName) if exists { cfg = &cfgExisting } } } cfg.SetDefaults(append(opts, ModelPath(modelPath))...) return cfg, nil } func (bcl *ModelConfigLoader) LoadModelConfigFileByNameDefaultOptions(modelName string, appConfig *ApplicationConfig) (*ModelConfig, error) { return bcl.LoadModelConfigFileByName(modelName, appConfig.SystemState.Model.ModelsPath, LoadOptionDebug(appConfig.Debug), LoadOptionThreads(appConfig.Threads), LoadOptionContextSize(appConfig.ContextSize), LoadOptionF16(appConfig.F16), ModelPath(appConfig.SystemState.Model.ModelsPath)) } // This format is currently only used when reading a single file at startup, passed in via ApplicationConfig.ConfigFile func (bcl *ModelConfigLoader) LoadMultipleModelConfigsSingleFile(file string, opts ...ConfigLoaderOption) error { bcl.Lock() defer bcl.Unlock() c, err := readModelConfigsFromFile(file, opts...) if err != nil { return fmt.Errorf("cannot load config file: %w", err) } for _, cc := range c { if valid, err := cc.Validate(); valid { bcl.configs[cc.Name] = *cc } else { xlog.Warn("skipping invalid model config", "name", cc.Name, "error", err) } } return nil } func (bcl *ModelConfigLoader) ReadModelConfig(file string, opts ...ConfigLoaderOption) error { bcl.Lock() defer bcl.Unlock() configs, err := readModelConfigsFromFile(file, opts...) if err != nil { return fmt.Errorf("ReadModelConfig cannot read config file %q: %w", file, err) } if len(configs) == 0 { return fmt.Errorf("ReadModelConfig: no configs found in file %q", file) } if len(configs) > 1 { xlog.Warn("ReadModelConig: read more than one config from file, only using first", "file", file, "configs", len(configs)) } c := configs[0] if valid, err := c.Validate(); valid { bcl.configs[c.Name] = *c } else { if err != nil { return fmt.Errorf("config is not valid: %w", err) } return fmt.Errorf("config is not valid") } return nil } func (bcl *ModelConfigLoader) GetModelConfig(m string) (ModelConfig, bool) { bcl.Lock() defer bcl.Unlock() v, exists := bcl.configs[m] return v, exists } func (bcl *ModelConfigLoader) GetAllModelsConfigs() []ModelConfig { bcl.Lock() defer bcl.Unlock() var res []ModelConfig for _, v := range bcl.configs { res = append(res, v) } sort.SliceStable(res, func(i, j int) bool { return res[i].Name < res[j].Name }) return res } func (bcl *ModelConfigLoader) GetModelConfigsByFilter(filter ModelConfigFilterFn) []ModelConfig { bcl.Lock() defer bcl.Unlock() var res []ModelConfig if filter == nil { filter = NoFilterFn } for n, v := range bcl.configs { if filter(n, &v) { res = append(res, v) } } // TODO: I don't think this one needs to Sort on name... but we'll see what breaks. return res } func (bcl *ModelConfigLoader) RemoveModelConfig(m string) { bcl.Lock() defer bcl.Unlock() delete(bcl.configs, m) } // UpdateModelConfig updates an existing model config in the loader. // This is useful for updating runtime-detected properties like thinking support. func (bcl *ModelConfigLoader) UpdateModelConfig(m string, updater func(*ModelConfig)) { bcl.Lock() defer bcl.Unlock() if cfg, exists := bcl.configs[m]; exists { updater(&cfg) bcl.configs[m] = cfg } } // Preload prepare models if they are not local but url or huggingface repositories func (bcl *ModelConfigLoader) Preload(modelPath string) error { bcl.Lock() defer bcl.Unlock() status := func(fileName, current, total string, percent float64) { utils.DisplayDownloadFunction(fileName, current, total, percent) } xlog.Info("Preloading models", "path", modelPath) renderMode := "dark" if os.Getenv("COLOR") != "" { renderMode = os.Getenv("COLOR") } glamText := func(t string) { out, err := glamour.Render(t, renderMode) if err == nil && os.Getenv("NO_COLOR") == "" { fmt.Println(out) } else { fmt.Println(t) } } for i, config := range bcl.configs { // Download files and verify their SHA for i, file := range config.DownloadFiles { xlog.Debug("Checking file exists and matches SHA", "filename", file.Filename) if err := utils.VerifyPath(file.Filename, modelPath); err != nil { return err } // Create file path filePath := filepath.Join(modelPath, file.Filename) if err := file.URI.DownloadFile(filePath, file.SHA256, i, len(config.DownloadFiles), status); err != nil { return err } } // If the model is an URL, expand it, and download the file if config.IsModelURL() { modelFileName := config.ModelFileName() uri := downloader.URI(config.Model) if uri.ResolveURL() != config.Model { // check if file exists if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status) if err != nil { return err } } cc := bcl.configs[i] c := &cc c.PredictionOptions.Model = modelFileName bcl.configs[i] = *c } } if config.IsMMProjURL() { modelFileName := config.MMProjFileName() uri := downloader.URI(config.MMProj) // check if file exists if _, err := os.Stat(filepath.Join(modelPath, modelFileName)); errors.Is(err, os.ErrNotExist) { err := uri.DownloadFile(filepath.Join(modelPath, modelFileName), "", 0, 0, status) if err != nil { return err } } cc := bcl.configs[i] c := &cc c.MMProj = modelFileName bcl.configs[i] = *c } if bcl.configs[i].Name != "" { glamText(fmt.Sprintf("**Model name**: _%s_", bcl.configs[i].Name)) } if bcl.configs[i].Description != "" { //glamText("**Description**") glamText(bcl.configs[i].Description) } if bcl.configs[i].Usage != "" { //glamText("**Usage**") glamText(bcl.configs[i].Usage) } } return nil } // LoadModelConfigsFromPath reads all the configurations of the models from a path // (non-recursive) func (bcl *ModelConfigLoader) LoadModelConfigsFromPath(path string, opts ...ConfigLoaderOption) error { bcl.Lock() defer bcl.Unlock() entries, err := os.ReadDir(path) if err != nil { return fmt.Errorf("LoadModelConfigsFromPath cannot read directory '%s': %w", path, err) } files := make([]fs.FileInfo, 0, len(entries)) for _, entry := range entries { info, err := entry.Info() if err != nil { return err } files = append(files, info) } for _, file := range files { // Skip templates, YAML and .keep files if !strings.Contains(file.Name(), ".yaml") && !strings.Contains(file.Name(), ".yml") || strings.HasPrefix(file.Name(), ".") { continue } filePath := filepath.Join(path, file.Name()) // Read config(s) - handles both single and array formats configs, err := readModelConfigsFromFile(filePath, opts...) if err != nil { xlog.Error("LoadModelConfigsFromPath cannot read config file", "error", err, "File Name", file.Name()) continue } // Validate and store each config for _, c := range configs { if valid, validationErr := c.Validate(); valid { bcl.configs[c.Name] = *c } else { xlog.Error("config is not valid", "error", validationErr, "Name", c.Name) } } } return nil } ================================================ FILE: core/config/model_config_test.go ================================================ package config import ( "io" "net/http" "os" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Test cases for config related functions", func() { Context("Test Read configuration functions", func() { It("Test Validate", func() { tmp, err := os.CreateTemp("", "config.yaml") Expect(err).To(BeNil()) defer os.Remove(tmp.Name()) _, err = tmp.WriteString( `backend: "../foo-bar" name: "foo" parameters: model: "foo-bar" known_usecases: - chat - COMPLETION `) Expect(err).ToNot(HaveOccurred()) configs, err := readModelConfigsFromFile(tmp.Name()) config := configs[0] Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) valid, err := config.Validate() Expect(err).To(HaveOccurred()) Expect(valid).To(BeFalse()) Expect(config.KnownUsecases).ToNot(BeNil()) }) It("Test Validate", func() { tmp, err := os.CreateTemp("", "config.yaml") Expect(err).To(BeNil()) defer os.Remove(tmp.Name()) _, err = tmp.WriteString( `name: bar-baz backend: "foo-bar" parameters: model: "foo-bar"`) Expect(err).ToNot(HaveOccurred()) configs, err := readModelConfigsFromFile(tmp.Name()) config := configs[0] Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml Expect(config.Name).To(Equal("bar-baz")) valid, err := config.Validate() Expect(err).To(BeNil()) Expect(valid).To(BeTrue()) // download https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml httpClient := http.Client{} resp, err := httpClient.Get("https://raw.githubusercontent.com/mudler/LocalAI/v2.25.0/embedded/models/hermes-2-pro-mistral.yaml") Expect(err).To(BeNil()) defer resp.Body.Close() tmp, err = os.CreateTemp("", "config.yaml") Expect(err).To(BeNil()) defer os.Remove(tmp.Name()) _, err = io.Copy(tmp, resp.Body) Expect(err).To(BeNil()) configs, err = readModelConfigsFromFile(tmp.Name()) config = configs[0] Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml Expect(config.Name).To(Equal("hermes-2-pro-mistral")) valid, err = config.Validate() Expect(err).To(BeNil()) Expect(valid).To(BeTrue()) }) }) It("Properly handles backend usecase matching", func() { a := ModelConfig{ Name: "a", } Expect(a.HasUsecases(FLAG_ANY)).To(BeTrue()) // FLAG_ANY just means the config _exists_ essentially. b := ModelConfig{ Name: "b", Backend: "stablediffusion", } Expect(b.HasUsecases(FLAG_ANY)).To(BeTrue()) Expect(b.HasUsecases(FLAG_IMAGE)).To(BeTrue()) Expect(b.HasUsecases(FLAG_CHAT)).To(BeFalse()) c := ModelConfig{ Name: "c", Backend: "llama-cpp", TemplateConfig: TemplateConfig{ Chat: "chat", }, } Expect(c.HasUsecases(FLAG_ANY)).To(BeTrue()) Expect(c.HasUsecases(FLAG_IMAGE)).To(BeFalse()) Expect(c.HasUsecases(FLAG_COMPLETION)).To(BeFalse()) Expect(c.HasUsecases(FLAG_CHAT)).To(BeTrue()) d := ModelConfig{ Name: "d", Backend: "llama-cpp", TemplateConfig: TemplateConfig{ Chat: "chat", Completion: "completion", }, } Expect(d.HasUsecases(FLAG_ANY)).To(BeTrue()) Expect(d.HasUsecases(FLAG_IMAGE)).To(BeFalse()) Expect(d.HasUsecases(FLAG_COMPLETION)).To(BeTrue()) Expect(d.HasUsecases(FLAG_CHAT)).To(BeTrue()) trueValue := true e := ModelConfig{ Name: "e", Backend: "llama-cpp", TemplateConfig: TemplateConfig{ Completion: "completion", }, Embeddings: &trueValue, } Expect(e.HasUsecases(FLAG_ANY)).To(BeTrue()) Expect(e.HasUsecases(FLAG_IMAGE)).To(BeFalse()) Expect(e.HasUsecases(FLAG_COMPLETION)).To(BeTrue()) Expect(e.HasUsecases(FLAG_CHAT)).To(BeFalse()) Expect(e.HasUsecases(FLAG_EMBEDDINGS)).To(BeTrue()) f := ModelConfig{ Name: "f", Backend: "piper", } Expect(f.HasUsecases(FLAG_ANY)).To(BeTrue()) Expect(f.HasUsecases(FLAG_TTS)).To(BeTrue()) Expect(f.HasUsecases(FLAG_CHAT)).To(BeFalse()) g := ModelConfig{ Name: "g", Backend: "whisper", } Expect(g.HasUsecases(FLAG_ANY)).To(BeTrue()) Expect(g.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue()) Expect(g.HasUsecases(FLAG_TTS)).To(BeFalse()) h := ModelConfig{ Name: "h", Backend: "transformers-musicgen", } Expect(h.HasUsecases(FLAG_ANY)).To(BeTrue()) Expect(h.HasUsecases(FLAG_TRANSCRIPT)).To(BeFalse()) Expect(h.HasUsecases(FLAG_TTS)).To(BeTrue()) Expect(h.HasUsecases(FLAG_SOUND_GENERATION)).To(BeTrue()) knownUsecases := FLAG_CHAT | FLAG_COMPLETION i := ModelConfig{ Name: "i", Backend: "whisper", // Earlier test checks parsing, this just needs to set final values KnownUsecases: &knownUsecases, } Expect(i.HasUsecases(FLAG_ANY)).To(BeTrue()) Expect(i.HasUsecases(FLAG_TRANSCRIPT)).To(BeTrue()) Expect(i.HasUsecases(FLAG_TTS)).To(BeFalse()) Expect(i.HasUsecases(FLAG_COMPLETION)).To(BeTrue()) Expect(i.HasUsecases(FLAG_CHAT)).To(BeTrue()) }) It("Test Validate with invalid MCP config", func() { tmp, err := os.CreateTemp("", "config.yaml") Expect(err).To(BeNil()) defer os.Remove(tmp.Name()) _, err = tmp.WriteString( `name: test-mcp backend: "llama-cpp" mcp: stdio: | { "mcpServers": { "ddg": { "command": "/docker/docker", "args": ["run", "-i"] } "weather": { "command": "/docker/docker", "args": ["run", "-i"] } } }`) Expect(err).ToNot(HaveOccurred()) configs, err := readModelConfigsFromFile(tmp.Name()) config := configs[0] Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) valid, err := config.Validate() Expect(err).To(HaveOccurred()) Expect(valid).To(BeFalse()) Expect(err.Error()).To(ContainSubstring("invalid MCP configuration")) }) It("Test Validate with valid MCP config", func() { tmp, err := os.CreateTemp("", "config.yaml") Expect(err).To(BeNil()) defer os.Remove(tmp.Name()) _, err = tmp.WriteString( `name: test-mcp-valid backend: "llama-cpp" mcp: stdio: | { "mcpServers": { "ddg": { "command": "/docker/docker", "args": ["run", "-i"] }, "weather": { "command": "/docker/docker", "args": ["run", "-i"] } } }`) Expect(err).ToNot(HaveOccurred()) configs, err := readModelConfigsFromFile(tmp.Name()) config := configs[0] Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) valid, err := config.Validate() Expect(err).To(BeNil()) Expect(valid).To(BeTrue()) }) }) ================================================ FILE: core/config/model_test.go ================================================ package config import ( "os" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Test cases for config related functions", func() { var ( configFile string ) Context("Test Read configuration functions", func() { configFile = os.Getenv("CONFIG_FILE") It("Test readConfigFile", func() { config, err := readModelConfigsFromFile(configFile) Expect(err).To(BeNil()) Expect(config).ToNot(BeNil()) // two configs in config.yaml Expect(config[0].Name).To(Equal("list1")) Expect(config[1].Name).To(Equal("list2")) }) It("Test LoadConfigs", func() { bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH")) err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH")) Expect(err).To(BeNil()) configs := bcl.GetAllModelsConfigs() loadedModelNames := []string{} for _, v := range configs { loadedModelNames = append(loadedModelNames, v.Name) } Expect(configs).ToNot(BeNil()) Expect(loadedModelNames).To(ContainElements("code-search-ada-code-001")) // config should includes text-embedding-ada-002 models's api.config Expect(loadedModelNames).To(ContainElements("text-embedding-ada-002")) // config should includes rwkv_test models's api.config Expect(loadedModelNames).To(ContainElements("rwkv_test")) // config should includes whisper-1 models's api.config Expect(loadedModelNames).To(ContainElements("whisper-1")) }) It("Test new loadconfig", func() { bcl := NewModelConfigLoader(os.Getenv("MODELS_PATH")) err := bcl.LoadModelConfigsFromPath(os.Getenv("MODELS_PATH")) Expect(err).To(BeNil()) configs := bcl.GetAllModelsConfigs() loadedModelNames := []string{} for _, v := range configs { loadedModelNames = append(loadedModelNames, v.Name) } Expect(configs).ToNot(BeNil()) totalModels := len(loadedModelNames) Expect(loadedModelNames).To(ContainElements("code-search-ada-code-001")) // config should includes text-embedding-ada-002 models's api.config Expect(loadedModelNames).To(ContainElements("text-embedding-ada-002")) // config should includes rwkv_test models's api.config Expect(loadedModelNames).To(ContainElements("rwkv_test")) // config should includes whisper-1 models's api.config Expect(loadedModelNames).To(ContainElements("whisper-1")) // create a temp directory and store a temporary model tmpdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tmpdir) // create a temporary model model := `name: "test-model" description: "test model" options: - foo - bar - baz ` modelFile := tmpdir + "/test-model.yaml" err = os.WriteFile(modelFile, []byte(model), 0644) Expect(err).ToNot(HaveOccurred()) err = bcl.LoadModelConfigsFromPath(tmpdir) Expect(err).ToNot(HaveOccurred()) configs = bcl.GetAllModelsConfigs() Expect(len(configs)).ToNot(Equal(totalModels)) loadedModelNames = []string{} var testModel ModelConfig for _, v := range configs { loadedModelNames = append(loadedModelNames, v.Name) if v.Name == "test-model" { testModel = v } } Expect(loadedModelNames).To(ContainElements("test-model")) Expect(testModel.Description).To(Equal("test model")) Expect(testModel.Options).To(ContainElements("foo", "bar", "baz")) }) }) }) ================================================ FILE: core/config/runtime_settings.go ================================================ package config // RuntimeSettings represents runtime configuration that can be changed dynamically. // This struct is used for: // - API responses (GET /api/settings) // - API requests (POST /api/settings) // - Persisting to runtime_settings.json // - Loading from runtime_settings.json on startup // // All fields are pointers to distinguish between "not set" and "set to zero/false value". type RuntimeSettings struct { // Watchdog settings WatchdogEnabled *bool `json:"watchdog_enabled,omitempty"` WatchdogIdleEnabled *bool `json:"watchdog_idle_enabled,omitempty"` WatchdogBusyEnabled *bool `json:"watchdog_busy_enabled,omitempty"` WatchdogIdleTimeout *string `json:"watchdog_idle_timeout,omitempty"` WatchdogBusyTimeout *string `json:"watchdog_busy_timeout,omitempty"` WatchdogInterval *string `json:"watchdog_interval,omitempty"` // Interval between watchdog checks (e.g., 2s, 30s) // Backend management SingleBackend *bool `json:"single_backend,omitempty"` // Deprecated: use MaxActiveBackends = 1 instead MaxActiveBackends *int `json:"max_active_backends,omitempty"` // Maximum number of active backends (0 = unlimited, 1 = single backend mode) ParallelBackendRequests *bool `json:"parallel_backend_requests,omitempty"` // Memory Reclaimer settings (works with GPU if available, otherwise RAM) MemoryReclaimerEnabled *bool `json:"memory_reclaimer_enabled,omitempty"` // Enable memory threshold monitoring MemoryReclaimerThreshold *float64 `json:"memory_reclaimer_threshold,omitempty"` // Threshold 0.0-1.0 (e.g., 0.95 = 95%) // Eviction settings ForceEvictionWhenBusy *bool `json:"force_eviction_when_busy,omitempty"` // Force eviction even when models have active API calls (default: false for safety) LRUEvictionMaxRetries *int `json:"lru_eviction_max_retries,omitempty"` // Maximum number of retries when waiting for busy models to become idle (default: 30) LRUEvictionRetryInterval *string `json:"lru_eviction_retry_interval,omitempty"` // Interval between retries when waiting for busy models (e.g., 1s, 2s) (default: 1s) // Performance settings Threads *int `json:"threads,omitempty"` ContextSize *int `json:"context_size,omitempty"` F16 *bool `json:"f16,omitempty"` Debug *bool `json:"debug,omitempty"` EnableTracing *bool `json:"enable_tracing,omitempty"` TracingMaxItems *int `json:"tracing_max_items,omitempty"` EnableBackendLogging *bool `json:"enable_backend_logging,omitempty"` // Security/CORS settings CORS *bool `json:"cors,omitempty"` CSRF *bool `json:"csrf,omitempty"` CORSAllowOrigins *string `json:"cors_allow_origins,omitempty"` // P2P settings P2PToken *string `json:"p2p_token,omitempty"` P2PNetworkID *string `json:"p2p_network_id,omitempty"` Federated *bool `json:"federated,omitempty"` // Gallery settings Galleries *[]Gallery `json:"galleries,omitempty"` BackendGalleries *[]Gallery `json:"backend_galleries,omitempty"` AutoloadGalleries *bool `json:"autoload_galleries,omitempty"` AutoloadBackendGalleries *bool `json:"autoload_backend_galleries,omitempty"` // API keys - No omitempty as we need to save empty arrays to clear keys ApiKeys *[]string `json:"api_keys"` // Agent settings AgentJobRetentionDays *int `json:"agent_job_retention_days,omitempty"` // Open Responses settings OpenResponsesStoreTTL *string `json:"open_responses_store_ttl,omitempty"` // TTL for stored responses (e.g., "1h", "30m", "0" = no expiration) // Agent Pool settings AgentPoolEnabled *bool `json:"agent_pool_enabled,omitempty"` AgentPoolDefaultModel *string `json:"agent_pool_default_model,omitempty"` AgentPoolEmbeddingModel *string `json:"agent_pool_embedding_model,omitempty"` AgentPoolMaxChunkingSize *int `json:"agent_pool_max_chunking_size,omitempty"` AgentPoolChunkOverlap *int `json:"agent_pool_chunk_overlap,omitempty"` AgentPoolEnableLogs *bool `json:"agent_pool_enable_logs,omitempty"` AgentPoolCollectionDBPath *string `json:"agent_pool_collection_db_path,omitempty"` } ================================================ FILE: core/dependencies_manager/manager.go ================================================ // DEPRECATED: This tool downloads static assets for the legacy Alpine.js UI. // The new React UI (core/http/react-ui/) bundles all dependencies via npm. // Remove this file when the legacy UI (core/http/views/) is removed. package main import ( "fmt" "os" "path/filepath" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/utils" "gopkg.in/yaml.v3" ) type Asset struct { FileName string `yaml:"filename"` URL string `yaml:"url"` SHA string `yaml:"sha"` } func main() { // read the YAML file which contains a list of assets // and download them in the asset path assets := []Asset{} assetFile := os.Args[1] destPath := os.Args[2] // read the YAML file f, err := os.ReadFile(assetFile) if err != nil { panic(err) } // unmarshal the YAML data into a struct if err := yaml.Unmarshal(f, &assets); err != nil { panic(err) } // download the assets for _, asset := range assets { uri := downloader.URI(asset.URL) if err := uri.DownloadFile(filepath.Join(destPath, asset.FileName), asset.SHA, 1, 1, utils.DisplayDownloadFunction); err != nil { panic(err) } } fmt.Println("Finished downloading assets") } ================================================ FILE: core/explorer/database.go ================================================ package explorer // A simple JSON database for storing and retrieving p2p network tokens and a name and description. import ( "encoding/json" "os" "sort" "sync" "github.com/gofrs/flock" ) // Database is a simple JSON database for storing and retrieving p2p network tokens and a name and description. type Database struct { path string data map[string]TokenData flock *flock.Flock sync.Mutex } // TokenData is a p2p network token with a name and description. type TokenData struct { Name string `json:"name"` Description string `json:"description"` Clusters []ClusterData Failures int } type ClusterData struct { Workers []string Type string NetworkID string } // NewDatabase creates a new Database with the given path. func NewDatabase(path string) (*Database, error) { fileLock := flock.New(path + ".lock") db := &Database{ data: make(map[string]TokenData), path: path, flock: fileLock, } return db, db.load() } // Get retrieves a Token from the Database by its token. func (db *Database) Get(token string) (TokenData, bool) { db.flock.Lock() // we are making sure that the file is not being written to defer db.flock.Unlock() db.Lock() // we are making sure that is safe if called by another instance in the same process defer db.Unlock() db.load() t, ok := db.data[token] return t, ok } // Set stores a Token in the Database by its token. func (db *Database) Set(token string, t TokenData) error { db.flock.Lock() defer db.flock.Unlock() db.Lock() defer db.Unlock() db.load() db.data[token] = t return db.save() } // Delete removes a Token from the Database by its token. func (db *Database) Delete(token string) error { db.flock.Lock() defer db.flock.Unlock() db.Lock() defer db.Unlock() db.load() delete(db.data, token) return db.save() } func (db *Database) TokenList() []string { db.flock.Lock() defer db.flock.Unlock() db.Lock() defer db.Unlock() db.load() tokens := []string{} for k := range db.data { tokens = append(tokens, k) } sort.Slice(tokens, func(i, j int) bool { // sort by token return tokens[i] < tokens[j] }) return tokens } // load reads the Database from disk. func (db *Database) load() error { if _, err := os.Stat(db.path); os.IsNotExist(err) { return nil } // Read the file from disk // Unmarshal the JSON into db.data f, err := os.ReadFile(db.path) if err != nil { return err } return json.Unmarshal(f, &db.data) } // Save writes the Database to disk. func (db *Database) save() error { // Marshal db.data into JSON // Write the JSON to the file f, err := os.Create(db.path) if err != nil { return err } defer f.Close() return json.NewEncoder(f).Encode(db.data) } ================================================ FILE: core/explorer/database_test.go ================================================ package explorer_test import ( "os" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/mudler/LocalAI/core/explorer" ) var _ = Describe("Database", func() { var ( dbPath string db *explorer.Database err error ) BeforeEach(func() { // Create a temporary file path for the database dbPath = "test_db.json" db, err = explorer.NewDatabase(dbPath) Expect(err).To(BeNil()) }) AfterEach(func() { // Clean up the temporary database file os.Remove(dbPath) }) Context("when managing tokens", func() { It("should add and retrieve a token", func() { token := "token123" t := explorer.TokenData{Name: "TokenName", Description: "A test token"} err = db.Set(token, t) Expect(err).To(BeNil()) retrievedToken, exists := db.Get(token) Expect(exists).To(BeTrue()) Expect(retrievedToken).To(Equal(t)) }) It("should delete a token", func() { token := "token123" t := explorer.TokenData{Name: "TokenName", Description: "A test token"} err = db.Set(token, t) Expect(err).To(BeNil()) err = db.Delete(token) Expect(err).To(BeNil()) _, exists := db.Get(token) Expect(exists).To(BeFalse()) }) It("should persist data to disk", func() { token := "token123" t := explorer.TokenData{Name: "TokenName", Description: "A test token"} err = db.Set(token, t) Expect(err).To(BeNil()) // Recreate the database object to simulate reloading from disk db, err = explorer.NewDatabase(dbPath) Expect(err).To(BeNil()) retrievedToken, exists := db.Get(token) Expect(exists).To(BeTrue()) Expect(retrievedToken).To(Equal(t)) // Check the token list tokenList := db.TokenList() Expect(tokenList).To(ContainElement(token)) }) }) Context("when loading an empty or non-existent file", func() { It("should start with an empty database", func() { dbPath = "empty_db.json" db, err = explorer.NewDatabase(dbPath) Expect(err).To(BeNil()) _, exists := db.Get("nonexistent") Expect(exists).To(BeFalse()) // Clean up os.Remove(dbPath) }) }) }) ================================================ FILE: core/explorer/discovery.go ================================================ package explorer import ( "context" "fmt" "strings" "sync" "time" "github.com/mudler/xlog" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/edgevpn/pkg/blockchain" ) type DiscoveryServer struct { sync.Mutex database *Database connectionTime time.Duration errorThreshold int } // NewDiscoveryServer creates a new DiscoveryServer with the given Database. // it keeps the db state in sync with the network state func NewDiscoveryServer(db *Database, dur time.Duration, failureThreshold int) *DiscoveryServer { if dur == 0 { dur = 50 * time.Second } if failureThreshold == 0 { failureThreshold = 3 } return &DiscoveryServer{ database: db, connectionTime: dur, errorThreshold: failureThreshold, } } type Network struct { Clusters []ClusterData } func (s *DiscoveryServer) runBackground() { if len(s.database.TokenList()) == 0 { time.Sleep(5 * time.Second) // avoid busy loop return } for _, token := range s.database.TokenList() { c, cancel := context.WithTimeout(context.Background(), s.connectionTime) defer cancel() // Connect to the network // Get the number of nodes // save it in the current state (mutex) // do not do in parallel n, err := p2p.NewNode(token) if err != nil { xlog.Error("Failed to create node", "error", err) s.failedToken(token) continue } err = n.Start(c) if err != nil { xlog.Error("Failed to start node", "error", err) s.failedToken(token) continue } ledger, err := n.Ledger() if err != nil { xlog.Error("Failed to start ledger", "error", err) s.failedToken(token) continue } networkData := make(chan ClusterData) // get the network data - it takes the whole timeout // as we might not be connected to the network yet, // and few attempts would have to be made before bailing out go s.retrieveNetworkData(c, ledger, networkData) hasWorkers := false ledgerK := []ClusterData{} for key := range networkData { ledgerK = append(ledgerK, key) if len(key.Workers) > 0 { hasWorkers = true } } xlog.Debug("Network clusters", "network", token, "count", len(ledgerK)) if len(ledgerK) != 0 { for _, k := range ledgerK { xlog.Debug("Clusterdata", "network", token, "cluster", k) } } if hasWorkers { s.Lock() data, _ := s.database.Get(token) (&data).Clusters = ledgerK (&data).Failures = 0 s.database.Set(token, data) s.Unlock() } else { s.failedToken(token) } } s.deleteFailedConnections() } func (s *DiscoveryServer) failedToken(token string) { s.Lock() defer s.Unlock() data, _ := s.database.Get(token) (&data).Failures++ s.database.Set(token, data) } func (s *DiscoveryServer) deleteFailedConnections() { s.Lock() defer s.Unlock() for _, t := range s.database.TokenList() { data, _ := s.database.Get(t) if data.Failures > s.errorThreshold { xlog.Info("Token has been removed from the database", "token", t) s.database.Delete(t) } } } func (s *DiscoveryServer) retrieveNetworkData(c context.Context, ledger *blockchain.Ledger, networkData chan ClusterData) { clusters := map[string]ClusterData{} defer func() { for _, n := range clusters { networkData <- n } close(networkData) }() for { select { case <-c.Done(): return default: time.Sleep(5 * time.Second) data := ledger.LastBlock().Storage LEDGER: for d := range data { toScanForWorkers := false cd := ClusterData{} isWorkerCluster := d == p2p.LlamaCPPWorkerID || (strings.Contains(d, "_") && strings.Contains(d, p2p.LlamaCPPWorkerID)) isFederatedCluster := d == p2p.FederatedID || (strings.Contains(d, "_") && strings.Contains(d, p2p.FederatedID)) switch { case isWorkerCluster: toScanForWorkers = true cd.Type = "worker" case isFederatedCluster: toScanForWorkers = true cd.Type = "federated" } if strings.Contains(d, "_") { cd.NetworkID = strings.Split(d, "_")[0] } if !toScanForWorkers { continue LEDGER } atLeastOneWorker := false DATA: for _, v := range data[d] { nd := &schema.NodeData{} if err := v.Unmarshal(nd); err != nil { continue DATA } if nd.IsOnline() { atLeastOneWorker = true (&cd).Workers = append(cd.Workers, nd.ID) } } if atLeastOneWorker { clusters[d] = cd } } } } } // Start the discovery server. This is meant to be run in to a goroutine. func (s *DiscoveryServer) Start(ctx context.Context, keepRunning bool) error { for { select { case <-ctx.Done(): return fmt.Errorf("context cancelled") default: // Collect data s.runBackground() if !keepRunning { return nil } } } } ================================================ FILE: core/explorer/explorer_suite_test.go ================================================ package explorer_test import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestExplorer(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Explorer test suite") } ================================================ FILE: core/gallery/backend_resolve.go ================================================ package gallery import ( "os" "path/filepath" "strings" "sync" "time" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/xsync" "github.com/mudler/xlog" "gopkg.in/yaml.v3" ) // modelConfigCacheEntry holds a cached parsed config_file map from a URL-referenced model config. type modelConfigCacheEntry struct { configMap map[string]interface{} lastUpdated time.Time } func (e modelConfigCacheEntry) hasExpired() bool { return e.lastUpdated.Before(time.Now().Add(-1 * time.Hour)) } // modelConfigCache caches parsed model config maps keyed by URL. var modelConfigCache = xsync.NewSyncedMap[string, modelConfigCacheEntry]() // resolveBackend determines the backend for a GalleryModel by checking (in priority order): // 1. Overrides["backend"] — highest priority, same as install-time merge // 2. Inline ConfigFile["backend"] — for models with inline config maps // 3. URL-referenced config file — fetched, parsed, and cached // // The model's URL should already be resolved (local override applied) before calling this. func resolveBackend(m *GalleryModel, basePath string) string { // 1. Overrides take priority (matches install-time mergo.WithOverride behavior) if b, ok := m.Overrides["backend"].(string); ok && b != "" { return b } // 2. Inline config_file map if b, ok := m.ConfigFile["backend"].(string); ok && b != "" { return b } // 3. Fetch and parse the URL-referenced config if m.URL != "" { configMap := fetchModelConfigMap(m.URL, basePath) if b, ok := configMap["backend"].(string); ok && b != "" { return b } } return "" } // fetchModelConfigMap fetches a model config URL, parses the config_file YAML string // inside it, and returns the result as a map. Results are cached for 1 hour. // Local file:// URLs skip the cache so edits are picked up immediately. func fetchModelConfigMap(modelURL, basePath string) map[string]interface{} { // Check cache (skip for file:// URLs so local edits are picked up immediately) isLocal := strings.HasPrefix(modelURL, downloader.LocalPrefix) if !isLocal && modelConfigCache.Exists(modelURL) { entry := modelConfigCache.Get(modelURL) if !entry.hasExpired() { return entry.configMap } modelConfigCache.Delete(modelURL) } // Reuse existing gallery config fetcher modelConfig, err := GetGalleryConfigFromURL[ModelConfig](modelURL, basePath) if err != nil { xlog.Debug("Failed to fetch model config for backend resolution", "url", modelURL, "error", err) // Cache the failure for remote URLs to avoid repeated fetch attempts if !isLocal { modelConfigCache.Set(modelURL, modelConfigCacheEntry{ configMap: map[string]interface{}{}, lastUpdated: time.Now(), }) } return map[string]interface{}{} } // Parse the config_file YAML string into a map configMap := make(map[string]interface{}) if modelConfig.ConfigFile != "" { if err := yaml.Unmarshal([]byte(modelConfig.ConfigFile), &configMap); err != nil { xlog.Debug("Failed to parse config_file for backend resolution", "url", modelURL, "error", err) } } // Cache for remote URLs if !isLocal { modelConfigCache.Set(modelURL, modelConfigCacheEntry{ configMap: configMap, lastUpdated: time.Now(), }) } return configMap } // prefetchModelConfigs fetches model config URLs in parallel to warm the cache. // This avoids sequential HTTP requests on cold start (~50 unique gallery files). func prefetchModelConfigs(urls []string, basePath string) { const maxConcurrency = 10 sem := make(chan struct{}, maxConcurrency) var wg sync.WaitGroup for _, url := range urls { wg.Add(1) go func(u string) { defer wg.Done() sem <- struct{}{} defer func() { <-sem }() fetchModelConfigMap(u, basePath) }(url) } wg.Wait() } // resolveModelURLLocally attempts to resolve a github: model URL to a local file:// // path when the gallery itself was loaded from a local path. This supports development // workflows where new model files are added locally before being pushed to GitHub. // // For example, if the gallery was loaded from file:///path/to/gallery/index.yaml // and a model references github:mudler/LocalAI/gallery/foo.yaml@master, this will // check if /path/to/gallery/foo.yaml exists locally and return file:///path/to/gallery/foo.yaml. // // This is applied to model.URL in AvailableGalleryModels so that both listing (backend // resolution) and installation use the same resolved URL. func resolveModelURLLocally(modelURL, galleryURL string) string { galleryDir := localGalleryDir(galleryURL) if galleryDir == "" { return modelURL } // Only handle github: URLs if !strings.HasPrefix(modelURL, downloader.GithubURI) && !strings.HasPrefix(modelURL, downloader.GithubURI2) { return modelURL } // Extract the filename from the github URL // Format: github:org/repo/path/to/file.yaml@branch raw := strings.TrimPrefix(modelURL, downloader.GithubURI2) raw = strings.TrimPrefix(raw, downloader.GithubURI) // Remove @branch suffix if idx := strings.LastIndex(raw, "@"); idx >= 0 { raw = raw[:idx] } filename := filepath.Base(raw) localPath := filepath.Join(galleryDir, filename) if _, err := os.Stat(localPath); err == nil { return downloader.LocalPrefix + localPath } return modelURL } // localGalleryDir returns the directory of a gallery URL if it's local, or "" if remote. func localGalleryDir(galleryURL string) string { if strings.HasPrefix(galleryURL, downloader.LocalPrefix) { return filepath.Dir(strings.TrimPrefix(galleryURL, downloader.LocalPrefix)) } // Plain path (no scheme) that exists on disk if !strings.Contains(galleryURL, "://") && !strings.HasPrefix(galleryURL, downloader.GithubURI) { if info, err := os.Stat(galleryURL); err == nil && !info.IsDir() { return filepath.Dir(galleryURL) } } return "" } ================================================ FILE: core/gallery/backend_types.go ================================================ package gallery import ( "fmt" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) // BackendMetadata represents the metadata stored in a JSON file for each installed backend type BackendMetadata struct { // Alias is an optional alternative name for the backend Alias string `json:"alias,omitempty"` // MetaBackendFor points to the concrete backend if this is a meta backend MetaBackendFor string `json:"meta_backend_for,omitempty"` // Name is the original name from the gallery Name string `json:"name,omitempty"` // GalleryURL is the URL of the gallery this backend came from GalleryURL string `json:"gallery_url,omitempty"` // InstalledAt is the timestamp when the backend was installed InstalledAt string `json:"installed_at,omitempty"` } type GalleryBackend struct { Metadata `json:",inline" yaml:",inline"` Alias string `json:"alias,omitempty" yaml:"alias,omitempty"` URI string `json:"uri,omitempty" yaml:"uri,omitempty"` Mirrors []string `json:"mirrors,omitempty" yaml:"mirrors,omitempty"` CapabilitiesMap map[string]string `json:"capabilities,omitempty" yaml:"capabilities,omitempty"` } func (backend *GalleryBackend) FindBestBackendFromMeta(systemState *system.SystemState, backends GalleryElements[*GalleryBackend]) *GalleryBackend { if systemState == nil { return nil } realBackend := backend.CapabilitiesMap[systemState.Capability(backend.CapabilitiesMap)] if realBackend == "" { xlog.Debug("No backend found for reported capability", "backend", backend.Name, "reportedCapability", systemState.Capability(backend.CapabilitiesMap)) return nil } xlog.Debug("Found backend for reported capability", "backend", backend.Name, "reportedCapability", systemState.Capability(backend.CapabilitiesMap)) return backends.FindByName(realBackend) } func (m *GalleryBackend) GetInstalled() bool { return m.Installed } func (m *GalleryBackend) GetLicense() string { return m.License } type GalleryBackends []*GalleryBackend func (m *GalleryBackend) SetGallery(gallery config.Gallery) { m.Gallery = gallery } func (m *GalleryBackend) IsMeta() bool { return len(m.CapabilitiesMap) > 0 && m.URI == "" } // IsCompatibleWith checks if the backend is compatible with the current system capability. // For meta backends, it checks if any of the capabilities in the map match the system capability. // For concrete backends, it delegates to SystemState.IsBackendCompatible. func (m *GalleryBackend) IsCompatibleWith(systemState *system.SystemState) bool { if systemState == nil { return true } // Meta backends are compatible if the system capability matches one of the keys if m.IsMeta() { capability := systemState.Capability(m.CapabilitiesMap) _, exists := m.CapabilitiesMap[capability] return exists } // For concrete backends, delegate to the system package return systemState.IsBackendCompatible(m.Name, m.URI) } func (m *GalleryBackend) SetInstalled(installed bool) { m.Installed = installed } func (m *GalleryBackend) GetName() string { return m.Name } func (m *GalleryBackend) GetGallery() config.Gallery { return m.Gallery } func (m *GalleryBackend) GetDescription() string { return m.Description } func (m *GalleryBackend) GetTags() []string { return m.Tags } func (m GalleryBackend) ID() string { return fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name) } ================================================ FILE: core/gallery/backends.go ================================================ // Package gallery provides installation and registration utilities for LocalAI backends, // including meta-backend resolution based on system capabilities. package gallery import ( "context" "os" "encoding/json" "errors" "fmt" "path/filepath" "strings" "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" cp "github.com/otiai10/copy" ) const ( metadataFile = "metadata.json" runFile = "run.sh" ) // Default fallback tag values const ( defaultLatestTag = "latest" defaultMasterTag = "master" defaultDevSuffix = "development" ) // getFallbackTagValues returns the configurable fallback tag values from SystemState func getFallbackTagValues(systemState *system.SystemState) (latestTag, masterTag, devSuffix string) { // Use SystemState fields if set, otherwise use defaults if systemState.BackendImagesReleaseTag != "" { latestTag = systemState.BackendImagesReleaseTag } else { latestTag = defaultLatestTag } if systemState.BackendImagesBranchTag != "" { masterTag = systemState.BackendImagesBranchTag } else { masterTag = defaultMasterTag } if systemState.BackendDevSuffix != "" { devSuffix = systemState.BackendDevSuffix } else { devSuffix = defaultDevSuffix } return latestTag, masterTag, devSuffix } // backendCandidate represents an installed concrete backend option for a given alias type backendCandidate struct { name string runFile string } // readBackendMetadata reads the metadata JSON file for a backend func readBackendMetadata(backendPath string) (*BackendMetadata, error) { metadataPath := filepath.Join(backendPath, metadataFile) // If metadata file doesn't exist, return nil (for backward compatibility) if _, err := os.Stat(metadataPath); os.IsNotExist(err) { return nil, nil } data, err := os.ReadFile(metadataPath) if err != nil { return nil, fmt.Errorf("failed to read metadata file %q: %v", metadataPath, err) } var metadata BackendMetadata if err := json.Unmarshal(data, &metadata); err != nil { return nil, fmt.Errorf("failed to unmarshal metadata file %q: %v", metadataPath, err) } return &metadata, nil } // writeBackendMetadata writes the metadata JSON file for a backend func writeBackendMetadata(backendPath string, metadata *BackendMetadata) error { metadataPath := filepath.Join(backendPath, metadataFile) data, err := json.MarshalIndent(metadata, "", " ") if err != nil { return fmt.Errorf("failed to marshal metadata: %v", err) } if err := os.WriteFile(metadataPath, data, 0644); err != nil { return fmt.Errorf("failed to write metadata file %q: %v", metadataPath, err) } return nil } // InstallBackendFromGallery installs a backend from the gallery. func InstallBackendFromGallery(ctx context.Context, galleries []config.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, downloadStatus func(string, string, string, float64), force bool) error { if !force { // check if we already have the backend installed backends, err := ListSystemBackends(systemState) if err != nil { return err } if backends.Exists(name) { return nil } } if name == "" { return fmt.Errorf("backend name is empty") } xlog.Debug("Installing backend from gallery", "galleries", galleries, "name", name) backends, err := AvailableBackends(galleries, systemState) if err != nil { return err } backend := FindGalleryElement(backends, name) if backend == nil { return fmt.Errorf("no backend found with name %q", name) } if backend.IsMeta() { xlog.Debug("Backend is a meta backend", "systemState", systemState, "name", name) // Then, let's try to find the best backend based on the capabilities map bestBackend := backend.FindBestBackendFromMeta(systemState, backends) if bestBackend == nil { return fmt.Errorf("no backend found with capabilities %q", backend.CapabilitiesMap) } xlog.Debug("Installing backend from meta backend", "name", name, "bestBackend", bestBackend.Name) // Then, let's install the best backend if err := InstallBackend(ctx, systemState, modelLoader, bestBackend, downloadStatus); err != nil { return err } // we need now to create a path for the meta backend, with the alias to the installed ones so it can be used to remove it metaBackendPath := filepath.Join(systemState.Backend.BackendsPath, name) if err := os.MkdirAll(metaBackendPath, 0750); err != nil { return fmt.Errorf("failed to create meta backend path %q: %v", metaBackendPath, err) } // Create metadata for the meta backend metaMetadata := &BackendMetadata{ MetaBackendFor: bestBackend.Name, Name: name, GalleryURL: backend.Gallery.URL, InstalledAt: time.Now().Format(time.RFC3339), } if err := writeBackendMetadata(metaBackendPath, metaMetadata); err != nil { return fmt.Errorf("failed to write metadata for meta backend %q: %v", name, err) } return nil } return InstallBackend(ctx, systemState, modelLoader, backend, downloadStatus) } func InstallBackend(ctx context.Context, systemState *system.SystemState, modelLoader *model.ModelLoader, config *GalleryBackend, downloadStatus func(string, string, string, float64)) error { // Get configurable fallback tag values from SystemState latestTag, masterTag, devSuffix := getFallbackTagValues(systemState) // Create base path if it doesn't exist err := os.MkdirAll(systemState.Backend.BackendsPath, 0750) if err != nil { return fmt.Errorf("failed to create base path: %v", err) } if config.IsMeta() { return fmt.Errorf("meta backends cannot be installed directly") } name := config.Name backendPath := filepath.Join(systemState.Backend.BackendsPath, name) err = os.MkdirAll(backendPath, 0750) if err != nil { return fmt.Errorf("failed to create base path: %v", err) } uri := downloader.URI(config.URI) // Check if it is a directory if uri.LooksLikeDir() { // It is a directory, we just copy it over in the backend folder if err := cp.Copy(config.URI, backendPath); err != nil { return fmt.Errorf("failed copying: %w", err) } } else { xlog.Debug("Downloading backend", "uri", config.URI, "backendPath", backendPath) if err := uri.DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err != nil { // Clean up the partially downloaded backend directory on failure xlog.Debug("Backend download failed, cleaning up", "backendPath", backendPath, "error", err) if cleanupErr := os.RemoveAll(backendPath); cleanupErr != nil { xlog.Warn("Failed to clean up backend directory", "backendPath", backendPath, "error", cleanupErr) } success := false // Try to download from mirrors for _, mirror := range config.Mirrors { // Check for cancellation before trying next mirror select { case <-ctx.Done(): return ctx.Err() default: } if err := downloader.URI(mirror).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { success = true xlog.Debug("Downloaded backend", "uri", config.URI, "backendPath", backendPath) break } } // Try fallback: replace latestTag + "-" with masterTag + "-" in the URI fallbackURI := strings.Replace(string(config.URI), latestTag+"-", masterTag+"-", 1) if fallbackURI != string(config.URI) { xlog.Debug("Trying fallback URI", "original", config.URI, "fallback", fallbackURI) if err := downloader.URI(fallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { xlog.Debug("Downloaded backend using fallback URI", "uri", fallbackURI, "backendPath", backendPath) success = true } else { // Try another fallback: add "-" + devSuffix suffix to the backend name // For example: master-gpu-nvidia-cuda-13-ace-step -> master-gpu-nvidia-cuda-13-ace-step-development if !strings.Contains(fallbackURI, "-"+devSuffix) { // Extract backend name from URI and add -development parts := strings.Split(fallbackURI, "-") if len(parts) >= 2 { // Find where the backend name ends (usually the last part before the tag) // Pattern: quay.io/go-skynet/local-ai-backends:master-gpu-nvidia-cuda-13-ace-step lastDash := strings.LastIndex(fallbackURI, "-") if lastDash > 0 { devFallbackURI := fallbackURI[:lastDash] + "-" + devSuffix xlog.Debug("Trying development fallback URI", "fallback", devFallbackURI) if err := downloader.URI(devFallbackURI).DownloadFileWithContext(ctx, backendPath, "", 1, 1, downloadStatus); err == nil { xlog.Debug("Downloaded backend using development fallback URI", "uri", devFallbackURI, "backendPath", backendPath) success = true } } } } } } if !success { xlog.Error("Failed to download backend", "uri", config.URI, "backendPath", backendPath, "error", err) return fmt.Errorf("failed to download backend %q: %v", config.URI, err) } } else { xlog.Debug("Downloaded backend", "uri", config.URI, "backendPath", backendPath) } } // sanity check - check if runfile is present runFile := filepath.Join(backendPath, runFile) if _, err := os.Stat(runFile); os.IsNotExist(err) { xlog.Error("Run file not found", "runFile", runFile) return fmt.Errorf("not a valid backend: run file not found %q", runFile) } // Create metadata for the backend metadata := &BackendMetadata{ Name: name, GalleryURL: config.Gallery.URL, InstalledAt: time.Now().Format(time.RFC3339), } if config.Alias != "" { metadata.Alias = config.Alias } if err := writeBackendMetadata(backendPath, metadata); err != nil { return fmt.Errorf("failed to write metadata for backend %q: %v", name, err) } return RegisterBackends(systemState, modelLoader) } func DeleteBackendFromSystem(systemState *system.SystemState, name string) error { backends, err := ListSystemBackends(systemState) if err != nil { return err } backend, ok := backends.Get(name) if !ok { return fmt.Errorf("backend %q not found", name) } if backend.IsSystem { return fmt.Errorf("system backend %q cannot be deleted", name) } backendDirectory := filepath.Join(systemState.Backend.BackendsPath, name) // check if the backend dir exists if _, err := os.Stat(backendDirectory); os.IsNotExist(err) { // if doesn't exist, it might be an alias, so we need to check if we have a matching alias in // all the backends in the basePath backends, err := os.ReadDir(systemState.Backend.BackendsPath) if err != nil { return err } foundBackend := false for _, backend := range backends { if backend.IsDir() { metadata, err := readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, backend.Name())) if err != nil { return err } if metadata != nil && metadata.Alias == name { backendDirectory = filepath.Join(systemState.Backend.BackendsPath, backend.Name()) foundBackend = true break } } } // If no backend found, return successfully (idempotent behavior) if !foundBackend { return fmt.Errorf("no backend found with name %q", name) } } // If it's a meta backend, delete also associated backend metadata, err := readBackendMetadata(backendDirectory) if err != nil { return err } if metadata != nil && metadata.MetaBackendFor != "" { metaBackendDirectory := filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor) xlog.Debug("Deleting meta backend", "backendDirectory", metaBackendDirectory) if _, err := os.Stat(metaBackendDirectory); os.IsNotExist(err) { return fmt.Errorf("meta backend %q not found", metadata.MetaBackendFor) } os.RemoveAll(metaBackendDirectory) } return os.RemoveAll(backendDirectory) } type SystemBackend struct { Name string RunFile string IsMeta bool IsSystem bool Metadata *BackendMetadata } type SystemBackends map[string]SystemBackend func (b SystemBackends) Exists(name string) bool { _, ok := b[name] return ok } func (b SystemBackends) Get(name string) (SystemBackend, bool) { backend, ok := b[name] return backend, ok } func (b SystemBackends) GetAll() []SystemBackend { backends := make([]SystemBackend, 0) for _, backend := range b { backends = append(backends, backend) } return backends } func ListSystemBackends(systemState *system.SystemState) (SystemBackends, error) { // Gather backends from system and user paths, then resolve alias conflicts by capability. backends := make(SystemBackends) // System-provided backends if systemBackends, err := os.ReadDir(systemState.Backend.BackendsSystemPath); err == nil { for _, systemBackend := range systemBackends { if systemBackend.IsDir() { run := filepath.Join(systemState.Backend.BackendsSystemPath, systemBackend.Name(), runFile) if _, err := os.Stat(run); err == nil { backends[systemBackend.Name()] = SystemBackend{ Name: systemBackend.Name(), RunFile: run, IsMeta: false, IsSystem: true, Metadata: nil, } } } } } else if !errors.Is(err, os.ErrNotExist) { xlog.Warn("Failed to read system backends, proceeding with user-managed backends", "error", err) } else if errors.Is(err, os.ErrNotExist) { xlog.Debug("No system backends found") } // User-managed backends and alias collection entries, err := os.ReadDir(systemState.Backend.BackendsPath) if err != nil { return nil, err } aliasGroups := make(map[string][]backendCandidate) metaMap := make(map[string]*BackendMetadata) for _, e := range entries { if !e.IsDir() { continue } dir := e.Name() run := filepath.Join(systemState.Backend.BackendsPath, dir, runFile) var metadata *BackendMetadata metadataPath := filepath.Join(systemState.Backend.BackendsPath, dir, metadataFile) if _, err := os.Stat(metadataPath); os.IsNotExist(err) { metadata = &BackendMetadata{Name: dir} } else { m, rerr := readBackendMetadata(filepath.Join(systemState.Backend.BackendsPath, dir)) if rerr != nil { return nil, rerr } if m == nil { metadata = &BackendMetadata{Name: dir} } else { metadata = m } } metaMap[dir] = metadata // Concrete-backend entry if _, err := os.Stat(run); err == nil { backends[dir] = SystemBackend{ Name: dir, RunFile: run, IsMeta: false, Metadata: metadata, } } // Alias candidates if metadata.Alias != "" { aliasGroups[metadata.Alias] = append(aliasGroups[metadata.Alias], backendCandidate{name: dir, runFile: run}) } // Meta backends indirection if metadata.MetaBackendFor != "" { backends[metadata.Name] = SystemBackend{ Name: metadata.Name, RunFile: filepath.Join(systemState.Backend.BackendsPath, metadata.MetaBackendFor, runFile), IsMeta: true, Metadata: metadata, } } } // Resolve aliases using system capability preferences tokens := systemState.BackendPreferenceTokens() for alias, cands := range aliasGroups { chosen := backendCandidate{} // Try preference tokens for _, t := range tokens { for _, c := range cands { if strings.Contains(strings.ToLower(c.name), t) && c.runFile != "" { chosen = c break } } if chosen.runFile != "" { break } } // Fallback: first runnable if chosen.runFile == "" { for _, c := range cands { if c.runFile != "" { chosen = c break } } } if chosen.runFile == "" { continue } md := metaMap[chosen.name] backends[alias] = SystemBackend{ Name: alias, RunFile: chosen.runFile, IsMeta: false, Metadata: md, } } return backends, nil } func RegisterBackends(systemState *system.SystemState, modelLoader *model.ModelLoader) error { backends, err := ListSystemBackends(systemState) if err != nil { return err } for _, backend := range backends { xlog.Debug("Registering backend", "name", backend.Name, "runFile", backend.RunFile) modelLoader.SetExternalBackend(backend.Name, backend.RunFile) } return nil } ================================================ FILE: core/gallery/backends_test.go ================================================ package gallery import ( "context" "encoding/json" "os" "path/filepath" "runtime" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" ) const ( testImage = "quay.io/mudler/tests:localai-backend-test" ) var _ = Describe("Runtime capability-based backend selection", func() { var tempDir string BeforeEach(func() { var err error tempDir, err = os.MkdirTemp("", "gallery-caps-*") Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { os.RemoveAll(tempDir) }) It("ListSystemBackends prefers optimal alias candidate", func() { // Arrange two installed backends sharing the same alias must := func(err error) { Expect(err).NotTo(HaveOccurred()) } cpuDir := filepath.Join(tempDir, "cpu-llama-cpp") must(os.MkdirAll(cpuDir, 0o750)) cpuMeta := &BackendMetadata{Alias: "llama-cpp", Name: "cpu-llama-cpp"} b, _ := json.Marshal(cpuMeta) must(os.WriteFile(filepath.Join(cpuDir, "metadata.json"), b, 0o644)) must(os.WriteFile(filepath.Join(cpuDir, "run.sh"), []byte(""), 0o755)) cudaDir := filepath.Join(tempDir, "cuda12-llama-cpp") must(os.MkdirAll(cudaDir, 0o750)) cudaMeta := &BackendMetadata{Alias: "llama-cpp", Name: "cuda12-llama-cpp"} b, _ = json.Marshal(cudaMeta) must(os.WriteFile(filepath.Join(cudaDir, "metadata.json"), b, 0o644)) must(os.WriteFile(filepath.Join(cudaDir, "run.sh"), []byte(""), 0o755)) // Default system: alias should point to CPU sysDefault, err := system.GetSystemState( system.WithBackendPath(tempDir), ) must(err) sysDefault.GPUVendor = "" // force default selection backs, err := ListSystemBackends(sysDefault) must(err) aliasBack, ok := backs.Get("llama-cpp") Expect(ok).To(BeTrue()) Expect(aliasBack.RunFile).To(Equal(filepath.Join(cpuDir, "run.sh"))) // concrete entries remain _, ok = backs.Get("cpu-llama-cpp") Expect(ok).To(BeTrue()) _, ok = backs.Get("cuda12-llama-cpp") Expect(ok).To(BeTrue()) // NVIDIA system: alias should point to CUDA // Force capability to nvidia to make the test deterministic on platforms like darwin/arm64 (which default to metal) os.Setenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY", "nvidia") defer os.Unsetenv("LOCALAI_FORCE_META_BACKEND_CAPABILITY") sysNvidia, err := system.GetSystemState( system.WithBackendPath(tempDir), ) must(err) sysNvidia.GPUVendor = "nvidia" sysNvidia.VRAM = 8 * 1024 * 1024 * 1024 backs, err = ListSystemBackends(sysNvidia) must(err) aliasBack, ok = backs.Get("llama-cpp") Expect(ok).To(BeTrue()) Expect(aliasBack.RunFile).To(Equal(filepath.Join(cudaDir, "run.sh"))) }) }) var _ = Describe("Gallery Backends", func() { var ( tempDir string galleries []config.Gallery ml *model.ModelLoader systemState *system.SystemState ) BeforeEach(func() { var err error tempDir, err = os.MkdirTemp("", "gallery-test-*") Expect(err).NotTo(HaveOccurred()) // Setup test galleries galleries = []config.Gallery{ { Name: "test-gallery", URL: "https://gist.githubusercontent.com/mudler/71d5376bc2aa168873fa519fa9f4bd56/raw/0557f9c640c159fa8e4eab29e8d98df6a3d6e80f/backend-gallery.yaml", }, } systemState, err = system.GetSystemState(system.WithBackendPath(tempDir)) Expect(err).NotTo(HaveOccurred()) ml = model.NewModelLoader(systemState) }) AfterEach(func() { os.RemoveAll(tempDir) }) Describe("InstallBackendFromGallery", func() { It("should return error when backend is not found", func() { err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "non-existent", nil, true) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("no backend found with name \"non-existent\"")) }) It("should install backend from gallery", func() { err := InstallBackendFromGallery(context.TODO(), galleries, systemState, ml, "test-backend", nil, true) Expect(err).ToNot(HaveOccurred()) Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile()) }) }) Describe("Meta Backends", func() { It("should identify meta backends correctly", func() { metaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "meta-backend", }, CapabilitiesMap: map[string]string{ "nvidia": "nvidia-backend", "amd": "amd-backend", "intel": "intel-backend", }, } Expect(metaBackend.IsMeta()).To(BeTrue()) regularBackend := &GalleryBackend{ Metadata: Metadata{ Name: "regular-backend", }, URI: testImage, } Expect(regularBackend.IsMeta()).To(BeFalse()) emptyMetaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "empty-meta-backend", }, CapabilitiesMap: map[string]string{}, } Expect(emptyMetaBackend.IsMeta()).To(BeFalse()) nilMetaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "nil-meta-backend", }, CapabilitiesMap: nil, } Expect(nilMetaBackend.IsMeta()).To(BeFalse()) }) It("should check IsCompatibleWith correctly for meta backends", func() { metaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "meta-backend", }, CapabilitiesMap: map[string]string{ "nvidia": "nvidia-backend", "amd": "amd-backend", "default": "default-backend", }, } // Test with nil state - should be compatible Expect(metaBackend.IsCompatibleWith(nil)).To(BeTrue()) // Test with NVIDIA system - should be compatible (has nvidia key) nvidiaState := &system.SystemState{GPUVendor: "nvidia", VRAM: 8 * 1024 * 1024 * 1024} Expect(metaBackend.IsCompatibleWith(nvidiaState)).To(BeTrue()) // Test with default (no GPU) - should be compatible (has default key) defaultState := &system.SystemState{} Expect(metaBackend.IsCompatibleWith(defaultState)).To(BeTrue()) }) Describe("IsCompatibleWith for concrete backends", func() { Context("CPU backends", func() { It("should be compatible on all systems", func() { cpuBackend := &GalleryBackend{ Metadata: Metadata{ Name: "cpu-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-cpu-llama-cpp", } Expect(cpuBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue()) Expect(cpuBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue()) Expect(cpuBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue()) }) }) Context("Darwin/Metal backends", func() { When("running on darwin", func() { BeforeEach(func() { if runtime.GOOS != "darwin" { Skip("Skipping darwin-specific tests on non-darwin system") } }) It("should be compatible for MLX backend", func() { mlxBackend := &GalleryBackend{ Metadata: Metadata{ Name: "mlx", }, URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx", } Expect(mlxBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue()) }) It("should be compatible for metal-llama-cpp backend", func() { metalBackend := &GalleryBackend{ Metadata: Metadata{ Name: "metal-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp", } Expect(metalBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue()) }) }) When("running on non-darwin", func() { BeforeEach(func() { if runtime.GOOS == "darwin" { Skip("Skipping non-darwin-specific tests on darwin system") } }) It("should NOT be compatible for MLX backend", func() { mlxBackend := &GalleryBackend{ Metadata: Metadata{ Name: "mlx", }, URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-mlx", } Expect(mlxBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse()) }) It("should NOT be compatible for metal-llama-cpp backend", func() { metalBackend := &GalleryBackend{ Metadata: Metadata{ Name: "metal-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-metal-darwin-arm64-llama-cpp", } Expect(metalBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse()) }) }) }) Context("NVIDIA/CUDA backends", func() { When("running on non-darwin", func() { BeforeEach(func() { if runtime.GOOS == "darwin" { Skip("Skipping CUDA tests on darwin system") } }) It("should NOT be compatible without nvidia GPU", func() { cudaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "cuda12-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-llama-cpp", } Expect(cudaBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse()) Expect(cudaBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeFalse()) }) It("should be compatible with nvidia GPU", func() { cudaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "cuda12-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-12-llama-cpp", } Expect(cudaBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue()) }) It("should be compatible with cuda13 backend on nvidia GPU", func() { cuda13Backend := &GalleryBackend{ Metadata: Metadata{ Name: "cuda13-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-nvidia-cuda-13-llama-cpp", } Expect(cuda13Backend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue()) }) }) }) Context("AMD/ROCm backends", func() { When("running on non-darwin", func() { BeforeEach(func() { if runtime.GOOS == "darwin" { Skip("Skipping AMD/ROCm tests on darwin system") } }) It("should NOT be compatible without AMD GPU", func() { rocmBackend := &GalleryBackend{ Metadata: Metadata{ Name: "rocm-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-llama-cpp", } Expect(rocmBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse()) Expect(rocmBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeFalse()) }) It("should be compatible with AMD GPU", func() { rocmBackend := &GalleryBackend{ Metadata: Metadata{ Name: "rocm-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-rocm-hipblas-llama-cpp", } Expect(rocmBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue()) }) It("should be compatible with hipblas backend on AMD GPU", func() { hipBackend := &GalleryBackend{ Metadata: Metadata{ Name: "hip-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-hip-llama-cpp", } Expect(hipBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.AMD, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue()) }) }) }) Context("Intel/SYCL backends", func() { When("running on non-darwin", func() { BeforeEach(func() { if runtime.GOOS == "darwin" { Skip("Skipping Intel/SYCL tests on darwin system") } }) It("should NOT be compatible without Intel GPU", func() { intelBackend := &GalleryBackend{ Metadata: Metadata{ Name: "intel-sycl-f16-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-llama-cpp", } Expect(intelBackend.IsCompatibleWith(&system.SystemState{})).To(BeFalse()) Expect(intelBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Nvidia, VRAM: 8 * 1024 * 1024 * 1024})).To(BeFalse()) }) It("should be compatible with Intel GPU", func() { intelBackend := &GalleryBackend{ Metadata: Metadata{ Name: "intel-sycl-f16-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f16-llama-cpp", } Expect(intelBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Intel, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue()) }) It("should be compatible with intel-sycl-f32 backend on Intel GPU", func() { intelF32Backend := &GalleryBackend{ Metadata: Metadata{ Name: "intel-sycl-f32-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-intel-sycl-f32-llama-cpp", } Expect(intelF32Backend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Intel, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue()) }) It("should be compatible with intel-transformers backend on Intel GPU", func() { intelTransformersBackend := &GalleryBackend{ Metadata: Metadata{ Name: "intel-transformers", }, URI: "quay.io/go-skynet/local-ai-backends:latest-intel-transformers", } Expect(intelTransformersBackend.IsCompatibleWith(&system.SystemState{GPUVendor: system.Intel, VRAM: 8 * 1024 * 1024 * 1024})).To(BeTrue()) }) }) }) Context("Vulkan backends", func() { It("should be compatible on CPU-only systems", func() { // Vulkan backends don't have a specific GPU vendor requirement in the current logic // They are compatible if no other GPU-specific pattern matches vulkanBackend := &GalleryBackend{ Metadata: Metadata{ Name: "vulkan-llama-cpp", }, URI: "quay.io/go-skynet/local-ai-backends:latest-gpu-vulkan-llama-cpp", } // Vulkan doesn't have vendor-specific filtering in current implementation Expect(vulkanBackend.IsCompatibleWith(&system.SystemState{})).To(BeTrue()) }) }) }) It("should find best backend from meta based on system capabilities", func() { metaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "meta-backend", }, CapabilitiesMap: map[string]string{ "nvidia": "nvidia-backend", "amd": "amd-backend", "intel": "intel-backend", "metal": "metal-backend", "default": "default-backend", }, } nvidiaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "nvidia-backend", }, URI: testImage, } amdBackend := &GalleryBackend{ Metadata: Metadata{ Name: "amd-backend", }, URI: testImage, } metalBackend := &GalleryBackend{ Metadata: Metadata{ Name: "metal-backend", }, URI: testImage, } defaultBackend := &GalleryBackend{ Metadata: Metadata{ Name: "default-backend", }, URI: testImage, } backends := GalleryElements[*GalleryBackend]{nvidiaBackend, amdBackend, metalBackend, defaultBackend} if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { metal := &system.SystemState{} bestBackend := metaBackend.FindBestBackendFromMeta(metal, backends) Expect(bestBackend).To(Equal(metalBackend)) } else { // Test with NVIDIA system state nvidiaSystemState := &system.SystemState{GPUVendor: "nvidia", VRAM: 1000000000000} bestBackend := metaBackend.FindBestBackendFromMeta(nvidiaSystemState, backends) Expect(bestBackend).To(Equal(nvidiaBackend)) // Test with AMD system state amdSystemState := &system.SystemState{GPUVendor: "amd", VRAM: 1000000000000} bestBackend = metaBackend.FindBestBackendFromMeta(amdSystemState, backends) Expect(bestBackend).To(Equal(amdBackend)) // Test with default system state (not enough VRAM) defaultSystemState := &system.SystemState{GPUVendor: "amd"} bestBackend = metaBackend.FindBestBackendFromMeta(defaultSystemState, backends) Expect(bestBackend).To(Equal(defaultBackend)) // Test with default system state defaultSystemState = &system.SystemState{GPUVendor: "default"} bestBackend = metaBackend.FindBestBackendFromMeta(defaultSystemState, backends) Expect(bestBackend).To(Equal(defaultBackend)) backends = GalleryElements[*GalleryBackend]{nvidiaBackend, amdBackend, metalBackend} // Test with unsupported GPU vendor unsupportedSystemState := &system.SystemState{GPUVendor: "unsupported"} bestBackend = metaBackend.FindBestBackendFromMeta(unsupportedSystemState, backends) Expect(bestBackend).To(BeNil()) } }) It("should handle meta backend deletion correctly", func() { if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { Skip("Skipping test on darwin/arm64") } metaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "meta-backend", }, CapabilitiesMap: map[string]string{ "nvidia": "nvidia-backend", "amd": "amd-backend", "intel": "intel-backend", }, } nvidiaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "nvidia-backend", }, URI: testImage, } amdBackend := &GalleryBackend{ Metadata: Metadata{ Name: "amd-backend", }, URI: testImage, } gallery := config.Gallery{ Name: "test-gallery", URL: "file://" + filepath.Join(tempDir, "backend-gallery.yaml"), } galleryBackend := GalleryBackends{amdBackend, nvidiaBackend, metaBackend} dat, err := yaml.Marshal(galleryBackend) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(tempDir, "backend-gallery.yaml"), dat, 0644) Expect(err).NotTo(HaveOccurred()) // Test with NVIDIA system state nvidiaSystemState := &system.SystemState{ GPUVendor: "nvidia", VRAM: 1000000000000, Backend: system.Backend{BackendsPath: tempDir}, } err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true) Expect(err).NotTo(HaveOccurred()) metaBackendPath := filepath.Join(tempDir, "meta-backend") Expect(metaBackendPath).To(BeADirectory()) concreteBackendPath := filepath.Join(tempDir, "nvidia-backend") Expect(concreteBackendPath).To(BeADirectory()) systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) allBackends, err := ListSystemBackends(systemState) Expect(err).NotTo(HaveOccurred()) Expect(allBackends).To(HaveKey("meta-backend")) Expect(allBackends).To(HaveKey("nvidia-backend")) // Delete meta backend by name err = DeleteBackendFromSystem(systemState, "meta-backend") Expect(err).NotTo(HaveOccurred()) // Verify meta backend directory is deleted Expect(metaBackendPath).NotTo(BeADirectory()) // Verify concrete backend directory is deleted Expect(concreteBackendPath).NotTo(BeADirectory()) }) It("should handle meta backend deletion correctly with aliases", func() { if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { Skip("Skipping test on darwin/arm64") } metaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "meta-backend", }, Alias: "backend-alias", CapabilitiesMap: map[string]string{ "nvidia": "nvidia-backend", "amd": "amd-backend", "intel": "intel-backend", }, } nvidiaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "nvidia-backend", }, Alias: "backend-alias", URI: testImage, } amdBackend := &GalleryBackend{ Metadata: Metadata{ Name: "amd-backend", }, Alias: "backend-alias", URI: testImage, } gallery := config.Gallery{ Name: "test-gallery", URL: "file://" + filepath.Join(tempDir, "backend-gallery.yaml"), } galleryBackend := GalleryBackends{amdBackend, nvidiaBackend, metaBackend} dat, err := yaml.Marshal(galleryBackend) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(tempDir, "backend-gallery.yaml"), dat, 0644) Expect(err).NotTo(HaveOccurred()) // Test with NVIDIA system state nvidiaSystemState := &system.SystemState{ GPUVendor: "nvidia", VRAM: 1000000000000, Backend: system.Backend{BackendsPath: tempDir}, } err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true) Expect(err).NotTo(HaveOccurred()) metaBackendPath := filepath.Join(tempDir, "meta-backend") Expect(metaBackendPath).To(BeADirectory()) concreteBackendPath := filepath.Join(tempDir, "nvidia-backend") Expect(concreteBackendPath).To(BeADirectory()) systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) allBackends, err := ListSystemBackends(systemState) Expect(err).NotTo(HaveOccurred()) Expect(allBackends).To(HaveKey("meta-backend")) Expect(allBackends).To(HaveKey("nvidia-backend")) mback, exists := allBackends.Get("meta-backend") Expect(exists).To(BeTrue()) Expect(mback.IsMeta).To(BeTrue()) Expect(mback.Metadata.MetaBackendFor).To(Equal("nvidia-backend")) // Delete meta backend by name err = DeleteBackendFromSystem(systemState, "meta-backend") Expect(err).NotTo(HaveOccurred()) // Verify meta backend directory is deleted Expect(metaBackendPath).NotTo(BeADirectory()) // Verify concrete backend directory is deleted Expect(concreteBackendPath).NotTo(BeADirectory()) }) It("should handle meta backend deletion correctly with aliases pointing to the same backend", func() { if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { Skip("Skipping test on darwin/arm64") } metaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "meta-backend", }, Alias: "meta-backend", CapabilitiesMap: map[string]string{ "nvidia": "nvidia-backend", "amd": "amd-backend", "intel": "intel-backend", }, } nvidiaBackend := &GalleryBackend{ Metadata: Metadata{ Name: "nvidia-backend", }, Alias: "meta-backend", URI: testImage, } amdBackend := &GalleryBackend{ Metadata: Metadata{ Name: "amd-backend", }, Alias: "meta-backend", URI: testImage, } gallery := config.Gallery{ Name: "test-gallery", URL: "file://" + filepath.Join(tempDir, "backend-gallery.yaml"), } galleryBackend := GalleryBackends{amdBackend, nvidiaBackend, metaBackend} dat, err := yaml.Marshal(galleryBackend) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(tempDir, "backend-gallery.yaml"), dat, 0644) Expect(err).NotTo(HaveOccurred()) // Test with NVIDIA system state nvidiaSystemState := &system.SystemState{ GPUVendor: "nvidia", VRAM: 1000000000000, Backend: system.Backend{BackendsPath: tempDir}, } err = InstallBackendFromGallery(context.TODO(), []config.Gallery{gallery}, nvidiaSystemState, ml, "meta-backend", nil, true) Expect(err).NotTo(HaveOccurred()) metaBackendPath := filepath.Join(tempDir, "meta-backend") Expect(metaBackendPath).To(BeADirectory()) concreteBackendPath := filepath.Join(tempDir, "nvidia-backend") Expect(concreteBackendPath).To(BeADirectory()) systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) allBackends, err := ListSystemBackends(systemState) Expect(err).NotTo(HaveOccurred()) Expect(allBackends).To(HaveKey("meta-backend")) Expect(allBackends).To(HaveKey("nvidia-backend")) mback, exists := allBackends.Get("meta-backend") Expect(exists).To(BeTrue()) Expect(mback.RunFile).To(Equal(filepath.Join(tempDir, "nvidia-backend", "run.sh"))) // Delete meta backend by name err = DeleteBackendFromSystem(systemState, "meta-backend") Expect(err).NotTo(HaveOccurred()) // Verify meta backend directory is deleted Expect(metaBackendPath).NotTo(BeADirectory()) // Verify concrete backend directory is deleted Expect(concreteBackendPath).NotTo(BeADirectory()) }) It("should list meta backends correctly in system backends", func() { // Create a meta backend directory with metadata metaBackendPath := filepath.Join(tempDir, "meta-backend") err := os.MkdirAll(metaBackendPath, 0750) Expect(err).NotTo(HaveOccurred()) // Create metadata file pointing to concrete backend metadata := &BackendMetadata{ MetaBackendFor: "concrete-backend", Name: "meta-backend", InstalledAt: "2023-01-01T00:00:00Z", } metadataData, err := json.Marshal(metadata) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(metaBackendPath, "metadata.json"), metadataData, 0644) Expect(err).NotTo(HaveOccurred()) // Create the concrete backend directory with run.sh concreteBackendPath := filepath.Join(tempDir, "concrete-backend") err = os.MkdirAll(concreteBackendPath, 0750) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(concreteBackendPath, "metadata.json"), []byte("{}"), 0755) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(concreteBackendPath, "run.sh"), []byte(""), 0755) Expect(err).NotTo(HaveOccurred()) // List system backends systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) backends, err := ListSystemBackends(systemState) Expect(err).NotTo(HaveOccurred()) metaBackend, exists := backends.Get("meta-backend") concreteBackendRunFile := filepath.Join(tempDir, "concrete-backend", "run.sh") // Should include both the meta backend name and concrete backend name Expect(exists).To(BeTrue()) Expect(backends.Exists("concrete-backend")).To(BeTrue()) // meta-backend should be empty Expect(metaBackend.IsMeta).To(BeTrue()) Expect(metaBackend.RunFile).To(Equal(concreteBackendRunFile)) // concrete-backend should point to its own run.sh concreteBackend, exists := backends.Get("concrete-backend") Expect(exists).To(BeTrue()) Expect(concreteBackend.RunFile).To(Equal(concreteBackendRunFile)) }) }) Describe("InstallBackend", func() { It("should create base path if it doesn't exist", func() { newPath := filepath.Join(tempDir, "new-path") backend := GalleryBackend{ Metadata: Metadata{ Name: "test-backend", }, URI: "test-uri", } systemState, err := system.GetSystemState( system.WithBackendPath(newPath), ) Expect(err).NotTo(HaveOccurred()) err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) Expect(newPath).To(BeADirectory()) Expect(err).To(HaveOccurred()) // Will fail due to invalid URI, but path should be created }) It("should overwrite existing backend", func() { if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { Skip("Skipping test on darwin/arm64") } newPath := filepath.Join(tempDir, "test-backend") // Create a dummy backend directory err := os.MkdirAll(newPath, 0750) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(newPath, "metadata.json"), []byte("foo"), 0644) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(newPath, "run.sh"), []byte(""), 0644) Expect(err).NotTo(HaveOccurred()) backend := GalleryBackend{ Metadata: Metadata{ Name: "test-backend", }, URI: "quay.io/mudler/tests:localai-backend-test", Alias: "test-alias", } systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) Expect(err).ToNot(HaveOccurred()) Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile()) dat, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json")) Expect(err).ToNot(HaveOccurred()) Expect(string(dat)).ToNot(Equal("foo")) }) It("should overwrite existing backend", func() { if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { Skip("Skipping test on darwin/arm64") } newPath := filepath.Join(tempDir, "test-backend") // Create a dummy backend directory err := os.MkdirAll(newPath, 0750) Expect(err).NotTo(HaveOccurred()) backend := GalleryBackend{ Metadata: Metadata{ Name: "test-backend", }, URI: "quay.io/mudler/tests:localai-backend-test", Alias: "test-alias", } systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).ToNot(BeARegularFile()) err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) Expect(err).ToNot(HaveOccurred()) Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile()) }) It("should create alias file when specified", func() { if runtime.GOOS == "darwin" && runtime.GOARCH == "arm64" { Skip("Skipping test on darwin/arm64") } backend := GalleryBackend{ Metadata: Metadata{ Name: "test-backend", }, URI: "quay.io/mudler/tests:localai-backend-test", Alias: "test-alias", } systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) err = InstallBackend(context.TODO(), systemState, ml, &backend, nil) Expect(err).ToNot(HaveOccurred()) Expect(filepath.Join(tempDir, "test-backend", "metadata.json")).To(BeARegularFile()) // Read and verify metadata metadataData, err := os.ReadFile(filepath.Join(tempDir, "test-backend", "metadata.json")) Expect(err).ToNot(HaveOccurred()) var metadata BackendMetadata err = json.Unmarshal(metadataData, &metadata) Expect(err).ToNot(HaveOccurred()) Expect(metadata.Alias).To(Equal("test-alias")) Expect(metadata.Name).To(Equal("test-backend")) Expect(filepath.Join(tempDir, "test-backend", "run.sh")).To(BeARegularFile()) // Check that the alias was recognized backends, err := ListSystemBackends(systemState) Expect(err).ToNot(HaveOccurred()) aliasBackend, exists := backends.Get("test-alias") Expect(exists).To(BeTrue()) Expect(aliasBackend.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh"))) testB, exists := backends.Get("test-backend") Expect(exists).To(BeTrue()) Expect(testB.RunFile).To(Equal(filepath.Join(tempDir, "test-backend", "run.sh"))) }) }) Describe("DeleteBackendFromSystem", func() { It("should delete backend directory", func() { backendName := "test-backend" backendPath := filepath.Join(tempDir, backendName) // Create a dummy backend directory err := os.MkdirAll(backendPath, 0750) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(backendPath, "metadata.json"), []byte("{}"), 0644) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0644) Expect(err).NotTo(HaveOccurred()) systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) err = DeleteBackendFromSystem(systemState, backendName) Expect(err).NotTo(HaveOccurred()) Expect(backendPath).NotTo(BeADirectory()) }) It("should not error when backend doesn't exist", func() { systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) err = DeleteBackendFromSystem(systemState, "non-existent") Expect(err).To(HaveOccurred()) }) }) Describe("ListSystemBackends", func() { It("should list backends without aliases", func() { // Create some dummy backend directories backendNames := []string{"backend1", "backend2", "backend3"} for _, name := range backendNames { err := os.MkdirAll(filepath.Join(tempDir, name), 0750) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(tempDir, name, "metadata.json"), []byte("{}"), 0755) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(tempDir, name, "run.sh"), []byte(""), 0755) Expect(err).NotTo(HaveOccurred()) } systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) backends, err := ListSystemBackends(systemState) Expect(err).NotTo(HaveOccurred()) Expect(backends).To(HaveLen(len(backendNames))) for _, name := range backendNames { backend, exists := backends.Get(name) Expect(exists).To(BeTrue()) Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, name, "run.sh"))) } }) It("should handle backends with aliases", func() { backendName := "backend1" alias := "alias1" backendPath := filepath.Join(tempDir, backendName) // Create backend directory err := os.MkdirAll(backendPath, 0750) Expect(err).NotTo(HaveOccurred()) // Create metadata file with alias metadata := &BackendMetadata{ Alias: alias, Name: backendName, InstalledAt: "2023-01-01T00:00:00Z", } metadataData, err := json.Marshal(metadata) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(backendPath, "metadata.json"), metadataData, 0644) Expect(err).NotTo(HaveOccurred()) err = os.WriteFile(filepath.Join(backendPath, "run.sh"), []byte(""), 0755) Expect(err).NotTo(HaveOccurred()) systemState, err := system.GetSystemState( system.WithBackendPath(tempDir), ) Expect(err).NotTo(HaveOccurred()) backends, err := ListSystemBackends(systemState) Expect(err).NotTo(HaveOccurred()) backend, exists := backends.Get(alias) Expect(exists).To(BeTrue()) Expect(backend.RunFile).To(Equal(filepath.Join(tempDir, backendName, "run.sh"))) }) It("should return error when base path doesn't exist", func() { systemState, err := system.GetSystemState( system.WithBackendPath("foobardir"), ) Expect(err).NotTo(HaveOccurred()) _, err = ListSystemBackends(systemState) Expect(err).To(HaveOccurred()) }) }) }) ================================================ FILE: core/gallery/gallery.go ================================================ package gallery import ( "context" "fmt" "os" "path/filepath" "sort" "strings" "time" "github.com/lithammer/fuzzysearch/fuzzy" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/xsync" "github.com/mudler/xlog" "gopkg.in/yaml.v3" ) func GetGalleryConfigFromURL[T any](url string, basePath string) (T, error) { var config T uri := downloader.URI(url) err := uri.ReadWithCallback(basePath, func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) if err != nil { xlog.Error("failed to get gallery config for url", "error", err, "url", url) return config, err } return config, nil } func GetGalleryConfigFromURLWithContext[T any](ctx context.Context, url string, basePath string) (T, error) { var config T uri := downloader.URI(url) err := uri.ReadWithAuthorizationAndCallback(ctx, basePath, "", func(url string, d []byte) error { return yaml.Unmarshal(d, &config) }) if err != nil { xlog.Error("failed to get gallery config for url", "error", err, "url", url) return config, err } return config, nil } func ReadConfigFile[T any](filePath string) (*T, error) { // Read the YAML file yamlFile, err := os.ReadFile(filePath) if err != nil { return nil, fmt.Errorf("failed to read YAML file: %v", err) } // Unmarshal YAML data into a Config struct var config T err = yaml.Unmarshal(yamlFile, &config) if err != nil { return nil, fmt.Errorf("failed to unmarshal YAML: %v", err) } return &config, nil } type GalleryElement interface { SetGallery(gallery config.Gallery) SetInstalled(installed bool) GetName() string GetDescription() string GetTags() []string GetInstalled() bool GetLicense() string GetGallery() config.Gallery } type GalleryElements[T GalleryElement] []T func (gm GalleryElements[T]) Search(term string) GalleryElements[T] { var filteredModels GalleryElements[T] term = strings.ToLower(term) for _, m := range gm { if fuzzy.Match(term, strings.ToLower(m.GetName())) || fuzzy.Match(term, strings.ToLower(m.GetGallery().Name)) || strings.Contains(strings.ToLower(m.GetName()), term) || strings.Contains(strings.ToLower(m.GetDescription()), term) || strings.Contains(strings.ToLower(m.GetGallery().Name), term) || strings.Contains(strings.ToLower(strings.Join(m.GetTags(), ",")), term) { filteredModels = append(filteredModels, m) } } return filteredModels } func (gm GalleryElements[T]) SortByName(sortOrder string) GalleryElements[T] { sort.Slice(gm, func(i, j int) bool { if sortOrder == "asc" { return strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName()) } else { return strings.ToLower(gm[i].GetName()) > strings.ToLower(gm[j].GetName()) } }) return gm } func (gm GalleryElements[T]) SortByRepository(sortOrder string) GalleryElements[T] { sort.Slice(gm, func(i, j int) bool { if sortOrder == "asc" { return strings.ToLower(gm[i].GetGallery().Name) < strings.ToLower(gm[j].GetGallery().Name) } else { return strings.ToLower(gm[i].GetGallery().Name) > strings.ToLower(gm[j].GetGallery().Name) } }) return gm } func (gm GalleryElements[T]) SortByLicense(sortOrder string) GalleryElements[T] { sort.Slice(gm, func(i, j int) bool { licenseI := gm[i].GetLicense() licenseJ := gm[j].GetLicense() var result bool if licenseI == "" && licenseJ != "" { return sortOrder == "desc" } else if licenseI != "" && licenseJ == "" { return sortOrder == "asc" } else if licenseI == "" && licenseJ == "" { return false } else { result = strings.ToLower(licenseI) < strings.ToLower(licenseJ) } if sortOrder == "desc" { return !result } else { return result } }) return gm } func (gm GalleryElements[T]) SortByInstalled(sortOrder string) GalleryElements[T] { sort.Slice(gm, func(i, j int) bool { var result bool // Sort by installed status: installed items first (true > false) if gm[i].GetInstalled() != gm[j].GetInstalled() { result = gm[i].GetInstalled() } else { result = strings.ToLower(gm[i].GetName()) < strings.ToLower(gm[j].GetName()) } if sortOrder == "desc" { return !result } else { return result } }) return gm } func (gm GalleryElements[T]) FindByName(name string) T { for _, m := range gm { if strings.EqualFold(m.GetName(), name) { return m } } var zero T return zero } func (gm GalleryElements[T]) Paginate(pageNum int, itemsNum int) GalleryElements[T] { start := (pageNum - 1) * itemsNum end := start + itemsNum if start > len(gm) { start = len(gm) } if end > len(gm) { end = len(gm) } return gm[start:end] } func FindGalleryElement[T GalleryElement](models []T, name string) T { var model T name = strings.ReplaceAll(name, string(os.PathSeparator), "__") if !strings.Contains(name, "@") { for _, m := range models { if strings.EqualFold(strings.ToLower(m.GetName()), strings.ToLower(name)) { model = m break } } } else { for _, m := range models { if strings.EqualFold(strings.ToLower(name), strings.ToLower(fmt.Sprintf("%s@%s", m.GetGallery().Name, m.GetName()))) { model = m break } } } return model } // List available models // Models galleries are a list of yaml files that are hosted on a remote server (for example github). // Each yaml file contains a list of models that can be downloaded and optionally overrides to define a new model setting. func AvailableGalleryModels(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryModel], error) { var models []*GalleryModel // Get models from galleries for _, gallery := range galleries { galleryModels, err := getGalleryElements(gallery, systemState.Model.ModelsPath, func(model *GalleryModel) bool { if _, err := os.Stat(filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", model.GetName()))); err == nil { return true } return false }) if err != nil { return nil, err } // Resolve model URLs locally (for local galleries) and collect unique // URLs that need fetching for backend resolution. uniqueURLs := map[string]struct{}{} for _, m := range galleryModels { if m.URL != "" { m.URL = resolveModelURLLocally(m.URL, gallery.URL) } if m.Backend == "" && m.URL != "" { uniqueURLs[m.URL] = struct{}{} } } // Pre-warm cache with parallel fetches to avoid sequential HTTP // requests on cold start (~50 unique gallery config files). if len(uniqueURLs) > 0 { urls := make([]string, 0, len(uniqueURLs)) for u := range uniqueURLs { urls = append(urls, u) } prefetchModelConfigs(urls, systemState.Model.ModelsPath) } // Resolve backends from warm cache. for _, m := range galleryModels { if m.Backend == "" { m.Backend = resolveBackend(m, systemState.Model.ModelsPath) } } models = append(models, galleryModels...) } return models, nil } // List available backends func AvailableBackends(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) { return availableBackendsWithFilter(galleries, systemState, true) } // AvailableBackendsUnfiltered returns all available backends without filtering by system capability. func AvailableBackendsUnfiltered(galleries []config.Gallery, systemState *system.SystemState) (GalleryElements[*GalleryBackend], error) { return availableBackendsWithFilter(galleries, systemState, false) } // availableBackendsWithFilter is a helper function that lists available backends with optional filtering. func availableBackendsWithFilter(galleries []config.Gallery, systemState *system.SystemState, filterByCapability bool) (GalleryElements[*GalleryBackend], error) { var backends []*GalleryBackend systemBackends, err := ListSystemBackends(systemState) if err != nil { return nil, err } // Get backends from galleries for _, gallery := range galleries { galleryBackends, err := getGalleryElements(gallery, systemState.Backend.BackendsPath, func(backend *GalleryBackend) bool { return systemBackends.Exists(backend.GetName()) }) if err != nil { return nil, err } // Filter backends by system capability if requested if filterByCapability { for _, backend := range galleryBackends { if backend.IsCompatibleWith(systemState) { backends = append(backends, backend) } } } else { backends = append(backends, galleryBackends...) } } return backends, nil } func findGalleryURLFromReferenceURL(url string, basePath string) (string, error) { var refFile string uri := downloader.URI(url) err := uri.ReadWithCallback(basePath, func(url string, d []byte) error { refFile = string(d) if len(refFile) == 0 { return fmt.Errorf("invalid reference file at url %s: %s", url, d) } cutPoint := strings.LastIndex(url, "/") refFile = url[:cutPoint+1] + refFile return nil }) return refFile, err } type galleryCacheEntry struct { yamlEntry []byte lastUpdated time.Time } func (entry galleryCacheEntry) hasExpired() bool { return entry.lastUpdated.Before(time.Now().Add(-1 * time.Hour)) } var galleryCache = xsync.NewSyncedMap[string, galleryCacheEntry]() func getGalleryElements[T GalleryElement](gallery config.Gallery, basePath string, isInstalledCallback func(T) bool) ([]T, error) { var models []T = []T{} if strings.HasSuffix(gallery.URL, ".ref") { var err error gallery.URL, err = findGalleryURLFromReferenceURL(gallery.URL, basePath) if err != nil { return models, err } } cacheKey := fmt.Sprintf("%s-%s", gallery.Name, gallery.URL) if galleryCache.Exists(cacheKey) { entry := galleryCache.Get(cacheKey) // refresh if last updated is more than 1 hour ago if !entry.hasExpired() { err := yaml.Unmarshal(entry.yamlEntry, &models) if err != nil { return models, err } } else { galleryCache.Delete(cacheKey) } } uri := downloader.URI(gallery.URL) if len(models) == 0 { err := uri.ReadWithCallback(basePath, func(url string, d []byte) error { galleryCache.Set(cacheKey, galleryCacheEntry{ yamlEntry: d, lastUpdated: time.Now(), }) return yaml.Unmarshal(d, &models) }) if err != nil { if yamlErr, ok := err.(*yaml.TypeError); ok { xlog.Debug("YAML errors", "errors", strings.Join(yamlErr.Errors, "\n"), "models", models) } return models, fmt.Errorf("failed to read gallery elements: %w", err) } } // Add gallery to models for _, model := range models { model.SetGallery(gallery) model.SetInstalled(isInstalledCallback(model)) } return models, nil } ================================================ FILE: core/gallery/gallery_suite_test.go ================================================ package gallery_test import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestGallery(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Gallery test suite") } ================================================ FILE: core/gallery/gallery_test.go ================================================ package gallery_test import ( "os" "path/filepath" "dario.cat/mergo" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/gallery" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" ) var _ = Describe("Gallery", func() { var tempDir string BeforeEach(func() { var err error tempDir, err = os.MkdirTemp("", "gallery-test-*") Expect(err).NotTo(HaveOccurred()) }) AfterEach(func() { os.RemoveAll(tempDir) }) Describe("ReadConfigFile", func() { It("should read and unmarshal a valid YAML file", func() { testConfig := map[string]interface{}{ "name": "test-model", "description": "A test model", "license": "MIT", } yamlData, err := yaml.Marshal(testConfig) Expect(err).NotTo(HaveOccurred()) filePath := filepath.Join(tempDir, "test.yaml") err = os.WriteFile(filePath, yamlData, 0644) Expect(err).NotTo(HaveOccurred()) var result map[string]interface{} config, err := ReadConfigFile[map[string]interface{}](filePath) Expect(err).NotTo(HaveOccurred()) Expect(config).NotTo(BeNil()) result = *config Expect(result["name"]).To(Equal("test-model")) Expect(result["description"]).To(Equal("A test model")) Expect(result["license"]).To(Equal("MIT")) }) It("should return error when file does not exist", func() { _, err := ReadConfigFile[map[string]interface{}]("nonexistent.yaml") Expect(err).To(HaveOccurred()) }) It("should return error when YAML is invalid", func() { filePath := filepath.Join(tempDir, "invalid.yaml") err := os.WriteFile(filePath, []byte("invalid: yaml: content: [unclosed"), 0644) Expect(err).NotTo(HaveOccurred()) _, err = ReadConfigFile[map[string]interface{}](filePath) Expect(err).To(HaveOccurred()) }) }) Describe("GalleryElements Search", func() { var elements GalleryElements[*GalleryModel] BeforeEach(func() { elements = GalleryElements[*GalleryModel]{ { Metadata: Metadata{ Name: "bert-embeddings", Description: "BERT model for embeddings", Tags: []string{"embeddings", "bert", "nlp"}, License: "Apache-2.0", Gallery: config.Gallery{ Name: "huggingface", }, }, }, { Metadata: Metadata{ Name: "gpt-2", Description: "GPT-2 language model", Tags: []string{"gpt", "language-model"}, License: "MIT", Gallery: config.Gallery{ Name: "openai", }, }, }, { Metadata: Metadata{ Name: "llama-7b", Description: "LLaMA 7B model", Tags: []string{"llama", "llm"}, License: "LLaMA", Gallery: config.Gallery{ Name: "meta", }, }, }, } }) It("should find elements by exact name match", func() { results := elements.Search("bert-embeddings") Expect(results).To(HaveLen(1)) Expect(results[0].GetName()).To(Equal("bert-embeddings")) }) It("should find elements by partial name match", func() { results := elements.Search("bert") Expect(results).To(HaveLen(1)) Expect(results[0].GetName()).To(Equal("bert-embeddings")) }) It("should find elements by description", func() { results := elements.Search("embeddings") Expect(results).To(HaveLen(1)) Expect(results[0].GetName()).To(Equal("bert-embeddings")) }) It("should find elements by gallery name", func() { results := elements.Search("huggingface") Expect(results).To(HaveLen(1)) Expect(results[0].GetGallery().Name).To(Equal("huggingface")) }) It("should find elements by tags", func() { results := elements.Search("nlp") Expect(results).To(HaveLen(1)) Expect(results[0].GetName()).To(Equal("bert-embeddings")) }) It("should be case insensitive", func() { results := elements.Search("BERT") Expect(results).To(HaveLen(1)) Expect(results[0].GetName()).To(Equal("bert-embeddings")) }) It("should find multiple elements", func() { results := elements.Search("gpt") Expect(results).To(HaveLen(1)) Expect(results[0].GetName()).To(Equal("gpt-2")) }) It("should return empty results for no matches", func() { results := elements.Search("nonexistent") Expect(results).To(HaveLen(0)) }) It("should use fuzzy matching", func() { results := elements.Search("bert-emb") Expect(results).To(HaveLen(1)) Expect(results[0].GetName()).To(Equal("bert-embeddings")) }) }) Describe("GalleryElements SortByName", func() { var elements GalleryElements[*GalleryModel] BeforeEach(func() { elements = GalleryElements[*GalleryModel]{ {Metadata: Metadata{Name: "zebra"}}, {Metadata: Metadata{Name: "alpha"}}, {Metadata: Metadata{Name: "beta"}}, } }) It("should sort ascending", func() { sorted := elements.SortByName("asc") Expect(sorted).To(HaveLen(3)) Expect(sorted[0].GetName()).To(Equal("alpha")) Expect(sorted[1].GetName()).To(Equal("beta")) Expect(sorted[2].GetName()).To(Equal("zebra")) }) It("should sort descending", func() { sorted := elements.SortByName("desc") Expect(sorted).To(HaveLen(3)) Expect(sorted[0].GetName()).To(Equal("zebra")) Expect(sorted[1].GetName()).To(Equal("beta")) Expect(sorted[2].GetName()).To(Equal("alpha")) }) It("should be case insensitive", func() { elements = GalleryElements[*GalleryModel]{ {Metadata: Metadata{Name: "Zebra"}}, {Metadata: Metadata{Name: "alpha"}}, {Metadata: Metadata{Name: "Beta"}}, } sorted := elements.SortByName("asc") Expect(sorted[0].GetName()).To(Equal("alpha")) Expect(sorted[1].GetName()).To(Equal("Beta")) Expect(sorted[2].GetName()).To(Equal("Zebra")) }) }) Describe("GalleryElements SortByRepository", func() { var elements GalleryElements[*GalleryModel] BeforeEach(func() { elements = GalleryElements[*GalleryModel]{ { Metadata: Metadata{ Gallery: config.Gallery{Name: "zebra-repo"}, }, }, { Metadata: Metadata{ Gallery: config.Gallery{Name: "alpha-repo"}, }, }, { Metadata: Metadata{ Gallery: config.Gallery{Name: "beta-repo"}, }, }, } }) It("should sort ascending", func() { sorted := elements.SortByRepository("asc") Expect(sorted).To(HaveLen(3)) Expect(sorted[0].GetGallery().Name).To(Equal("alpha-repo")) Expect(sorted[1].GetGallery().Name).To(Equal("beta-repo")) Expect(sorted[2].GetGallery().Name).To(Equal("zebra-repo")) }) It("should sort descending", func() { sorted := elements.SortByRepository("desc") Expect(sorted).To(HaveLen(3)) Expect(sorted[0].GetGallery().Name).To(Equal("zebra-repo")) Expect(sorted[1].GetGallery().Name).To(Equal("beta-repo")) Expect(sorted[2].GetGallery().Name).To(Equal("alpha-repo")) }) }) Describe("GalleryElements SortByLicense", func() { var elements GalleryElements[*GalleryModel] BeforeEach(func() { elements = GalleryElements[*GalleryModel]{ {Metadata: Metadata{License: "MIT"}}, {Metadata: Metadata{License: "Apache-2.0"}}, {Metadata: Metadata{License: ""}}, {Metadata: Metadata{License: "GPL-3.0"}}, } }) It("should sort ascending with empty licenses at end", func() { sorted := elements.SortByLicense("asc") Expect(sorted).To(HaveLen(4)) Expect(sorted[0].GetLicense()).To(Equal("Apache-2.0")) Expect(sorted[1].GetLicense()).To(Equal("GPL-3.0")) Expect(sorted[2].GetLicense()).To(Equal("MIT")) Expect(sorted[3].GetLicense()).To(Equal("")) }) It("should sort descending with empty licenses at beginning", func() { sorted := elements.SortByLicense("desc") Expect(sorted).To(HaveLen(4)) Expect(sorted[0].GetLicense()).To(Equal("")) Expect(sorted[1].GetLicense()).To(Equal("MIT")) Expect(sorted[2].GetLicense()).To(Equal("GPL-3.0")) Expect(sorted[3].GetLicense()).To(Equal("Apache-2.0")) }) It("should handle all empty licenses", func() { elements = GalleryElements[*GalleryModel]{ {Metadata: Metadata{License: ""}}, {Metadata: Metadata{License: ""}}, } sorted := elements.SortByLicense("asc") Expect(sorted).To(HaveLen(2)) }) }) Describe("GalleryElements SortByInstalled", func() { var elements GalleryElements[*GalleryModel] BeforeEach(func() { elements = GalleryElements[*GalleryModel]{ {Metadata: Metadata{Name: "installed-2", Installed: true}}, {Metadata: Metadata{Name: "not-installed-1", Installed: false}}, {Metadata: Metadata{Name: "installed-1", Installed: true}}, {Metadata: Metadata{Name: "not-installed-2", Installed: false}}, } }) It("should sort ascending with installed first, then by name", func() { sorted := elements.SortByInstalled("asc") Expect(sorted).To(HaveLen(4)) Expect(sorted[0].GetInstalled()).To(BeTrue()) Expect(sorted[0].GetName()).To(Equal("installed-1")) Expect(sorted[1].GetInstalled()).To(BeTrue()) Expect(sorted[1].GetName()).To(Equal("installed-2")) Expect(sorted[2].GetInstalled()).To(BeFalse()) Expect(sorted[2].GetName()).To(Equal("not-installed-1")) Expect(sorted[3].GetInstalled()).To(BeFalse()) Expect(sorted[3].GetName()).To(Equal("not-installed-2")) }) It("should sort descending with not-installed first, then by name", func() { sorted := elements.SortByInstalled("desc") Expect(sorted).To(HaveLen(4)) Expect(sorted[0].GetInstalled()).To(BeFalse()) Expect(sorted[0].GetName()).To(Equal("not-installed-2")) Expect(sorted[1].GetInstalled()).To(BeFalse()) Expect(sorted[1].GetName()).To(Equal("not-installed-1")) Expect(sorted[2].GetInstalled()).To(BeTrue()) Expect(sorted[2].GetName()).To(Equal("installed-2")) Expect(sorted[3].GetInstalled()).To(BeTrue()) Expect(sorted[3].GetName()).To(Equal("installed-1")) }) }) Describe("GalleryElements FindByName", func() { var elements GalleryElements[*GalleryModel] BeforeEach(func() { elements = GalleryElements[*GalleryModel]{ {Metadata: Metadata{Name: "bert-embeddings"}}, {Metadata: Metadata{Name: "gpt-2"}}, {Metadata: Metadata{Name: "llama-7b"}}, } }) It("should find element by exact name", func() { result := elements.FindByName("bert-embeddings") Expect(result).NotTo(BeNil()) Expect(result.GetName()).To(Equal("bert-embeddings")) }) It("should be case insensitive", func() { result := elements.FindByName("BERT-EMBEDDINGS") Expect(result).NotTo(BeNil()) Expect(result.GetName()).To(Equal("bert-embeddings")) }) It("should return zero value when not found", func() { result := elements.FindByName("nonexistent") Expect(result).To(BeNil()) }) }) Describe("GalleryElements Paginate", func() { var elements GalleryElements[*GalleryModel] BeforeEach(func() { elements = GalleryElements[*GalleryModel]{ {Metadata: Metadata{Name: "model-1"}}, {Metadata: Metadata{Name: "model-2"}}, {Metadata: Metadata{Name: "model-3"}}, {Metadata: Metadata{Name: "model-4"}}, {Metadata: Metadata{Name: "model-5"}}, } }) It("should return first page", func() { page := elements.Paginate(1, 2) Expect(page).To(HaveLen(2)) Expect(page[0].GetName()).To(Equal("model-1")) Expect(page[1].GetName()).To(Equal("model-2")) }) It("should return second page", func() { page := elements.Paginate(2, 2) Expect(page).To(HaveLen(2)) Expect(page[0].GetName()).To(Equal("model-3")) Expect(page[1].GetName()).To(Equal("model-4")) }) It("should return partial last page", func() { page := elements.Paginate(3, 2) Expect(page).To(HaveLen(1)) Expect(page[0].GetName()).To(Equal("model-5")) }) It("should handle page beyond range", func() { page := elements.Paginate(10, 2) Expect(page).To(HaveLen(0)) }) It("should handle empty elements", func() { empty := GalleryElements[*GalleryModel]{} page := empty.Paginate(1, 10) Expect(page).To(HaveLen(0)) }) }) Describe("FindGalleryElement", func() { var models []*GalleryModel BeforeEach(func() { models = []*GalleryModel{ { Metadata: Metadata{ Name: "bert-embeddings", Gallery: config.Gallery{ Name: "huggingface", }, }, }, { Metadata: Metadata{ Name: "gpt-2", Gallery: config.Gallery{ Name: "openai", }, }, }, } }) It("should find element by name without @ notation", func() { result := FindGalleryElement(models, "bert-embeddings") Expect(result).NotTo(BeNil()) Expect(result.GetName()).To(Equal("bert-embeddings")) }) It("should find element by name with @ notation", func() { result := FindGalleryElement(models, "huggingface@bert-embeddings") Expect(result).NotTo(BeNil()) Expect(result.GetName()).To(Equal("bert-embeddings")) Expect(result.GetGallery().Name).To(Equal("huggingface")) }) It("should be case insensitive", func() { result := FindGalleryElement(models, "BERT-EMBEDDINGS") Expect(result).NotTo(BeNil()) Expect(result.GetName()).To(Equal("bert-embeddings")) }) It("should handle path separators in name", func() { // Path separators are replaced with __, so bert/embeddings becomes bert__embeddings // This test verifies the replacement happens, but won't match unless model name has __ modelsWithPath := []*GalleryModel{ { Metadata: Metadata{ Name: "bert__embeddings", Gallery: config.Gallery{ Name: "huggingface", }, }, }, } result := FindGalleryElement(modelsWithPath, "bert/embeddings") Expect(result).NotTo(BeNil()) Expect(result.GetName()).To(Equal("bert__embeddings")) }) It("should return zero value when not found", func() { result := FindGalleryElement(models, "nonexistent") Expect(result).To(BeNil()) }) It("should return zero value when gallery@name not found", func() { result := FindGalleryElement(models, "nonexistent@model") Expect(result).To(BeNil()) }) }) Describe("YAML merge with nested maps", func() { It("should handle YAML anchors and merges with nested overrides (regression test for nanbeige4.1)", func() { // This tests the fix for the panic that occurred with yaml.v2: // yaml.v2 produces map[interface{}]interface{} for nested maps // which caused mergo.Merge to panic with "value of type interface {} is not assignable to type string" // The exact YAML structure from gallery/index.yaml nanbeige4.1 entries yamlContent := `--- - &nanbeige4 name: "nanbeige4.1-3b-q8" overrides: parameters: model: nanbeige4.1-3b-q8_0.gguf - !!merge <<: *nanbeige4 name: "nanbeige4.1-3b-q4" overrides: parameters: model: nanbeige4.1-3b-q4_k_m.gguf ` var models []GalleryModel err := yaml.Unmarshal([]byte(yamlContent), &models) Expect(err).NotTo(HaveOccurred()) Expect(models).To(HaveLen(2)) // Verify first model Expect(models[0].Name).To(Equal("nanbeige4.1-3b-q8")) Expect(models[0].Overrides).NotTo(BeNil()) Expect(models[0].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{})) params := models[0].Overrides["parameters"].(map[string]interface{}) Expect(params["model"]).To(Equal("nanbeige4.1-3b-q8_0.gguf")) // Verify second model (merged) Expect(models[1].Name).To(Equal("nanbeige4.1-3b-q4")) Expect(models[1].Overrides).NotTo(BeNil()) Expect(models[1].Overrides["parameters"]).To(BeAssignableToTypeOf(map[string]interface{}{})) params = models[1].Overrides["parameters"].(map[string]interface{}) Expect(params["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf")) // Simulate the mergo.Merge call that was failing in models.go:251 // This should not panic with yaml.v3 configMap := make(map[string]interface{}) configMap["name"] = "test" configMap["backend"] = "llama-cpp" configMap["parameters"] = map[string]interface{}{ "model": "original.gguf", } err = mergo.Merge(&configMap, models[1].Overrides, mergo.WithOverride) Expect(err).NotTo(HaveOccurred()) Expect(configMap["parameters"]).NotTo(BeNil()) // Verify the merge worked correctly mergedParams := configMap["parameters"].(map[string]interface{}) Expect(mergedParams["model"]).To(Equal("nanbeige4.1-3b-q4_k_m.gguf")) }) }) }) ================================================ FILE: core/gallery/importers/diffuser.go ================================================ package importers import ( "encoding/json" "path/filepath" "strings" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" "gopkg.in/yaml.v3" ) var _ Importer = &DiffuserImporter{} type DiffuserImporter struct{} func (i *DiffuserImporter) Match(details Details) bool { preferences, err := details.Preferences.MarshalJSON() if err != nil { return false } preferencesMap := make(map[string]any) err = json.Unmarshal(preferences, &preferencesMap) if err != nil { return false } b, ok := preferencesMap["backend"].(string) if ok && b == "diffusers" { return true } if details.HuggingFace != nil { for _, file := range details.HuggingFace.Files { if strings.Contains(file.Path, "model_index.json") || strings.Contains(file.Path, "scheduler/scheduler_config.json") { return true } } } return false } func (i *DiffuserImporter) Import(details Details) (gallery.ModelConfig, error) { preferences, err := details.Preferences.MarshalJSON() if err != nil { return gallery.ModelConfig{}, err } preferencesMap := make(map[string]any) err = json.Unmarshal(preferences, &preferencesMap) if err != nil { return gallery.ModelConfig{}, err } name, ok := preferencesMap["name"].(string) if !ok { name = filepath.Base(details.URI) } description, ok := preferencesMap["description"].(string) if !ok { description = "Imported from " + details.URI } backend := "diffusers" b, ok := preferencesMap["backend"].(string) if ok { backend = b } pipelineType, ok := preferencesMap["pipeline_type"].(string) if !ok { pipelineType = "StableDiffusionPipeline" } schedulerType, ok := preferencesMap["scheduler_type"].(string) if !ok { schedulerType = "" } enableParameters, ok := preferencesMap["enable_parameters"].(string) if !ok { enableParameters = "negative_prompt,num_inference_steps" } cuda := false if cudaVal, ok := preferencesMap["cuda"].(bool); ok { cuda = cudaVal } modelConfig := config.ModelConfig{ Name: name, Description: description, KnownUsecaseStrings: []string{"image"}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: details.URI, }, }, Diffusers: config.Diffusers{ PipelineType: pipelineType, SchedulerType: schedulerType, EnableParameters: enableParameters, CUDA: cuda, }, } data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err } return gallery.ModelConfig{ Name: name, Description: description, ConfigFile: string(data), }, nil } ================================================ FILE: core/gallery/importers/diffuser_test.go ================================================ package importers_test import ( "encoding/json" "github.com/mudler/LocalAI/core/gallery/importers" . "github.com/mudler/LocalAI/core/gallery/importers" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("DiffuserImporter", func() { var importer *DiffuserImporter BeforeEach(func() { importer = &DiffuserImporter{} }) Context("Match", func() { It("should match when backend preference is diffusers", func() { preferences := json.RawMessage(`{"backend": "diffusers"}`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should match when HuggingFace details contain model_index.json", func() { hfDetails := &hfapi.ModelDetails{ Files: []hfapi.ModelFile{ {Path: "model_index.json"}, }, } details := Details{ URI: "https://huggingface.co/test/model", HuggingFace: hfDetails, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should match when HuggingFace details contain scheduler config", func() { hfDetails := &hfapi.ModelDetails{ Files: []hfapi.ModelFile{ {Path: "scheduler/scheduler_config.json"}, }, } details := Details{ URI: "https://huggingface.co/test/model", HuggingFace: hfDetails, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should not match when URI has no diffuser files and no backend preference", func() { details := Details{ URI: "https://example.com/model.bin", } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should not match when backend preference is different", func() { preferences := json.RawMessage(`{"backend": "llama-cpp"}`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should return false when JSON preferences are invalid", func() { preferences := json.RawMessage(`invalid json`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeFalse()) }) }) Context("Import", func() { It("should import model config with default name and description", func() { details := Details{ URI: "https://huggingface.co/test/my-diffuser-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("my-diffuser-model")) Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-diffuser-model")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers")) Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-diffuser-model")) Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline")) Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: negative_prompt,num_inference_steps")) }) It("should import model config with custom name and description from preferences", func() { preferences := json.RawMessage(`{"name": "custom-diffuser", "description": "Custom diffuser model"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("custom-diffuser")) Expect(modelConfig.Description).To(Equal("Custom diffuser model")) }) It("should use custom pipeline_type from preferences", func() { preferences := json.RawMessage(`{"pipeline_type": "StableDiffusion3Pipeline"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusion3Pipeline")) }) It("should use default pipeline_type when not specified", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("pipeline_type: StableDiffusionPipeline")) }) It("should use custom scheduler_type from preferences", func() { preferences := json.RawMessage(`{"scheduler_type": "k_dpmpp_2m"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("scheduler_type: k_dpmpp_2m")) }) It("should use cuda setting from preferences", func() { preferences := json.RawMessage(`{"cuda": true}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("cuda: true")) }) It("should use custom enable_parameters from preferences", func() { preferences := json.RawMessage(`{"enable_parameters": "num_inference_steps,guidance_scale"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("enable_parameters: num_inference_steps,guidance_scale")) }) It("should use custom backend from preferences", func() { preferences := json.RawMessage(`{"backend": "diffusers"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: diffusers")) }) It("should handle invalid JSON preferences", func() { preferences := json.RawMessage(`invalid json`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } _, err := importer.Import(details) Expect(err).To(HaveOccurred()) }) It("should extract filename correctly from URI with path", func() { details := importers.Details{ URI: "https://huggingface.co/test/path/to/model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("model")) }) It("should include known_usecases as image in config", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:")) Expect(modelConfig.ConfigFile).To(ContainSubstring("- image")) }) It("should include diffusers configuration in config", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("diffusers:")) }) }) }) ================================================ FILE: core/gallery/importers/importers.go ================================================ package importers import ( "encoding/json" "fmt" "os" "strings" "github.com/mudler/xlog" "gopkg.in/yaml.v3" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" ) var defaultImporters = []Importer{ &LlamaCPPImporter{}, &MLXImporter{}, &VLLMImporter{}, &TransformersImporter{}, &DiffuserImporter{}, } type Details struct { HuggingFace *hfapi.ModelDetails URI string Preferences json.RawMessage } type Importer interface { Match(details Details) bool Import(details Details) (gallery.ModelConfig, error) } func hasYAMLExtension(uri string) bool { return strings.HasSuffix(uri, ".yaml") || strings.HasSuffix(uri, ".yml") } func DiscoverModelConfig(uri string, preferences json.RawMessage) (gallery.ModelConfig, error) { var err error var modelConfig gallery.ModelConfig hf := hfapi.NewClient() hfrepoID := strings.ReplaceAll(uri, "huggingface://", "") hfrepoID = strings.ReplaceAll(hfrepoID, "hf://", "") hfrepoID = strings.ReplaceAll(hfrepoID, "https://huggingface.co/", "") hfDetails, err := hf.GetModelDetails(hfrepoID) if err != nil { // maybe not a HF repository // TODO: maybe we can check if the URI is a valid HF repository xlog.Debug("Failed to get model details, maybe not a HF repository", "uri", uri, "hfrepoID", hfrepoID) } else { xlog.Debug("Got model details", "uri", uri) xlog.Debug("Model details", "details", hfDetails) } // handle local config files ("/my-model.yaml" or "file://my-model.yaml") localURI := uri if strings.HasPrefix(uri, downloader.LocalPrefix) { localURI = strings.TrimPrefix(uri, downloader.LocalPrefix) } // if a file exists or it's an url that ends with .yaml or .yml, read the config file directly if _, e := os.Stat(localURI); hasYAMLExtension(localURI) && (e == nil || downloader.URI(localURI).LooksLikeURL()) { var modelYAML []byte if downloader.URI(localURI).LooksLikeURL() { err := downloader.URI(localURI).ReadWithCallback(localURI, func(url string, i []byte) error { modelYAML = i return nil }) if err != nil { xlog.Error("error reading model definition", "error", err, "filepath", localURI) return gallery.ModelConfig{}, err } } else { modelYAML, err = os.ReadFile(localURI) if err != nil { xlog.Error("error reading model definition", "error", err, "filepath", localURI) return gallery.ModelConfig{}, err } } var modelConfig config.ModelConfig if e := yaml.Unmarshal(modelYAML, &modelConfig); e != nil { return gallery.ModelConfig{}, e } configFile, err := yaml.Marshal(modelConfig) return gallery.ModelConfig{ Description: modelConfig.Description, Name: modelConfig.Name, ConfigFile: string(configFile), }, err } details := Details{ HuggingFace: hfDetails, URI: uri, Preferences: preferences, } importerMatched := false for _, importer := range defaultImporters { if importer.Match(details) { importerMatched = true modelConfig, err = importer.Import(details) if err != nil { continue } break } } if !importerMatched { return gallery.ModelConfig{}, fmt.Errorf("no importer matched for %s", uri) } return modelConfig, nil } ================================================ FILE: core/gallery/importers/importers_suite_test.go ================================================ package importers_test import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestImporters(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Importers test suite") } ================================================ FILE: core/gallery/importers/importers_test.go ================================================ package importers_test import ( "encoding/json" "fmt" "os" "path/filepath" "github.com/mudler/LocalAI/core/gallery/importers" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("DiscoverModelConfig", func() { Context("With only a repository URI", func() { It("should discover and import using LlamaCPPImporter", func() { uri := "https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF" preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err)) Expect(modelConfig.Name).To(Equal("LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/mudler/LocalAI-functioncall-qwen2.5-7b-v0.5-Q4_K_M-GGUF/resolve/main/localai-functioncall-qwen2.5-7b-v0.5-q4_k_m.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].SHA256).To(Equal("4e7b7fe1d54b881f1ef90799219dc6cc285d29db24f559c8998d1addb35713d4"), fmt.Sprintf("Model config: %+v", modelConfig)) }) It("should discover and import using LlamaCPPImporter", func() { uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF" preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err)) Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q4_K_M.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig)) }) It("should discover and import using LlamaCPPImporter", func() { uri := "https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF" preferences := json.RawMessage(`{ "quantizations": "Q8_0", "mmproj_quantizations": "f16" }`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred(), fmt.Sprintf("Error: %v", err)) Expect(modelConfig.Name).To(Equal("Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.ConfigFile).To(ContainSubstring("mmproj: llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.ConfigFile).To(ContainSubstring("model: llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(len(modelConfig.Files)).To(Equal(2), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].Filename).To(Equal("llama-cpp/models/Qwen3-VL-2B-Instruct-GGUF/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/Qwen3VL-2B-Instruct-Q8_0.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[1].Filename).To(Equal("llama-cpp/mmproj/Qwen3-VL-2B-Instruct-GGUF/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[1].URI).To(Equal("https://huggingface.co/Qwen/Qwen3-VL-2B-Instruct-GGUF/resolve/main/mmproj-Qwen3VL-2B-Instruct-F16.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[1].SHA256).ToNot(BeEmpty(), fmt.Sprintf("Model config: %+v", modelConfig)) }) }) Context("with .gguf URI", func() { It("should discover and import using LlamaCPPImporter", func() { uri := "https://example.com/my-model.gguf" preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("my-model.gguf")) Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp")) }) It("should use custom preferences when provided", func() { uri := "https://example.com/my-model.gguf" preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description"}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("custom-name")) Expect(modelConfig.Description).To(Equal("Custom description")) }) }) Context("with mlx-community URI", func() { It("should discover and import using MLXImporter", func() { uri := "https://huggingface.co/mlx-community/test-model" preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("test-model")) Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx")) }) It("should use custom preferences when provided", func() { uri := "https://huggingface.co/mlx-community/test-model" preferences := json.RawMessage(`{"name": "custom-mlx", "description": "Custom MLX description"}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("custom-mlx")) Expect(modelConfig.Description).To(Equal("Custom MLX description")) }) }) Context("with backend preference", func() { It("should use llama-cpp backend when specified", func() { uri := "https://example.com/model" preferences := json.RawMessage(`{"backend": "llama-cpp"}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp")) }) It("should use mlx backend when specified", func() { uri := "https://example.com/model" preferences := json.RawMessage(`{"backend": "mlx"}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx")) }) It("should use mlx-vlm backend when specified", func() { uri := "https://example.com/model" preferences := json.RawMessage(`{"backend": "mlx-vlm"}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm")) }) }) Context("with HuggingFace URI formats", func() { It("should handle huggingface:// prefix", func() { uri := "huggingface://mlx-community/test-model" preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("test-model")) }) It("should handle hf:// prefix", func() { uri := "hf://mlx-community/test-model" preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("test-model")) }) It("should handle https://huggingface.co/ prefix", func() { uri := "https://huggingface.co/mlx-community/test-model" preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("test-model")) }) }) Context("with invalid or non-matching URI", func() { It("should return error when no importer matches", func() { uri := "https://example.com/unknown-model.bin" preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) // When no importer matches, the function returns empty config and error // The exact behavior depends on implementation, but typically an error is returned Expect(modelConfig.Name).To(BeEmpty()) Expect(err).To(HaveOccurred()) }) }) Context("with invalid JSON preferences", func() { It("should return error when JSON is invalid even if URI matches", func() { uri := "https://example.com/model.gguf" preferences := json.RawMessage(`invalid json`) // Even though Match() returns true for .gguf extension, // Import() will fail when trying to unmarshal invalid JSON preferences modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).To(HaveOccurred()) Expect(modelConfig.Name).To(BeEmpty()) }) }) Context("with local YAML config files", func() { var tempDir string BeforeEach(func() { var err error tempDir, err = os.MkdirTemp("", "importers-test-*") Expect(err).ToNot(HaveOccurred()) }) AfterEach(func() { os.RemoveAll(tempDir) }) It("should read local YAML file with file:// prefix", func() { yamlContent := `name: test-model backend: llama-cpp description: Test model from local YAML parameters: model: /path/to/model.gguf temperature: 0.7 ` yamlFile := filepath.Join(tempDir, "test-model.yaml") err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) Expect(err).ToNot(HaveOccurred()) uri := "file://" + yamlFile preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("test-model")) Expect(modelConfig.Description).To(Equal("Test model from local YAML")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp")) Expect(modelConfig.ConfigFile).To(ContainSubstring("name: test-model")) }) It("should read local YAML file without file:// prefix (direct path)", func() { yamlContent := `name: direct-path-model backend: mlx description: Test model from direct path parameters: model: /path/to/model.safetensors ` yamlFile := filepath.Join(tempDir, "direct-model.yaml") err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) Expect(err).ToNot(HaveOccurred()) uri := yamlFile preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("direct-path-model")) Expect(modelConfig.Description).To(Equal("Test model from direct path")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx")) }) It("should read local YAML file with .yml extension", func() { yamlContent := `name: yml-extension-model backend: transformers description: Test model with .yml extension parameters: model: /path/to/model ` yamlFile := filepath.Join(tempDir, "test-model.yml") err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) Expect(err).ToNot(HaveOccurred()) uri := "file://" + yamlFile preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("yml-extension-model")) Expect(modelConfig.Description).To(Equal("Test model with .yml extension")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers")) }) It("should ignore preferences when reading YAML files directly", func() { yamlContent := `name: yaml-model backend: llama-cpp description: Original description parameters: model: /path/to/model.gguf ` yamlFile := filepath.Join(tempDir, "prefs-test.yaml") err := os.WriteFile(yamlFile, []byte(yamlContent), 0644) Expect(err).ToNot(HaveOccurred()) uri := "file://" + yamlFile // Preferences should be ignored when reading YAML directly preferences := json.RawMessage(`{"name": "custom-name", "description": "Custom description", "backend": "mlx"}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).ToNot(HaveOccurred()) // Should use values from YAML file, not preferences Expect(modelConfig.Name).To(Equal("yaml-model")) Expect(modelConfig.Description).To(Equal("Original description")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp")) }) It("should return error when local YAML file doesn't exist", func() { nonExistentFile := filepath.Join(tempDir, "nonexistent.yaml") uri := "file://" + nonExistentFile preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).To(HaveOccurred()) Expect(modelConfig.Name).To(BeEmpty()) }) It("should return error when YAML file is invalid/malformed", func() { invalidYaml := `name: invalid-model backend: llama-cpp invalid: yaml: content: [unclosed bracket ` yamlFile := filepath.Join(tempDir, "invalid.yaml") err := os.WriteFile(yamlFile, []byte(invalidYaml), 0644) Expect(err).ToNot(HaveOccurred()) uri := "file://" + yamlFile preferences := json.RawMessage(`{}`) modelConfig, err := importers.DiscoverModelConfig(uri, preferences) Expect(err).To(HaveOccurred()) Expect(modelConfig.Name).To(BeEmpty()) }) }) }) ================================================ FILE: core/gallery/importers/llama-cpp.go ================================================ package importers import ( "encoding/json" "path/filepath" "slices" "strings" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/xlog" "go.yaml.in/yaml/v2" ) var _ Importer = &LlamaCPPImporter{} type LlamaCPPImporter struct{} func (i *LlamaCPPImporter) Match(details Details) bool { preferences, err := details.Preferences.MarshalJSON() if err != nil { xlog.Error("failed to marshal preferences", "error", err) return false } preferencesMap := make(map[string]any) if len(preferences) > 0 { err = json.Unmarshal(preferences, &preferencesMap) if err != nil { xlog.Error("failed to unmarshal preferences", "error", err) return false } } uri := downloader.URI(details.URI) if preferencesMap["backend"] == "llama-cpp" { return true } if strings.HasSuffix(details.URI, ".gguf") { return true } if uri.LooksLikeOCI() { return true } if details.HuggingFace != nil { for _, file := range details.HuggingFace.Files { if strings.HasSuffix(file.Path, ".gguf") { return true } } } return false } func (i *LlamaCPPImporter) Import(details Details) (gallery.ModelConfig, error) { xlog.Debug("llama.cpp importer matched", "uri", details.URI) preferences, err := details.Preferences.MarshalJSON() if err != nil { return gallery.ModelConfig{}, err } preferencesMap := make(map[string]any) if len(preferences) > 0 { err = json.Unmarshal(preferences, &preferencesMap) if err != nil { return gallery.ModelConfig{}, err } } name, ok := preferencesMap["name"].(string) if !ok { name = filepath.Base(details.URI) } description, ok := preferencesMap["description"].(string) if !ok { description = "Imported from " + details.URI } preferedQuantizations, _ := preferencesMap["quantizations"].(string) quants := []string{"q4_k_m"} if preferedQuantizations != "" { quants = strings.Split(preferedQuantizations, ",") } mmprojQuants, _ := preferencesMap["mmproj_quantizations"].(string) mmprojQuantsList := []string{"fp16"} if mmprojQuants != "" { mmprojQuantsList = strings.Split(mmprojQuants, ",") } embeddings, _ := preferencesMap["embeddings"].(string) modelConfig := config.ModelConfig{ Name: name, Description: description, KnownUsecaseStrings: []string{"chat"}, Options: []string{"use_jinja:true"}, Backend: "llama-cpp", TemplateConfig: config.TemplateConfig{ UseTokenizerTemplate: true, }, FunctionsConfig: functions.FunctionsConfig{ GrammarConfig: functions.GrammarConfig{ NoGrammar: true, }, }, } if embeddings != "" && strings.ToLower(embeddings) == "true" || strings.ToLower(embeddings) == "yes" { trueV := true modelConfig.Embeddings = &trueV } cfg := gallery.ModelConfig{ Name: name, Description: description, } uri := downloader.URI(details.URI) switch { case uri.LooksLikeOCI(): ociName := strings.TrimPrefix(string(uri), downloader.OCIPrefix) ociName = strings.TrimPrefix(ociName, downloader.OllamaPrefix) ociName = strings.ReplaceAll(ociName, "/", "__") ociName = strings.ReplaceAll(ociName, ":", "__") cfg.Files = append(cfg.Files, gallery.File{ URI: details.URI, Filename: ociName, }) modelConfig.PredictionOptions = schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: ociName, }, } case uri.LooksLikeURL() && strings.HasSuffix(details.URI, ".gguf"): // Extract filename from URL fileName, e := uri.FilenameFromUrl() if e != nil { return gallery.ModelConfig{}, e } cfg.Files = append(cfg.Files, gallery.File{ URI: details.URI, Filename: fileName, }) modelConfig.PredictionOptions = schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: fileName, }, } case strings.HasSuffix(details.URI, ".gguf"): cfg.Files = append(cfg.Files, gallery.File{ URI: details.URI, Filename: filepath.Base(details.URI), }) modelConfig.PredictionOptions = schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: filepath.Base(details.URI), }, } case details.HuggingFace != nil: // We want to: // Get first the chosen quants that match filenames // OR the first mmproj/gguf file found var lastMMProjFile *gallery.File var lastGGUFFile *gallery.File foundPreferedQuant := false foundPreferedMMprojQuant := false for _, file := range details.HuggingFace.Files { // Get the mmproj prefered quants if strings.Contains(strings.ToLower(file.Path), "mmproj") { lastMMProjFile = &gallery.File{ URI: file.URL, Filename: filepath.Join("llama-cpp", "mmproj", name, filepath.Base(file.Path)), SHA256: file.SHA256, } if slices.ContainsFunc(mmprojQuantsList, func(quant string) bool { return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant)) }) { cfg.Files = append(cfg.Files, *lastMMProjFile) foundPreferedMMprojQuant = true } } else if strings.HasSuffix(strings.ToLower(file.Path), "gguf") { lastGGUFFile = &gallery.File{ URI: file.URL, Filename: filepath.Join("llama-cpp", "models", name, filepath.Base(file.Path)), SHA256: file.SHA256, } // get the files of the prefered quants if slices.ContainsFunc(quants, func(quant string) bool { return strings.Contains(strings.ToLower(file.Path), strings.ToLower(quant)) }) { foundPreferedQuant = true cfg.Files = append(cfg.Files, *lastGGUFFile) } } } // Make sure to add at least one file if not already present (which is the latest one) if lastMMProjFile != nil && !foundPreferedMMprojQuant { if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool { return f.Filename == lastMMProjFile.Filename }) { cfg.Files = append(cfg.Files, *lastMMProjFile) } } if lastGGUFFile != nil && !foundPreferedQuant { if !slices.ContainsFunc(cfg.Files, func(f gallery.File) bool { return f.Filename == lastGGUFFile.Filename }) { cfg.Files = append(cfg.Files, *lastGGUFFile) } } // Find first mmproj file and configure it in the config file for _, file := range cfg.Files { if !strings.Contains(strings.ToLower(file.Filename), "mmproj") { continue } modelConfig.MMProj = file.Filename break } // Find first non-mmproj file and configure it in the config file for _, file := range cfg.Files { if strings.Contains(strings.ToLower(file.Filename), "mmproj") { continue } modelConfig.PredictionOptions = schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: file.Filename, }, } break } } data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err } cfg.ConfigFile = string(data) return cfg, nil } ================================================ FILE: core/gallery/importers/llama-cpp_test.go ================================================ package importers_test import ( "encoding/json" "fmt" "github.com/mudler/LocalAI/core/gallery/importers" . "github.com/mudler/LocalAI/core/gallery/importers" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("LlamaCPPImporter", func() { var importer *LlamaCPPImporter BeforeEach(func() { importer = &LlamaCPPImporter{} }) Context("Match", func() { It("should match when URI ends with .gguf", func() { details := Details{ URI: "https://example.com/model.gguf", } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should match when backend preference is llama-cpp", func() { preferences := json.RawMessage(`{"backend": "llama-cpp"}`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should not match when URI does not end with .gguf and no backend preference", func() { details := Details{ URI: "https://example.com/model.bin", } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should not match when backend preference is different", func() { preferences := json.RawMessage(`{"backend": "mlx"}`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should return false when JSON preferences are invalid", func() { preferences := json.RawMessage(`invalid json`) details := Details{ URI: "https://example.com/model.gguf", Preferences: preferences, } // Invalid JSON causes Match to return false early result := importer.Match(details) Expect(result).To(BeFalse()) }) }) Context("Import", func() { It("should import model config with default name and description", func() { details := Details{ URI: "https://example.com/my-model.gguf", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("my-model.gguf")) Expect(modelConfig.Description).To(Equal("Imported from https://example.com/my-model.gguf")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: llama-cpp")) Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) }) It("should import model config with custom name and description from preferences", func() { preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`) details := Details{ URI: "https://example.com/my-model.gguf", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("custom-model")) Expect(modelConfig.Description).To(Equal("Custom description")) Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].Filename).To(Equal("my-model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) }) It("should handle invalid JSON preferences", func() { preferences := json.RawMessage(`invalid json`) details := Details{ URI: "https://example.com/my-model.gguf", Preferences: preferences, } _, err := importer.Import(details) Expect(err).To(HaveOccurred()) }) It("should extract filename correctly from URI with path", func() { details := importers.Details{ URI: "https://example.com/path/to/model.gguf", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(len(modelConfig.Files)).To(Equal(1), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].URI).To(Equal("https://example.com/path/to/model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) Expect(modelConfig.Files[0].Filename).To(Equal("model.gguf"), fmt.Sprintf("Model config: %+v", modelConfig)) }) }) }) ================================================ FILE: core/gallery/importers/mlx.go ================================================ package importers import ( "encoding/json" "path/filepath" "strings" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" "go.yaml.in/yaml/v2" ) var _ Importer = &MLXImporter{} type MLXImporter struct{} func (i *MLXImporter) Match(details Details) bool { preferences, err := details.Preferences.MarshalJSON() if err != nil { return false } preferencesMap := make(map[string]any) err = json.Unmarshal(preferences, &preferencesMap) if err != nil { return false } b, ok := preferencesMap["backend"].(string) if ok && b == "mlx" || b == "mlx-vlm" { return true } // All https://huggingface.co/mlx-community/* if strings.Contains(details.URI, "mlx-community/") { return true } return false } func (i *MLXImporter) Import(details Details) (gallery.ModelConfig, error) { preferences, err := details.Preferences.MarshalJSON() if err != nil { return gallery.ModelConfig{}, err } preferencesMap := make(map[string]any) err = json.Unmarshal(preferences, &preferencesMap) if err != nil { return gallery.ModelConfig{}, err } name, ok := preferencesMap["name"].(string) if !ok { name = filepath.Base(details.URI) } description, ok := preferencesMap["description"].(string) if !ok { description = "Imported from " + details.URI } backend := "mlx" b, ok := preferencesMap["backend"].(string) if ok { backend = b } modelConfig := config.ModelConfig{ Name: name, Description: description, KnownUsecaseStrings: []string{"chat"}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: details.URI, }, }, TemplateConfig: config.TemplateConfig{ UseTokenizerTemplate: true, }, } data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err } return gallery.ModelConfig{ Name: name, Description: description, ConfigFile: string(data), }, nil } ================================================ FILE: core/gallery/importers/mlx_test.go ================================================ package importers_test import ( "encoding/json" "github.com/mudler/LocalAI/core/gallery/importers" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("MLXImporter", func() { var importer *importers.MLXImporter BeforeEach(func() { importer = &importers.MLXImporter{} }) Context("Match", func() { It("should match when URI contains mlx-community/", func() { details := importers.Details{ URI: "https://huggingface.co/mlx-community/test-model", } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should match when backend preference is mlx", func() { preferences := json.RawMessage(`{"backend": "mlx"}`) details := importers.Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should match when backend preference is mlx-vlm", func() { preferences := json.RawMessage(`{"backend": "mlx-vlm"}`) details := importers.Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should not match when URI does not contain mlx-community/ and no backend preference", func() { details := importers.Details{ URI: "https://huggingface.co/other-org/test-model", } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should not match when backend preference is different", func() { preferences := json.RawMessage(`{"backend": "llama-cpp"}`) details := importers.Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should return false when JSON preferences are invalid", func() { preferences := json.RawMessage(`invalid json`) details := importers.Details{ URI: "https://huggingface.co/mlx-community/test-model", Preferences: preferences, } // Invalid JSON causes Match to return false early result := importer.Match(details) Expect(result).To(BeFalse()) }) }) Context("Import", func() { It("should import model config with default name and description", func() { details := importers.Details{ URI: "https://huggingface.co/mlx-community/test-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("test-model")) Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/mlx-community/test-model")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx")) Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/mlx-community/test-model")) }) It("should import model config with custom name and description from preferences", func() { preferences := json.RawMessage(`{"name": "custom-mlx-model", "description": "Custom MLX description"}`) details := importers.Details{ URI: "https://huggingface.co/mlx-community/test-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("custom-mlx-model")) Expect(modelConfig.Description).To(Equal("Custom MLX description")) }) It("should use custom backend from preferences", func() { preferences := json.RawMessage(`{"backend": "mlx-vlm"}`) details := importers.Details{ URI: "https://huggingface.co/mlx-community/test-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: mlx-vlm")) }) It("should handle invalid JSON preferences", func() { preferences := json.RawMessage(`invalid json`) details := importers.Details{ URI: "https://huggingface.co/mlx-community/test-model", Preferences: preferences, } _, err := importer.Import(details) Expect(err).To(HaveOccurred()) }) It("should extract filename correctly from URI with path", func() { details := importers.Details{ URI: "https://huggingface.co/mlx-community/path/to/model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("model")) }) }) }) ================================================ FILE: core/gallery/importers/transformers.go ================================================ package importers import ( "encoding/json" "path/filepath" "strings" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" "go.yaml.in/yaml/v2" ) var _ Importer = &TransformersImporter{} type TransformersImporter struct{} func (i *TransformersImporter) Match(details Details) bool { preferences, err := details.Preferences.MarshalJSON() if err != nil { return false } preferencesMap := make(map[string]any) err = json.Unmarshal(preferences, &preferencesMap) if err != nil { return false } b, ok := preferencesMap["backend"].(string) if ok && b == "transformers" { return true } if details.HuggingFace != nil { for _, file := range details.HuggingFace.Files { if strings.Contains(file.Path, "tokenizer.json") || strings.Contains(file.Path, "tokenizer_config.json") { return true } } } return false } func (i *TransformersImporter) Import(details Details) (gallery.ModelConfig, error) { preferences, err := details.Preferences.MarshalJSON() if err != nil { return gallery.ModelConfig{}, err } preferencesMap := make(map[string]any) err = json.Unmarshal(preferences, &preferencesMap) if err != nil { return gallery.ModelConfig{}, err } name, ok := preferencesMap["name"].(string) if !ok { name = filepath.Base(details.URI) } description, ok := preferencesMap["description"].(string) if !ok { description = "Imported from " + details.URI } backend := "transformers" b, ok := preferencesMap["backend"].(string) if ok { backend = b } modelType, ok := preferencesMap["type"].(string) if !ok { modelType = "AutoModelForCausalLM" } quantization, ok := preferencesMap["quantization"].(string) if !ok { quantization = "" } modelConfig := config.ModelConfig{ Name: name, Description: description, KnownUsecaseStrings: []string{"chat"}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: details.URI, }, }, TemplateConfig: config.TemplateConfig{ UseTokenizerTemplate: true, }, } modelConfig.ModelType = modelType modelConfig.Quantization = quantization data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err } return gallery.ModelConfig{ Name: name, Description: description, ConfigFile: string(data), }, nil } ================================================ FILE: core/gallery/importers/transformers_test.go ================================================ package importers_test import ( "encoding/json" "github.com/mudler/LocalAI/core/gallery/importers" . "github.com/mudler/LocalAI/core/gallery/importers" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("TransformersImporter", func() { var importer *TransformersImporter BeforeEach(func() { importer = &TransformersImporter{} }) Context("Match", func() { It("should match when backend preference is transformers", func() { preferences := json.RawMessage(`{"backend": "transformers"}`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should match when HuggingFace details contain tokenizer.json", func() { hfDetails := &hfapi.ModelDetails{ Files: []hfapi.ModelFile{ {Path: "tokenizer.json"}, }, } details := Details{ URI: "https://huggingface.co/test/model", HuggingFace: hfDetails, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should match when HuggingFace details contain tokenizer_config.json", func() { hfDetails := &hfapi.ModelDetails{ Files: []hfapi.ModelFile{ {Path: "tokenizer_config.json"}, }, } details := Details{ URI: "https://huggingface.co/test/model", HuggingFace: hfDetails, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should not match when URI has no tokenizer files and no backend preference", func() { details := Details{ URI: "https://example.com/model.bin", } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should not match when backend preference is different", func() { preferences := json.RawMessage(`{"backend": "llama-cpp"}`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should return false when JSON preferences are invalid", func() { preferences := json.RawMessage(`invalid json`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeFalse()) }) }) Context("Import", func() { It("should import model config with default name and description", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("my-model")) Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers")) Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model")) Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM")) }) It("should import model config with custom name and description from preferences", func() { preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("custom-model")) Expect(modelConfig.Description).To(Equal("Custom description")) }) It("should use custom model type from preferences", func() { preferences := json.RawMessage(`{"type": "SentenceTransformer"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("type: SentenceTransformer")) }) It("should use default model type when not specified", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("type: AutoModelForCausalLM")) }) It("should use custom backend from preferences", func() { preferences := json.RawMessage(`{"backend": "transformers"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: transformers")) }) It("should use quantization from preferences", func() { preferences := json.RawMessage(`{"quantization": "int8"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("quantization: int8")) }) It("should handle invalid JSON preferences", func() { preferences := json.RawMessage(`invalid json`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } _, err := importer.Import(details) Expect(err).To(HaveOccurred()) }) It("should extract filename correctly from URI with path", func() { details := importers.Details{ URI: "https://huggingface.co/test/path/to/model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("model")) }) It("should include use_tokenizer_template in config", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true")) }) It("should include known_usecases in config", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:")) Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat")) }) }) }) ================================================ FILE: core/gallery/importers/vllm.go ================================================ package importers import ( "encoding/json" "path/filepath" "strings" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/schema" "go.yaml.in/yaml/v2" ) var _ Importer = &VLLMImporter{} type VLLMImporter struct{} func (i *VLLMImporter) Match(details Details) bool { preferences, err := details.Preferences.MarshalJSON() if err != nil { return false } preferencesMap := make(map[string]any) err = json.Unmarshal(preferences, &preferencesMap) if err != nil { return false } b, ok := preferencesMap["backend"].(string) if ok && b == "vllm" { return true } if details.HuggingFace != nil { for _, file := range details.HuggingFace.Files { if strings.Contains(file.Path, "tokenizer.json") || strings.Contains(file.Path, "tokenizer_config.json") { return true } } } return false } func (i *VLLMImporter) Import(details Details) (gallery.ModelConfig, error) { preferences, err := details.Preferences.MarshalJSON() if err != nil { return gallery.ModelConfig{}, err } preferencesMap := make(map[string]any) err = json.Unmarshal(preferences, &preferencesMap) if err != nil { return gallery.ModelConfig{}, err } name, ok := preferencesMap["name"].(string) if !ok { name = filepath.Base(details.URI) } description, ok := preferencesMap["description"].(string) if !ok { description = "Imported from " + details.URI } backend := "vllm" b, ok := preferencesMap["backend"].(string) if ok { backend = b } modelConfig := config.ModelConfig{ Name: name, Description: description, KnownUsecaseStrings: []string{"chat"}, Backend: backend, PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{ Model: details.URI, }, }, TemplateConfig: config.TemplateConfig{ UseTokenizerTemplate: true, }, } data, err := yaml.Marshal(modelConfig) if err != nil { return gallery.ModelConfig{}, err } return gallery.ModelConfig{ Name: name, Description: description, ConfigFile: string(data), }, nil } ================================================ FILE: core/gallery/importers/vllm_test.go ================================================ package importers_test import ( "encoding/json" "github.com/mudler/LocalAI/core/gallery/importers" . "github.com/mudler/LocalAI/core/gallery/importers" hfapi "github.com/mudler/LocalAI/pkg/huggingface-api" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("VLLMImporter", func() { var importer *VLLMImporter BeforeEach(func() { importer = &VLLMImporter{} }) Context("Match", func() { It("should match when backend preference is vllm", func() { preferences := json.RawMessage(`{"backend": "vllm"}`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should match when HuggingFace details contain tokenizer.json", func() { hfDetails := &hfapi.ModelDetails{ Files: []hfapi.ModelFile{ {Path: "tokenizer.json"}, }, } details := Details{ URI: "https://huggingface.co/test/model", HuggingFace: hfDetails, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should match when HuggingFace details contain tokenizer_config.json", func() { hfDetails := &hfapi.ModelDetails{ Files: []hfapi.ModelFile{ {Path: "tokenizer_config.json"}, }, } details := Details{ URI: "https://huggingface.co/test/model", HuggingFace: hfDetails, } result := importer.Match(details) Expect(result).To(BeTrue()) }) It("should not match when URI has no tokenizer files and no backend preference", func() { details := Details{ URI: "https://example.com/model.bin", } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should not match when backend preference is different", func() { preferences := json.RawMessage(`{"backend": "llama-cpp"}`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeFalse()) }) It("should return false when JSON preferences are invalid", func() { preferences := json.RawMessage(`invalid json`) details := Details{ URI: "https://example.com/model", Preferences: preferences, } result := importer.Match(details) Expect(result).To(BeFalse()) }) }) Context("Import", func() { It("should import model config with default name and description", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("my-model")) Expect(modelConfig.Description).To(Equal("Imported from https://huggingface.co/test/my-model")) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm")) Expect(modelConfig.ConfigFile).To(ContainSubstring("model: https://huggingface.co/test/my-model")) }) It("should import model config with custom name and description from preferences", func() { preferences := json.RawMessage(`{"name": "custom-model", "description": "Custom description"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("custom-model")) Expect(modelConfig.Description).To(Equal("Custom description")) }) It("should use custom backend from preferences", func() { preferences := json.RawMessage(`{"backend": "vllm"}`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("backend: vllm")) }) It("should handle invalid JSON preferences", func() { preferences := json.RawMessage(`invalid json`) details := Details{ URI: "https://huggingface.co/test/my-model", Preferences: preferences, } _, err := importer.Import(details) Expect(err).To(HaveOccurred()) }) It("should extract filename correctly from URI with path", func() { details := importers.Details{ URI: "https://huggingface.co/test/path/to/model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.Name).To(Equal("model")) }) It("should include use_tokenizer_template in config", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("use_tokenizer_template: true")) }) It("should include known_usecases in config", func() { details := Details{ URI: "https://huggingface.co/test/my-model", } modelConfig, err := importer.Import(details) Expect(err).ToNot(HaveOccurred()) Expect(modelConfig.ConfigFile).To(ContainSubstring("known_usecases:")) Expect(modelConfig.ConfigFile).To(ContainSubstring("- chat")) }) }) }) ================================================ FILE: core/gallery/metadata_type.go ================================================ package gallery import "github.com/mudler/LocalAI/core/config" type Metadata struct { URL string `json:"url,omitempty" yaml:"url,omitempty"` Name string `json:"name,omitempty" yaml:"name,omitempty"` Description string `json:"description,omitempty" yaml:"description,omitempty"` License string `json:"license,omitempty" yaml:"license,omitempty"` URLs []string `json:"urls,omitempty" yaml:"urls,omitempty"` Icon string `json:"icon,omitempty" yaml:"icon,omitempty"` Tags []string `json:"tags,omitempty" yaml:"tags,omitempty"` // AdditionalFiles are used to add additional files to the model AdditionalFiles []File `json:"files,omitempty" yaml:"files,omitempty"` // Size is an optional hardcoded model size string (e.g. "500MB", "14.5GB"). // Used when the size cannot be estimated automatically. Size string `json:"size,omitempty" yaml:"size,omitempty"` // Gallery is a reference to the gallery which contains the model Gallery config.Gallery `json:"gallery,omitempty" yaml:"gallery,omitempty"` // Installed is used to indicate if the model is installed or not Installed bool `json:"installed,omitempty" yaml:"installed,omitempty"` // Backend is the resolved backend engine for this model (e.g. "llama-cpp"). // Populated at load time from overrides, inline config, or the URL-referenced config file. Backend string `json:"backend,omitempty" yaml:"backend,omitempty"` } ================================================ FILE: core/gallery/models.go ================================================ package gallery import ( "context" "errors" "fmt" "os" "path/filepath" "slices" "strings" "dario.cat/mergo" lconfig "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/xlog" "gopkg.in/yaml.v3" ) /* description: | foo license: "" urls: - - name: "bar" config_file: | # Note, name will be injected. or generated by the alias wanted by the user threads: 14 files: - filename: "" sha: "" uri: "" prompt_templates: - name: "" content: "" */ // ModelConfig is the model configuration which contains all the model details // This configuration is read from the gallery endpoint and is used to download and install the model // It is the internal structure, separated from the request type ModelConfig struct { Description string `yaml:"description"` Icon string `yaml:"icon"` License string `yaml:"license"` URLs []string `yaml:"urls"` Name string `yaml:"name"` ConfigFile string `yaml:"config_file"` Files []File `yaml:"files"` PromptTemplates []PromptTemplate `yaml:"prompt_templates"` } type File struct { Filename string `yaml:"filename" json:"filename"` SHA256 string `yaml:"sha256" json:"sha256"` URI string `yaml:"uri" json:"uri"` } type PromptTemplate struct { Name string `yaml:"name"` Content string `yaml:"content"` } // Installs a model from the gallery func InstallModelFromGallery( ctx context.Context, modelGalleries, backendGalleries []lconfig.Gallery, systemState *system.SystemState, modelLoader *model.ModelLoader, name string, req GalleryModel, downloadStatus func(string, string, string, float64), enforceScan, automaticallyInstallBackend bool) error { applyModel := func(model *GalleryModel) error { name = strings.ReplaceAll(name, string(os.PathSeparator), "__") var config ModelConfig if len(model.URL) > 0 { var err error config, err = GetGalleryConfigFromURLWithContext[ModelConfig](ctx, model.URL, systemState.Model.ModelsPath) if err != nil { return err } config.Description = model.Description config.License = model.License } else if len(model.ConfigFile) > 0 { // TODO: is this worse than using the override method with a blank cfg yaml? reYamlConfig, err := yaml.Marshal(model.ConfigFile) if err != nil { return err } config = ModelConfig{ ConfigFile: string(reYamlConfig), Description: model.Description, License: model.License, URLs: model.URLs, Name: model.Name, Files: make([]File, 0), // Real values get added below, must be blank // Prompt Template Skipped for now - I expect in this mode that they will be delivered as files. } } else { return fmt.Errorf("invalid gallery model %+v", model) } installName := model.Name if req.Name != "" { installName = req.Name } // Copy the model configuration from the request schema config.URLs = append(config.URLs, model.URLs...) config.Icon = model.Icon config.Files = append(config.Files, req.AdditionalFiles...) config.Files = append(config.Files, model.AdditionalFiles...) // TODO model.Overrides could be merged with user overrides (not defined yet) if req.Overrides != nil { if err := mergo.Merge(&model.Overrides, req.Overrides, mergo.WithOverride); err != nil { return err } } installedModel, err := InstallModel(ctx, systemState, installName, &config, model.Overrides, downloadStatus, enforceScan) if err != nil { return err } xlog.Debug("Installed model", "model", installedModel.Name) if automaticallyInstallBackend && installedModel.Backend != "" { xlog.Debug("Installing backend", "backend", installedModel.Backend) if err := InstallBackendFromGallery(ctx, backendGalleries, systemState, modelLoader, installedModel.Backend, downloadStatus, false); err != nil { return err } } return nil } models, err := AvailableGalleryModels(modelGalleries, systemState) if err != nil { return err } model := FindGalleryElement(models, name) if model == nil { return fmt.Errorf("no model found with name %q", name) } return applyModel(model) } func InstallModel(ctx context.Context, systemState *system.SystemState, nameOverride string, config *ModelConfig, configOverrides map[string]interface{}, downloadStatus func(string, string, string, float64), enforceScan bool) (*lconfig.ModelConfig, error) { basePath := systemState.Model.ModelsPath // Create base path if it doesn't exist err := os.MkdirAll(basePath, 0750) if err != nil { return nil, fmt.Errorf("failed to create base path: %v", err) } if len(configOverrides) > 0 { xlog.Debug("Config overrides", "overrides", configOverrides) } // Download files and verify their SHA for i, file := range config.Files { // Check for cancellation before each file select { case <-ctx.Done(): return nil, ctx.Err() default: } xlog.Debug("Checking file exists and matches SHA", "filename", file.Filename) if err := utils.VerifyPath(file.Filename, basePath); err != nil { return nil, err } // Create file path filePath := filepath.Join(basePath, file.Filename) if enforceScan { scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI)) if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) { xlog.Error("Contains unsafe file(s)!", "model", config.Name, "clamAV", scanResults.ClamAVInfectedFiles, "pickles", scanResults.DangerousPickles) return nil, err } } uri := downloader.URI(file.URI) if err := uri.DownloadFileWithContext(ctx, filePath, file.SHA256, i, len(config.Files), downloadStatus); err != nil { return nil, err } } // Write prompt template contents to separate files for _, template := range config.PromptTemplates { if err := utils.VerifyPath(template.Name+".tmpl", basePath); err != nil { return nil, err } // Create file path filePath := filepath.Join(basePath, template.Name+".tmpl") // Create parent directory err := os.MkdirAll(filepath.Dir(filePath), 0750) if err != nil { return nil, fmt.Errorf("failed to create parent directory for prompt template %q: %v", template.Name, err) } // Create and write file content err = os.WriteFile(filePath, []byte(template.Content), 0644) if err != nil { return nil, fmt.Errorf("failed to write prompt template %q: %v", template.Name, err) } xlog.Debug("Prompt template written", "template", template.Name) } name := config.Name if nameOverride != "" { name = nameOverride } if err := utils.VerifyPath(name+".yaml", basePath); err != nil { return nil, err } modelConfig := lconfig.ModelConfig{} // write config file if len(configOverrides) != 0 || len(config.ConfigFile) != 0 { configFilePath := filepath.Join(basePath, name+".yaml") // Read and update config file as map[string]interface{} configMap := make(map[string]interface{}) err = yaml.Unmarshal([]byte(config.ConfigFile), &configMap) if err != nil { return nil, fmt.Errorf("failed to unmarshal config YAML: %v", err) } configMap["name"] = name if configOverrides != nil { if err := mergo.Merge(&configMap, configOverrides, mergo.WithOverride); err != nil { return nil, err } } // Write updated config file updatedConfigYAML, err := yaml.Marshal(configMap) if err != nil { return nil, fmt.Errorf("failed to marshal updated config YAML: %v", err) } err = yaml.Unmarshal(updatedConfigYAML, &modelConfig) if err != nil { return nil, fmt.Errorf("failed to unmarshal updated config YAML: %v", err) } if valid, err := modelConfig.Validate(); !valid { return nil, fmt.Errorf("failed to validate updated config YAML: %v", err) } err = os.WriteFile(configFilePath, updatedConfigYAML, 0644) if err != nil { return nil, fmt.Errorf("failed to write updated config file: %v", err) } xlog.Debug("Written config file", "file", configFilePath) } // Save the model gallery file for further reference modelFile := filepath.Join(basePath, galleryFileName(name)) data, err := yaml.Marshal(config) if err != nil { return nil, err } xlog.Debug("Written gallery file", "file", modelFile) return &modelConfig, os.WriteFile(modelFile, data, 0644) } func galleryFileName(name string) string { return "._gallery_" + name + ".yaml" } func GetLocalModelConfiguration(basePath string, name string) (*ModelConfig, error) { name = strings.ReplaceAll(name, string(os.PathSeparator), "__") galleryFile := filepath.Join(basePath, galleryFileName(name)) return ReadConfigFile[ModelConfig](galleryFile) } func listModelFiles(systemState *system.SystemState, name string) ([]string, error) { configFile := filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", name)) if err := utils.VerifyPath(configFile, systemState.Model.ModelsPath); err != nil { return nil, fmt.Errorf("failed to verify path %s: %w", configFile, err) } // os.PathSeparator is not allowed in model names. Replace them with "__" to avoid conflicts with file paths. name = strings.ReplaceAll(name, string(os.PathSeparator), "__") galleryFile := filepath.Join(systemState.Model.ModelsPath, galleryFileName(name)) if err := utils.VerifyPath(galleryFile, systemState.Model.ModelsPath); err != nil { return nil, fmt.Errorf("failed to verify path %s: %w", galleryFile, err) } additionalFiles := []string{} allFiles := []string{} // Galleryname is the name of the model in this case dat, err := os.ReadFile(configFile) if err == nil { modelConfig := &lconfig.ModelConfig{} err = yaml.Unmarshal(dat, &modelConfig) if err != nil { return nil, err } if modelConfig.Model != "" { additionalFiles = append(additionalFiles, modelConfig.ModelFileName()) } if modelConfig.MMProj != "" { additionalFiles = append(additionalFiles, modelConfig.MMProjFileName()) } } // read the model config galleryconfig, err := ReadConfigFile[ModelConfig](galleryFile) if err == nil && galleryconfig != nil { for _, f := range galleryconfig.Files { fullPath := filepath.Join(systemState.Model.ModelsPath, f.Filename) if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil { return allFiles, fmt.Errorf("failed to verify path %s: %w", fullPath, err) } allFiles = append(allFiles, fullPath) } } else { xlog.Error("failed to read gallery file", "error", err, "file", configFile) } for _, f := range additionalFiles { fullPath := filepath.Join(filepath.Join(systemState.Model.ModelsPath, f)) if err := utils.VerifyPath(fullPath, systemState.Model.ModelsPath); err != nil { return allFiles, fmt.Errorf("failed to verify path %s: %w", fullPath, err) } allFiles = append(allFiles, fullPath) } allFiles = append(allFiles, galleryFile) // skip duplicates allFiles = utils.Unique(allFiles) return allFiles, nil } func DeleteModelFromSystem(systemState *system.SystemState, name string) error { configFile := filepath.Join(systemState.Model.ModelsPath, fmt.Sprintf("%s.yaml", name)) filesToRemove, err := listModelFiles(systemState, name) if err != nil { return err } allOtherFiles := []string{} // Get all files of all other models fi, err := os.ReadDir(systemState.Model.ModelsPath) if err != nil { return err } for _, f := range fi { if f.IsDir() { continue } if strings.HasPrefix(f.Name(), "._gallery_") { continue } if !strings.HasSuffix(f.Name(), ".yaml") && !strings.HasSuffix(f.Name(), ".yml") { continue } if f.Name() == fmt.Sprintf("%s.yaml", name) || f.Name() == fmt.Sprintf("%s.yml", name) { continue } name := strings.TrimSuffix(f.Name(), ".yaml") name = strings.TrimSuffix(name, ".yml") xlog.Debug("Checking file", "file", f.Name()) files, err := listModelFiles(systemState, name) if err != nil { xlog.Debug("failed to list files for model", "error", err, "model", f.Name()) continue } allOtherFiles = append(allOtherFiles, files...) } xlog.Debug("Files to remove", "files", filesToRemove) xlog.Debug("All other files", "files", allOtherFiles) // Removing files for _, f := range filesToRemove { if slices.Contains(allOtherFiles, f) { xlog.Debug("Skipping file because it is part of another model", "file", f) continue } if e := os.Remove(f); e != nil { xlog.Error("failed to remove file", "error", e, "file", f) } } return os.Remove(configFile) } // This is ***NEVER*** going to be perfect or finished. // This is a BEST EFFORT function to surface known-vulnerable models to users. func SafetyScanGalleryModels(galleries []lconfig.Gallery, systemState *system.SystemState) error { galleryModels, err := AvailableGalleryModels(galleries, systemState) if err != nil { return err } for _, gM := range galleryModels { if gM.Installed { err = errors.Join(err, SafetyScanGalleryModel(gM)) } } return err } func SafetyScanGalleryModel(galleryModel *GalleryModel) error { for _, file := range galleryModel.AdditionalFiles { scanResults, err := downloader.HuggingFaceScan(downloader.URI(file.URI)) if err != nil && errors.Is(err, downloader.ErrUnsafeFilesFound) { xlog.Error("Contains unsafe file(s)!", "model", galleryModel.Name, "clamAV", scanResults.ClamAVInfectedFiles, "pickles", scanResults.DangerousPickles) return err } } return nil } ================================================ FILE: core/gallery/models_test.go ================================================ package gallery_test import ( "context" "errors" "os" "path/filepath" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" ) const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml` var _ = Describe("Model test", func() { BeforeEach(func() { if os.Getenv("FIXTURES") == "" { Skip("FIXTURES env var not set, skipping model tests") } }) Context("Downloading", func() { It("applies model correctly", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) systemState, err := system.GetSystemState( system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) _, err = InstallModel(context.TODO(), systemState, "", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "cerebras.yaml"} { _, err = os.Stat(filepath.Join(tempdir, f)) Expect(err).ToNot(HaveOccurred()) } content := map[string]interface{}{} dat, err := os.ReadFile(filepath.Join(tempdir, "cerebras.yaml")) Expect(err).ToNot(HaveOccurred()) err = yaml.Unmarshal(dat, content) Expect(err).ToNot(HaveOccurred()) Expect(content["context_size"]).To(Equal(1024)) }) It("applies model from gallery correctly", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) gallery := []GalleryModel{{ Metadata: Metadata{ Name: "bert", URL: bertEmbeddingsURL, }, }} out, err := yaml.Marshal(gallery) Expect(err).ToNot(HaveOccurred()) galleryFilePath := filepath.Join(tempdir, "gallery_simple.yaml") err = os.WriteFile(galleryFilePath, out, 0600) Expect(err).ToNot(HaveOccurred()) Expect(filepath.IsAbs(galleryFilePath)).To(BeTrue(), galleryFilePath) galleries := []config.Gallery{ { Name: "test", URL: "file://" + galleryFilePath, }, } systemState, err := system.GetSystemState( system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) models, err := AvailableGalleryModels(galleries, systemState) Expect(err).ToNot(HaveOccurred()) Expect(len(models)).To(Equal(1)) Expect(models[0].Name).To(Equal("bert")) Expect(models[0].URL).To(Equal(bertEmbeddingsURL)) Expect(models[0].Installed).To(BeFalse()) err = InstallModelFromGallery(context.TODO(), galleries, []config.Gallery{}, systemState, nil, "test@bert", GalleryModel{}, func(s1, s2, s3 string, f float64) {}, true, true) Expect(err).ToNot(HaveOccurred()) dat, err := os.ReadFile(filepath.Join(tempdir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) models, err = AvailableGalleryModels(galleries, systemState) Expect(err).ToNot(HaveOccurred()) Expect(len(models)).To(Equal(1)) Expect(models[0].Installed).To(BeTrue()) // delete err = DeleteModelFromSystem(systemState, "bert") Expect(err).ToNot(HaveOccurred()) models, err = AvailableGalleryModels(galleries, systemState) Expect(err).ToNot(HaveOccurred()) Expect(len(models)).To(Equal(1)) Expect(models[0].Installed).To(BeFalse()) _, err = os.Stat(filepath.Join(tempdir, "bert.yaml")) Expect(err).To(HaveOccurred()) Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue()) }) It("renames model correctly", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) systemState, err := system.GetSystemState( system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { _, err = os.Stat(filepath.Join(tempdir, f)) Expect(err).ToNot(HaveOccurred()) } }) It("overrides parameters", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) systemState, err := system.GetSystemState( system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) _, err = InstallModel(context.TODO(), systemState, "foo", c, map[string]interface{}{"backend": "foo"}, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "foo.yaml"} { _, err = os.Stat(filepath.Join(tempdir, f)) Expect(err).ToNot(HaveOccurred()) } content := map[string]interface{}{} dat, err := os.ReadFile(filepath.Join(tempdir, "foo.yaml")) Expect(err).ToNot(HaveOccurred()) err = yaml.Unmarshal(dat, content) Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("foo")) }) It("catches path traversals", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) systemState, err := system.GetSystemState( system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) _, err = InstallModel(context.TODO(), systemState, "../../../foo", c, map[string]interface{}{}, func(string, string, string, float64) {}, true) Expect(err).To(HaveOccurred()) }) It("handles nil configOverrides without panic", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) c, err := ReadConfigFile[ModelConfig](filepath.Join(os.Getenv("FIXTURES"), "gallery_simple.yaml")) Expect(err).ToNot(HaveOccurred()) systemState, err := system.GetSystemState( system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) _, err = InstallModel(context.TODO(), systemState, "test-model", c, nil, func(string, string, string, float64) {}, true) Expect(err).ToNot(HaveOccurred()) for _, f := range []string{"cerebras", "cerebras-completion.tmpl", "cerebras-chat.tmpl", "test-model.yaml"} { _, err = os.Stat(filepath.Join(tempdir, f)) Expect(err).ToNot(HaveOccurred()) } }) It("does not delete shared model files when one config is deleted", func() { tempdir, err := os.MkdirTemp("", "test") Expect(err).ToNot(HaveOccurred()) defer os.RemoveAll(tempdir) systemState, err := system.GetSystemState( system.WithModelPath(tempdir), ) Expect(err).ToNot(HaveOccurred()) // Create a shared model file sharedModelFile := filepath.Join(tempdir, "shared_model.bin") err = os.WriteFile(sharedModelFile, []byte("fake model content"), 0600) Expect(err).ToNot(HaveOccurred()) // Create first model configuration config1 := `name: model1 model: shared_model.bin` err = os.WriteFile(filepath.Join(tempdir, "model1.yaml"), []byte(config1), 0600) Expect(err).ToNot(HaveOccurred()) // Create first model's gallery file galleryConfig1 := ModelConfig{ Name: "model1", Files: []File{ {Filename: "shared_model.bin"}, }, } galleryData1, err := yaml.Marshal(galleryConfig1) Expect(err).ToNot(HaveOccurred()) err = os.WriteFile(filepath.Join(tempdir, "._gallery_model1.yaml"), galleryData1, 0600) Expect(err).ToNot(HaveOccurred()) // Create second model configuration sharing the same model file config2 := `name: model2 model: shared_model.bin` err = os.WriteFile(filepath.Join(tempdir, "model2.yaml"), []byte(config2), 0600) Expect(err).ToNot(HaveOccurred()) // Create second model's gallery file galleryConfig2 := ModelConfig{ Name: "model2", Files: []File{ {Filename: "shared_model.bin"}, }, } galleryData2, err := yaml.Marshal(galleryConfig2) Expect(err).ToNot(HaveOccurred()) err = os.WriteFile(filepath.Join(tempdir, "._gallery_model2.yaml"), galleryData2, 0600) Expect(err).ToNot(HaveOccurred()) // Verify both configurations exist _, err = os.Stat(filepath.Join(tempdir, "model1.yaml")) Expect(err).ToNot(HaveOccurred()) _, err = os.Stat(filepath.Join(tempdir, "model2.yaml")) Expect(err).ToNot(HaveOccurred()) // Verify the shared model file exists _, err = os.Stat(sharedModelFile) Expect(err).ToNot(HaveOccurred()) // Delete the first model err = DeleteModelFromSystem(systemState, "model1") Expect(err).ToNot(HaveOccurred()) // Verify the first configuration is deleted _, err = os.Stat(filepath.Join(tempdir, "model1.yaml")) Expect(err).To(HaveOccurred()) Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue()) // Verify the shared model file still exists (not deleted because model2 still uses it) _, err = os.Stat(sharedModelFile) Expect(err).ToNot(HaveOccurred(), "shared model file should not be deleted when used by other configs") // Verify the second configuration still exists _, err = os.Stat(filepath.Join(tempdir, "model2.yaml")) Expect(err).ToNot(HaveOccurred()) // Now delete the second model err = DeleteModelFromSystem(systemState, "model2") Expect(err).ToNot(HaveOccurred()) // Verify the second configuration is deleted _, err = os.Stat(filepath.Join(tempdir, "model2.yaml")) Expect(err).To(HaveOccurred()) Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue()) // Verify the shared model file is now deleted (no more references) _, err = os.Stat(sharedModelFile) Expect(err).To(HaveOccurred(), "shared model file should be deleted when no configs reference it") Expect(errors.Is(err, os.ErrNotExist)).To(BeTrue()) }) }) }) ================================================ FILE: core/gallery/models_types.go ================================================ package gallery import ( "fmt" "github.com/mudler/LocalAI/core/config" ) // GalleryModel is the struct used to represent a model in the gallery returned by the endpoint. // It is used to install the model by resolving the URL and downloading the files. // The other fields are used to override the configuration of the model. type GalleryModel struct { Metadata `json:",inline" yaml:",inline"` // config_file is read in the situation where URL is blank - and therefore this is a base config. ConfigFile map[string]interface{} `json:"config_file,omitempty" yaml:"config_file,omitempty"` // Overrides are used to override the configuration of the model located at URL Overrides map[string]interface{} `json:"overrides,omitempty" yaml:"overrides,omitempty"` } func (m *GalleryModel) GetInstalled() bool { return m.Installed } func (m *GalleryModel) GetLicense() string { return m.License } func (m *GalleryModel) SetGallery(gallery config.Gallery) { m.Gallery = gallery } func (m *GalleryModel) SetInstalled(installed bool) { m.Installed = installed } func (m *GalleryModel) GetName() string { return m.Name } func (m *GalleryModel) GetGallery() config.Gallery { return m.Gallery } func (m GalleryModel) ID() string { return fmt.Sprintf("%s@%s", m.Gallery.Name, m.Name) } func (m *GalleryModel) GetTags() []string { return m.Tags } func (m *GalleryModel) GetDescription() string { return m.Description } ================================================ FILE: core/gallery/request_test.go ================================================ package gallery_test import ( . "github.com/mudler/LocalAI/core/gallery" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Gallery API tests", func() { Context("requests", func() { It("parses github with a branch", func() { req := GalleryModel{ Metadata: Metadata{ URL: "github:go-skynet/model-gallery/gpt4all-j.yaml@main", }, } e, err := GetGalleryConfigFromURL[ModelConfig](req.URL, "") Expect(err).ToNot(HaveOccurred()) Expect(e.Name).To(Equal("gpt4all-j")) }) }) }) ================================================ FILE: core/http/app.go ================================================ package http import ( "embed" "errors" "fmt" "io/fs" "mime" "net/http" "os" "path/filepath" "strings" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/http/endpoints/localai" httpMiddleware "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/xlog" ) // Embed a directory // //go:embed static/* var embedDirStatic embed.FS // Embed React UI build output // //go:embed react-ui/dist/* var reactUI embed.FS var quietPaths = []string{"/api/operations", "/api/resources", "/healthz", "/readyz"} // @title LocalAI API // @version 2.0.0 // @description The LocalAI Rest API. // @termsOfService // @contact.name LocalAI // @contact.url https://localai.io // @license.name MIT // @license.url https://raw.githubusercontent.com/mudler/LocalAI/master/LICENSE // @BasePath / // @securityDefinitions.apikey BearerAuth // @in header // @name Authorization func API(application *application.Application) (*echo.Echo, error) { e := echo.New() // Set body limit if application.ApplicationConfig().UploadLimitMB > 0 { e.Use(middleware.BodyLimit(fmt.Sprintf("%dM", application.ApplicationConfig().UploadLimitMB))) } // SPA fallback handler, set later when React UI is available var spaFallback func(echo.Context) error // Set error handler if !application.ApplicationConfig().OpaqueErrors { e.HTTPErrorHandler = func(err error, c echo.Context) { code := http.StatusInternalServerError var he *echo.HTTPError if errors.As(err, &he) { code = he.Code } // Handle 404 errors: serve React SPA for HTML requests, JSON otherwise if code == http.StatusNotFound { if spaFallback != nil { accept := c.Request().Header.Get("Accept") contentType := c.Request().Header.Get("Content-Type") if strings.Contains(accept, "text/html") && !strings.Contains(contentType, "application/json") { spaFallback(c) return } } notFoundHandler(c) return } // Send custom error page c.JSON(code, schema.ErrorResponse{ Error: &schema.APIError{Message: err.Error(), Code: code}, }) } } else { e.HTTPErrorHandler = func(err error, c echo.Context) { code := http.StatusInternalServerError var he *echo.HTTPError if errors.As(err, &he) { code = he.Code } c.NoContent(code) } } // Set renderer e.Renderer = renderEngine() // Hide banner e.HideBanner = true e.HidePort = true // Middleware - StripPathPrefix must be registered early as it uses Rewrite which runs before routing e.Pre(httpMiddleware.StripPathPrefix()) e.Pre(middleware.RemoveTrailingSlash()) if application.ApplicationConfig().MachineTag != "" { e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { c.Response().Header().Set("Machine-Tag", application.ApplicationConfig().MachineTag) return next(c) } }) } // Custom logger middleware using xlog e.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { req := c.Request() res := c.Response() err := next(c) // Fix for #7989: Reduce log verbosity of Web UI polling, resources API, and health checks // These paths are logged at DEBUG level (hidden by default) instead of INFO. isQuietPath := false for _, path := range quietPaths { if req.URL.Path == path { isQuietPath = true break } } if isQuietPath && res.Status == 200 { xlog.Debug("HTTP request", "method", req.Method, "path", req.URL.Path, "status", res.Status) } else { xlog.Info("HTTP request", "method", req.Method, "path", req.URL.Path, "status", res.Status) } return err } }) // Recover middleware if !application.ApplicationConfig().Debug { e.Use(middleware.Recover()) } // Metrics middleware if !application.ApplicationConfig().DisableMetrics { metricsService, err := services.NewLocalAIMetricsService() if err != nil { return nil, err } if metricsService != nil { e.Use(localai.LocalAIMetricsAPIMiddleware(metricsService)) e.Server.RegisterOnShutdown(func() { metricsService.Shutdown() }) } } // Health Checks should always be exempt from auth, so register these first routes.HealthRoutes(e) // Build auth middleware: use the new auth.Middleware when auth is enabled or // as a unified replacement for the legacy key-auth middleware. authMiddleware := auth.Middleware(application.AuthDB(), application.ApplicationConfig()) // Favicon handler e.GET("/favicon.svg", func(c echo.Context) error { data, err := embedDirStatic.ReadFile("static/favicon.svg") if err != nil { return c.NoContent(http.StatusNotFound) } c.Response().Header().Set("Content-Type", "image/svg+xml") return c.Blob(http.StatusOK, "image/svg+xml", data) }) // Static files - use fs.Sub to create a filesystem rooted at "static" staticFS, err := fs.Sub(embedDirStatic, "static") if err != nil { return nil, fmt.Errorf("failed to create static filesystem: %w", err) } e.StaticFS("/static", staticFS) // Generated content directories if application.ApplicationConfig().GeneratedContentDir != "" { os.MkdirAll(application.ApplicationConfig().GeneratedContentDir, 0750) audioPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "audio") imagePath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "images") videoPath := filepath.Join(application.ApplicationConfig().GeneratedContentDir, "videos") os.MkdirAll(audioPath, 0750) os.MkdirAll(imagePath, 0750) os.MkdirAll(videoPath, 0750) e.Static("/generated-audio", audioPath) e.Static("/generated-images", imagePath) e.Static("/generated-videos", videoPath) } // Initialize usage recording when auth DB is available if application.AuthDB() != nil { httpMiddleware.InitUsageRecorder(application.AuthDB()) } // Auth is applied to _all_ endpoints. Filtering out endpoints to bypass is // the role of the exempt-path logic inside the middleware. e.Use(authMiddleware) // Feature and model access control (after auth middleware, before routes) if application.AuthDB() != nil { e.Use(auth.RequireRouteFeature(application.AuthDB())) e.Use(auth.RequireModelAccess(application.AuthDB())) } // CORS middleware if application.ApplicationConfig().CORS { corsConfig := middleware.CORSConfig{} if application.ApplicationConfig().CORSAllowOrigins != "" { corsConfig.AllowOrigins = strings.Split(application.ApplicationConfig().CORSAllowOrigins, ",") } e.Use(middleware.CORSWithConfig(corsConfig)) } else { e.Use(middleware.CORS()) } // CSRF middleware (enabled by default, disable with LOCALAI_DISABLE_CSRF=true) // // Protection relies on Echo's Sec-Fetch-Site header check (supported by all // modern browsers). The legacy cookie+token approach is removed because // Echo's Sec-Fetch-Site short-circuit never sets the cookie, so the frontend // could never read a token to send back. if !application.ApplicationConfig().DisableCSRF { xlog.Debug("Enabling CSRF middleware (Sec-Fetch-Site mode)") e.Use(middleware.CSRFWithConfig(middleware.CSRFConfig{ Skipper: func(c echo.Context) bool { // Skip CSRF for API clients using auth headers (may be cross-origin) if c.Request().Header.Get("Authorization") != "" { return true } if c.Request().Header.Get("x-api-key") != "" || c.Request().Header.Get("xi-api-key") != "" { return true } // Skip when Sec-Fetch-Site header is absent (older browsers, reverse // proxies that strip the header). The SameSite=Lax cookie attribute // provides baseline CSRF protection for these clients. if c.Request().Header.Get("Sec-Fetch-Site") == "" { return true } return false }, // Allow same-site requests (subdomains / different ports) in addition // to same-origin which Echo already permits by default. AllowSecFetchSiteFunc: func(c echo.Context) (bool, error) { secFetchSite := c.Request().Header.Get("Sec-Fetch-Site") if secFetchSite == "same-site" { return true, nil } // cross-site: block return false, nil }, })) } // Admin middleware: enforces admin role when auth is enabled, no-op otherwise var adminMiddleware echo.MiddlewareFunc if application.AuthDB() != nil { adminMiddleware = auth.RequireAdmin() } else { adminMiddleware = auth.NoopMiddleware() } // Feature middlewares: per-feature access control agentsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureAgents) skillsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureSkills) collectionsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureCollections) mcpJobsMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCPJobs) requestExtractor := httpMiddleware.NewRequestExtractor(application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) // Register auth routes (login, callback, API keys, user management) routes.RegisterAuthRoutes(e, application) routes.RegisterElevenLabsRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) // Create opcache for tracking UI operations (used by both UI and LocalAI routes) var opcache *services.OpCache if !application.ApplicationConfig().DisableWebUI { opcache = services.NewOpCache(application.GalleryService()) } mcpMw := auth.RequireFeature(application.AuthDB(), auth.FeatureMCP) routes.RegisterLocalAIRoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application.TemplatesEvaluator(), application, adminMiddleware, mcpJobsMw, mcpMw) routes.RegisterAgentPoolRoutes(e, application, agentsMw, skillsMw, collectionsMw) routes.RegisterOpenAIRoutes(e, requestExtractor, application) routes.RegisterAnthropicRoutes(e, requestExtractor, application) routes.RegisterOpenResponsesRoutes(e, requestExtractor, application) if !application.ApplicationConfig().DisableWebUI { routes.RegisterUIAPIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), opcache, application, adminMiddleware) routes.RegisterUIRoutes(e, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), application.GalleryService(), adminMiddleware) // Serve React SPA from / with SPA fallback via 404 handler reactFS, fsErr := fs.Sub(reactUI, "react-ui/dist") if fsErr != nil { xlog.Warn("React UI not available (build with 'make core/http/react-ui/dist')", "error", fsErr) } else { serveIndex := func(c echo.Context) error { indexHTML, err := reactUI.ReadFile("react-ui/dist/index.html") if err != nil { return c.String(http.StatusNotFound, "React UI not built") } // Inject for reverse-proxy support baseURL := httpMiddleware.BaseURL(c) if baseURL != "" { baseTag := `` indexHTML = []byte(strings.Replace(string(indexHTML), "", "\n "+baseTag, 1)) } return c.HTMLBlob(http.StatusOK, indexHTML) } // Enable SPA fallback in the 404 handler for client-side routing spaFallback = serveIndex // Serve React SPA at /app e.GET("/app", serveIndex) e.GET("/app/*", serveIndex) // prefixRedirect performs a redirect that preserves X-Forwarded-Prefix for reverse-proxy support. prefixRedirect := func(c echo.Context, target string) error { if prefix := c.Request().Header.Get("X-Forwarded-Prefix"); prefix != "" { target = strings.TrimSuffix(prefix, "/") + target } return c.Redirect(http.StatusMovedPermanently, target) } // Redirect / to /app e.GET("/", func(c echo.Context) error { return prefixRedirect(c, "/app") }) // Backward compatibility: redirect /browse/* to /app/* e.GET("/browse", func(c echo.Context) error { return prefixRedirect(c, "/app") }) e.GET("/browse/*", func(c echo.Context) error { p := c.Param("*") return prefixRedirect(c, "/app/"+p) }) // Serve React static assets (JS, CSS, etc.) serveReactAsset := func(c echo.Context) error { p := "assets/" + c.Param("*") f, err := reactFS.Open(p) if err == nil { defer f.Close() stat, statErr := f.Stat() if statErr == nil && !stat.IsDir() { contentType := mime.TypeByExtension(filepath.Ext(p)) if contentType == "" { contentType = echo.MIMEOctetStream } return c.Stream(http.StatusOK, contentType, f) } } return echo.NewHTTPError(http.StatusNotFound) } e.GET("/assets/*", serveReactAsset) } } routes.RegisterJINARoutes(e, requestExtractor, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig()) // Note: 404 handling is done via HTTPErrorHandler above, no need for catch-all route // Log startup message e.Server.RegisterOnShutdown(func() { xlog.Info("LocalAI API server shutting down") }) return e, nil } ================================================ FILE: core/http/app_test.go ================================================ package http_test import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "os" "path/filepath" "runtime" "time" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/core/schema" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/pkg/downloader" "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gopkg.in/yaml.v3" "github.com/mudler/xlog" openaigo "github.com/otiai10/openaigo" "github.com/sashabaranov/go-openai" "github.com/sashabaranov/go-openai/jsonschema" ) const apiKey = "joshua" const bearerKey = "Bearer " + apiKey const testPrompt = `### System: You are an AI assistant that follows instruction extremely well. Help as much as you can. ### Instruction: Say hello. ### Response:` type modelApplyRequest struct { ID string `json:"id"` URL string `json:"url"` ConfigURL string `json:"config_url"` Name string `json:"name"` Overrides map[string]interface{} `json:"overrides"` } func getModelStatus(url string) (response map[string]interface{}) { // Create the HTTP request req, err := http.NewRequest("GET", url, nil) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) if err != nil { fmt.Println("Error creating request:", err) return } client := &http.Client{} resp, err := client.Do(req) if err != nil { fmt.Println("Error sending request:", err) return } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println("Error reading response body:", err) return } // Unmarshal the response into a map[string]interface{} err = json.Unmarshal(body, &response) if err != nil { fmt.Println("Error unmarshaling JSON response:", err) return } return } func getModels(url string) ([]gallery.GalleryModel, error) { response := []gallery.GalleryModel{} uri := downloader.URI(url) // TODO: No tests currently seem to exercise file:// urls. Fix? err := uri.ReadWithAuthorizationAndCallback(context.TODO(), "", bearerKey, func(url string, i []byte) error { // Unmarshal YAML data into a struct return json.Unmarshal(i, &response) }) return response, err } func postModelApplyRequest(url string, request modelApplyRequest) (response map[string]interface{}) { //url := "http://localhost:AI/models/apply" // Create the request payload payload, err := json.Marshal(request) if err != nil { fmt.Println("Error marshaling JSON:", err) return } // Create the HTTP request req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) if err != nil { fmt.Println("Error creating request:", err) return } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) // Make the request client := &http.Client{} resp, err := client.Do(req) if err != nil { fmt.Println("Error making request:", err) return } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { fmt.Println("Error reading response body:", err) return } // Unmarshal the response into a map[string]interface{} err = json.Unmarshal(body, &response) if err != nil { fmt.Println("Error unmarshaling JSON response:", err) return } return } func postRequestJSON[B any](url string, bodyJson *B) error { payload, err := json.Marshal(bodyJson) if err != nil { return err } GinkgoWriter.Printf("POST %s: %s\n", url, string(payload)) req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) if err != nil { return err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return err } if resp.StatusCode < 200 || resp.StatusCode >= 400 { return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) } return nil } func postRequestResponseJSON[B1 any, B2 any](url string, reqJson *B1, respJson *B2) error { payload, err := json.Marshal(reqJson) if err != nil { return err } GinkgoWriter.Printf("POST %s: %s\n", url, string(payload)) req, err := http.NewRequest("POST", url, bytes.NewBuffer(payload)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) if err != nil { return err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return err } if resp.StatusCode < 200 || resp.StatusCode >= 400 { return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) } return json.Unmarshal(body, respJson) } func putRequestJSON[B any](url string, bodyJson *B) error { payload, err := json.Marshal(bodyJson) if err != nil { return err } GinkgoWriter.Printf("PUT %s: %s\n", url, string(payload)) req, err := http.NewRequest("PUT", url, bytes.NewBuffer(payload)) if err != nil { return err } req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) if err != nil { return err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return err } if resp.StatusCode < 200 || resp.StatusCode >= 400 { return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)) } return nil } func postInvalidRequest(url string) (error, int) { req, err := http.NewRequest("POST", url, bytes.NewBufferString("invalid request")) if err != nil { return err, -1 } req.Header.Set("Content-Type", "application/json") client := &http.Client{} resp, err := client.Do(req) if err != nil { return err, -1 } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return err, -1 } if resp.StatusCode < 200 || resp.StatusCode >= 400 { return fmt.Errorf("unexpected status code: %d, body: %s", resp.StatusCode, string(body)), resp.StatusCode } return nil, resp.StatusCode } func getRequest(url string, header http.Header) (error, int, []byte) { req, err := http.NewRequest("GET", url, nil) if err != nil { return err, -1, nil } req.Header = header client := &http.Client{} resp, err := client.Do(req) if err != nil { return err, -1, nil } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return err, -1, nil } return nil, resp.StatusCode, body } const bertEmbeddingsURL = `https://gist.githubusercontent.com/mudler/0a080b166b87640e8644b09c2aee6e3b/raw/f0e8c26bb72edc16d9fbafbfd6638072126ff225/bert-embeddings-gallery.yaml` var _ = Describe("API test", func() { var app *echo.Echo var client *openai.Client var client2 *openaigo.Client var c context.Context var cancel context.CancelFunc var tmpdir string var modelDir string commonOpts := []config.AppOption{ config.WithDebug(true), } Context("API with ephemeral models", func() { BeforeEach(func(sc SpecContext) { var err error tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) backendPath := os.Getenv("BACKENDS_PATH") modelDir = filepath.Join(tmpdir, "models") err = os.Mkdir(modelDir, 0750) Expect(err).ToNot(HaveOccurred()) c, cancel = context.WithCancel(context.Background()) g := []gallery.GalleryModel{ { Metadata: gallery.Metadata{ Name: "bert", URL: bertEmbeddingsURL, }, Overrides: map[string]interface{}{"backend": "llama-cpp"}, }, { Metadata: gallery.Metadata{ Name: "bert2", URL: bertEmbeddingsURL, AdditionalFiles: []gallery.File{{Filename: "foo.yaml", URI: bertEmbeddingsURL}}, }, Overrides: map[string]interface{}{"foo": "bar"}, }, } out, err := yaml.Marshal(g) Expect(err).ToNot(HaveOccurred()) err = os.WriteFile(filepath.Join(modelDir, "gallery_simple.yaml"), out, 0600) Expect(err).ToNot(HaveOccurred()) galleries := []config.Gallery{ { Name: "test", URL: "file://" + filepath.Join(modelDir, "gallery_simple.yaml"), }, } systemState, err := system.GetSystemState( system.WithBackendPath(backendPath), system.WithModelPath(modelDir), ) Expect(err).ToNot(HaveOccurred()) application, err := application.New( append(commonOpts, config.WithContext(c), config.WithSystemState(systemState), config.WithGalleries(galleries), config.WithApiKeys([]string{apiKey}), )...) Expect(err).ToNot(HaveOccurred()) app, err = API(application) Expect(err).ToNot(HaveOccurred()) go func() { if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { xlog.Error("server error", "error", err) } }() defaultConfig := openai.DefaultConfig(apiKey) defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL // Wait for API to be ready client = openai.NewClientWithConfig(defaultConfig) Eventually(func() error { _, err := client.ListModels(context.TODO()) return err }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func(sc SpecContext) { cancel() if app != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err := app.Shutdown(ctx) Expect(err).ToNot(HaveOccurred()) } err := os.RemoveAll(tmpdir) Expect(err).ToNot(HaveOccurred()) _, err = os.ReadDir(tmpdir) Expect(err).To(HaveOccurred()) }) Context("Auth Tests", func() { It("Should fail if the api key is missing", func() { err, sc := postInvalidRequest("http://127.0.0.1:9090/models/available") Expect(err).ToNot(BeNil()) Expect(sc).To(Equal(401)) }) }) Context("URL routing Tests", func() { It("Should support reverse-proxy when unauthenticated", func() { err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{ "X-Forwarded-Proto": {"https"}, "X-Forwarded-Host": {"example.org"}, "X-Forwarded-Prefix": {"/myprefix/"}, }) Expect(err).To(BeNil(), "error") Expect(sc).To(Equal(200), "status code") // Non-API paths pass through to the React SPA (which handles login client-side) Expect(string(body)).To(ContainSubstring(``), "body") Expect(string(body)).To(ContainSubstring(`
`), "should serve React SPA") }) It("Should support reverse-proxy when authenticated", func() { err, sc, body := getRequest("http://127.0.0.1:9090/myprefix/", http.Header{ "Authorization": {bearerKey}, "X-Forwarded-Proto": {"https"}, "X-Forwarded-Host": {"example.org"}, "X-Forwarded-Prefix": {"/myprefix/"}, }) Expect(err).To(BeNil(), "error") Expect(sc).To(Equal(200), "status code") Expect(string(body)).To(ContainSubstring(``), "body") }) }) Context("Applying models", func() { It("applies models from a gallery", func() { models, err := getModels("http://127.0.0.1:9090/models/available") Expect(err).To(BeNil()) Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Installed).To(BeFalse(), fmt.Sprint(models)) Expect(models[1].Installed).To(BeFalse(), fmt.Sprint(models)) response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ ID: "test@bert2", }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) resp := map[string]interface{}{} Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) fmt.Println(response) resp = response return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) Expect(resp["message"]).ToNot(ContainSubstring("error")) dat, err := os.ReadFile(filepath.Join(modelDir, "bert2.yaml")) Expect(err).ToNot(HaveOccurred()) _, err = os.ReadFile(filepath.Join(modelDir, "foo.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) Expect(content["foo"]).To(Equal("bar")) models, err = getModels("http://127.0.0.1:9090/models/available") Expect(err).To(BeNil()) Expect(len(models)).To(Equal(2), fmt.Sprint(models)) Expect(models[0].Name).To(Or(Equal("bert"), Equal("bert2"))) Expect(models[1].Name).To(Or(Equal("bert"), Equal("bert2"))) for _, m := range models { if m.Name == "bert2" { Expect(m.Installed).To(BeTrue()) } else { Expect(m.Installed).To(BeFalse()) } } }) It("overrides models", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: bertEmbeddingsURL, Name: "bert", Overrides: map[string]interface{}{ "backend": "llama", }, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["backend"]).To(Equal("llama")) }) It("apply models without overrides", func() { response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ URL: bertEmbeddingsURL, Name: "bert", Overrides: map[string]interface{}{}, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) dat, err := os.ReadFile(filepath.Join(modelDir, "bert.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["usage"]).To(ContainSubstring("You can test this model with curl like this")) }) }) Context("Importing models from URI", func() { var testYamlFile string BeforeEach(func() { // Create a test YAML config file yamlContent := `name: test-import-model backend: llama-cpp description: Test model imported from file URI parameters: model: path/to/model.gguf temperature: 0.7 ` testYamlFile = filepath.Join(tmpdir, "test-import.yaml") err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) Expect(err).ToNot(HaveOccurred()) }) AfterEach(func() { err := os.Remove(testYamlFile) Expect(err).ToNot(HaveOccurred()) }) It("should import model from file:// URI pointing to local YAML config", func() { importReq := schema.ImportModelRequest{ URI: "file://" + testYamlFile, Preferences: json.RawMessage(`{}`), } var response schema.GalleryResponse err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) Expect(err).ToNot(HaveOccurred()) Expect(response.ID).ToNot(BeEmpty()) uuid := response.ID resp := map[string]interface{}{} Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) resp = response return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) // Check that the model was imported successfully Expect(resp["message"]).ToNot(ContainSubstring("error")) Expect(resp["error"]).To(BeNil()) // Verify the model config file was created dat, err := os.ReadFile(filepath.Join(modelDir, "test-import-model.yaml")) Expect(err).ToNot(HaveOccurred()) content := map[string]interface{}{} err = yaml.Unmarshal(dat, &content) Expect(err).ToNot(HaveOccurred()) Expect(content["name"]).To(Equal("test-import-model")) Expect(content["backend"]).To(Equal("llama-cpp")) }) It("should return error when file:// URI points to non-existent file", func() { nonExistentFile := filepath.Join(tmpdir, "nonexistent.yaml") importReq := schema.ImportModelRequest{ URI: "file://" + nonExistentFile, Preferences: json.RawMessage(`{}`), } var response schema.GalleryResponse err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) // The endpoint should return an error immediately Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("failed to discover model config")) }) }) Context("Importing models from URI can't point to absolute paths", func() { var testYamlFile string BeforeEach(func() { // Create a test YAML config file yamlContent := `name: test-import-model backend: llama-cpp description: Test model imported from file URI parameters: model: /path/to/model.gguf temperature: 0.7 ` testYamlFile = filepath.Join(tmpdir, "test-import.yaml") err := os.WriteFile(testYamlFile, []byte(yamlContent), 0644) Expect(err).ToNot(HaveOccurred()) }) AfterEach(func() { err := os.Remove(testYamlFile) Expect(err).ToNot(HaveOccurred()) }) It("should fail to import model from file:// URI pointing to local YAML config", func() { importReq := schema.ImportModelRequest{ URI: "file://" + testYamlFile, Preferences: json.RawMessage(`{}`), } var response schema.GalleryResponse err := postRequestResponseJSON("http://127.0.0.1:9090/models/import-uri", &importReq, &response) Expect(err).ToNot(HaveOccurred()) Expect(response.ID).ToNot(BeEmpty()) uuid := response.ID resp := map[string]interface{}{} Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) resp = response return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) // Check that the model was imported successfully Expect(resp["message"]).To(ContainSubstring("error")) Expect(resp["error"]).ToNot(BeNil()) }) }) }) Context("Model gallery", func() { BeforeEach(func() { var err error tmpdir, err = os.MkdirTemp("", "") backendPath := os.Getenv("BACKENDS_PATH") Expect(err).ToNot(HaveOccurred()) modelDir = filepath.Join(tmpdir, "models") backendAssetsDir := filepath.Join(tmpdir, "backend-assets") err = os.Mkdir(backendAssetsDir, 0750) Expect(err).ToNot(HaveOccurred()) c, cancel = context.WithCancel(context.Background()) galleries := []config.Gallery{ { Name: "localai", URL: "https://raw.githubusercontent.com/mudler/LocalAI/refs/heads/master/gallery/index.yaml", }, } systemState, err := system.GetSystemState( system.WithBackendPath(backendPath), system.WithModelPath(modelDir), ) Expect(err).ToNot(HaveOccurred()) application, err := application.New( append(commonOpts, config.WithContext(c), config.WithGeneratedContentDir(tmpdir), config.WithSystemState(systemState), config.WithGalleries(galleries), )..., ) Expect(err).ToNot(HaveOccurred()) app, err = API(application) Expect(err).ToNot(HaveOccurred()) go func() { if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { xlog.Error("server error", "error", err) } }() defaultConfig := openai.DefaultConfig("") defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL // Wait for API to be ready client = openai.NewClientWithConfig(defaultConfig) Eventually(func() error { _, err := client.ListModels(context.TODO()) return err }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { cancel() if app != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err := app.Shutdown(ctx) Expect(err).ToNot(HaveOccurred()) } err := os.RemoveAll(tmpdir) Expect(err).ToNot(HaveOccurred()) _, err = os.ReadDir(tmpdir) Expect(err).To(HaveOccurred()) }) It("runs gguf models (chat)", Label("llama-gguf"), func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } modelName := "qwen3-1.7b" response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ ID: "localai@" + modelName, }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) return response["processed"].(bool) }, "900s", "10s").Should(Equal(true)) By("testing chat") resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: modelName, Messages: []openai.ChatCompletionMessage{ { Role: "user", Content: "How much is 2+2?", }, }}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).To(Or(ContainSubstring("4"), ContainSubstring("four"))) By("testing functions") resp2, err := client.CreateChatCompletion( context.TODO(), openai.ChatCompletionRequest{ Model: modelName, Messages: []openai.ChatCompletionMessage{ { Role: "user", Content: "What is the weather like in San Francisco (celsius)?", }, }, Functions: []openai.FunctionDefinition{ openai.FunctionDefinition{ Name: "get_current_weather", Description: "Get the current weather", Parameters: jsonschema.Definition{ Type: jsonschema.Object, Properties: map[string]jsonschema.Definition{ "location": { Type: jsonschema.String, Description: "The city and state, e.g. San Francisco, CA", }, "unit": { Type: jsonschema.String, Enum: []string{"celcius", "fahrenheit"}, }, }, Required: []string{"location"}, }, }, }, }) Expect(err).ToNot(HaveOccurred()) Expect(len(resp2.Choices)).To(Equal(1)) Expect(resp2.Choices[0].Message.FunctionCall).ToNot(BeNil()) Expect(resp2.Choices[0].Message.FunctionCall.Name).To(Equal("get_current_weather"), resp2.Choices[0].Message.FunctionCall.Name) var res map[string]string err = json.Unmarshal([]byte(resp2.Choices[0].Message.FunctionCall.Arguments), &res) Expect(err).ToNot(HaveOccurred()) Expect(res["location"]).To(ContainSubstring("San Francisco"), fmt.Sprint(res)) Expect(res["unit"]).To(Equal("celcius"), fmt.Sprint(res)) Expect(string(resp2.Choices[0].FinishReason)).To(Equal("function_call"), fmt.Sprint(resp2.Choices[0].FinishReason)) }) It("installs and is capable to run tts", Label("tts"), func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ ID: "localai@voice-en-us-kathleen-low", }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) fmt.Println(response) return response["processed"].(bool) }, "360s", "10s").Should(Equal(true)) // An HTTP Post to the /tts endpoint should return a wav audio file resp, err := http.Post("http://127.0.0.1:9090/tts", "application/json", bytes.NewBuffer([]byte(`{"input": "Hello world", "model": "voice-en-us-kathleen-low"}`))) Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) dat, err := io.ReadAll(resp.Body) Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) Expect(resp.StatusCode).To(Equal(200), fmt.Sprint(string(dat))) Expect(resp.Header.Get("Content-Type")).To(Or(Equal("audio/x-wav"), Equal("audio/wav"), Equal("audio/vnd.wave"))) }) It("installs and is capable to generate images", Label("stablediffusion"), func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } response := postModelApplyRequest("http://127.0.0.1:9090/models/apply", modelApplyRequest{ ID: "localai@sd-1.5-ggml", Name: "stablediffusion", }) Expect(response["uuid"]).ToNot(BeEmpty(), fmt.Sprint(response)) uuid := response["uuid"].(string) Eventually(func() bool { response := getModelStatus("http://127.0.0.1:9090/models/jobs/" + uuid) fmt.Println(response) return response["processed"].(bool) }, "1200s", "10s").Should(Equal(true)) resp, err := http.Post( "http://127.0.0.1:9090/v1/images/generations", "application/json", bytes.NewBuffer([]byte(`{ "prompt": "a lovely cat", "step": 1, "seed":9000, "size": "256x256", "n":2}`))) // The response should contain an URL Expect(err).ToNot(HaveOccurred(), fmt.Sprint(resp)) dat, err := io.ReadAll(resp.Body) Expect(err).ToNot(HaveOccurred(), "error reading /image/generations response") imgUrlResp := &schema.OpenAIResponse{} err = json.Unmarshal(dat, imgUrlResp) Expect(err).ToNot(HaveOccurred(), fmt.Sprint(dat)) Expect(imgUrlResp.Data).ToNot(Or(BeNil(), BeZero())) imgUrl := imgUrlResp.Data[0].URL Expect(imgUrl).To(ContainSubstring("http://127.0.0.1:9090/"), imgUrl) Expect(imgUrl).To(ContainSubstring(".png"), imgUrl) imgResp, err := http.Get(imgUrl) Expect(err).To(BeNil()) Expect(imgResp).ToNot(BeNil()) Expect(imgResp.StatusCode).To(Equal(200)) Expect(imgResp.ContentLength).To(BeNumerically(">", 0)) imgData := make([]byte, 512) count, err := io.ReadFull(imgResp.Body, imgData) Expect(err).To(Or(BeNil(), MatchError(io.EOF))) Expect(count).To(BeNumerically(">", 0)) Expect(count).To(BeNumerically("<=", 512)) Expect(http.DetectContentType(imgData)).To(Equal("image/png")) }) }) Context("API query", func() { BeforeEach(func() { modelPath := os.Getenv("MODELS_PATH") backendPath := os.Getenv("BACKENDS_PATH") c, cancel = context.WithCancel(context.Background()) var err error systemState, err := system.GetSystemState( system.WithBackendPath(backendPath), system.WithModelPath(modelPath), ) Expect(err).ToNot(HaveOccurred()) application, err := application.New( append(commonOpts, config.WithExternalBackend("transformers", os.Getenv("TRANSFORMER_BACKEND")), config.WithContext(c), config.WithSystemState(systemState), )...) Expect(err).ToNot(HaveOccurred()) app, err = API(application) Expect(err).ToNot(HaveOccurred()) go func() { if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { xlog.Error("server error", "error", err) } }() defaultConfig := openai.DefaultConfig("") defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL // Wait for API to be ready client = openai.NewClientWithConfig(defaultConfig) Eventually(func() error { _, err := client.ListModels(context.TODO()) return err }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { cancel() if app != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err := app.Shutdown(ctx) Expect(err).ToNot(HaveOccurred()) } }) It("returns the models list", func() { models, err := client.ListModels(context.TODO()) Expect(err).ToNot(HaveOccurred()) // A model called "bert" can be present in the model directory depending on the order of the tests Expect(len(models.Models)).To(BeNumerically(">=", 8)) }) It("can generate completions via ggml", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } resp, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "testmodel.ggml", Prompt: testPrompt}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).ToNot(BeEmpty()) }) It("can generate chat completions via ggml", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "testmodel.ggml", Messages: []openai.ChatCompletionMessage{openai.ChatCompletionMessage{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("returns logprobs in chat completions when requested", func() { if runtime.GOOS != "linux" { Skip("test only on linux") } topLogprobsVal := 3 response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{ Model: "testmodel.ggml", LogProbs: true, TopLogProbs: topLogprobsVal, Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(response.Choices)).To(Equal(1)) Expect(response.Choices[0].Message).ToNot(BeNil()) Expect(response.Choices[0].Message.Content).ToNot(BeEmpty()) // Verify logprobs are present and have correct structure Expect(response.Choices[0].LogProbs).ToNot(BeNil()) Expect(response.Choices[0].LogProbs.Content).ToNot(BeEmpty()) Expect(len(response.Choices[0].LogProbs.Content)).To(BeNumerically(">", 1)) foundatLeastToken := "" foundAtLeastBytes := []byte{} foundAtLeastTopLogprobBytes := []byte{} foundatLeastTopLogprob := "" // Verify logprobs content structure matches OpenAI format for _, logprobContent := range response.Choices[0].LogProbs.Content { // Bytes can be empty for certain tokens (special tokens, etc.), so we don't require it if len(logprobContent.Bytes) > 0 { foundAtLeastBytes = logprobContent.Bytes } if len(logprobContent.Token) > 0 { foundatLeastToken = logprobContent.Token } Expect(logprobContent.LogProb).To(BeNumerically("<=", 0)) // Logprobs are always <= 0 Expect(len(logprobContent.TopLogProbs)).To(BeNumerically(">", 1)) // If top_logprobs is requested, verify top_logprobs array respects the limit if len(logprobContent.TopLogProbs) > 0 { // Should respect top_logprobs limit (3 in this test) Expect(len(logprobContent.TopLogProbs)).To(BeNumerically("<=", topLogprobsVal)) for _, topLogprob := range logprobContent.TopLogProbs { if len(topLogprob.Bytes) > 0 { foundAtLeastTopLogprobBytes = topLogprob.Bytes } if len(topLogprob.Token) > 0 { foundatLeastTopLogprob = topLogprob.Token } Expect(topLogprob.LogProb).To(BeNumerically("<=", 0)) } } } Expect(foundAtLeastBytes).ToNot(BeEmpty()) Expect(foundAtLeastTopLogprobBytes).ToNot(BeEmpty()) Expect(foundatLeastToken).ToNot(BeEmpty()) Expect(foundatLeastTopLogprob).ToNot(BeEmpty()) }) It("applies logit_bias to chat completions when requested", func() { if runtime.GOOS != "linux" { Skip("test only on linux") } // logit_bias is a map of token IDs (as strings) to bias values (-100 to 100) // According to OpenAI API: modifies the likelihood of specified tokens appearing in the completion logitBias := map[string]int{ "15043": 1, // Bias token ID 15043 (example token ID) with bias value 1 } response, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{ Model: "testmodel.ggml", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}, LogitBias: logitBias, }) Expect(err).ToNot(HaveOccurred()) Expect(len(response.Choices)).To(Equal(1)) Expect(response.Choices[0].Message).ToNot(BeNil()) Expect(response.Choices[0].Message.Content).ToNot(BeEmpty()) // If logit_bias is applied, the response should be generated successfully // We can't easily verify the bias effect without knowing the actual token IDs for the model, // but the fact that the request succeeds confirms the API accepts and processes logit_bias }) It("returns errors", func() { _, err := client.CreateCompletion(context.TODO(), openai.CompletionRequest{Model: "foomodel", Prompt: testPrompt}) Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("error, status code: 500, status: 500 Internal Server Error, message: could not load model - all backends returned error:")) }) It("shows the external backend", func() { // Only run on linux if runtime.GOOS != "linux" { Skip("test supported only on linux") } // do an http request to the /system endpoint resp, err := http.Get("http://127.0.0.1:9090/system") Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) dat, err := io.ReadAll(resp.Body) Expect(err).ToNot(HaveOccurred()) Expect(string(dat)).To(ContainSubstring("llama-cpp")) }) It("transcribes audio", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } resp, err := client.CreateTranscription( context.Background(), openai.AudioRequest{ Model: openai.Whisper1, FilePath: filepath.Join(os.Getenv("TEST_DIR"), "audio.wav"), }, ) Expect(err).ToNot(HaveOccurred()) Expect(resp.Text).To(ContainSubstring("This is the Micro Machine Man presenting")) }) It("calculate embeddings", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } embeddingModel := openai.AdaEmbeddingV2 resp, err := client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ Model: embeddingModel, Input: []string{"sun", "cat"}, }, ) Expect(err).ToNot(HaveOccurred(), err) Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 4096)) Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 4096)) sunEmbedding := resp.Data[0].Embedding resp2, err := client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ Model: embeddingModel, Input: []string{"sun"}, }, ) Expect(err).ToNot(HaveOccurred()) Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) resp3, err := client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ Model: embeddingModel, Input: []string{"cat"}, }, ) Expect(err).ToNot(HaveOccurred()) Expect(resp3.Data[0].Embedding).To(Equal(resp.Data[1].Embedding)) Expect(resp3.Data[0].Embedding).ToNot(Equal(sunEmbedding)) }) Context("External gRPC calls", func() { It("calculate embeddings with sentencetransformers", func() { if runtime.GOOS != "linux" { Skip("test supported only on linux") } resp, err := client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ Model: openai.AdaCodeSearchCode, Input: []string{"sun", "cat"}, }, ) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Data[0].Embedding)).To(BeNumerically("==", 384)) Expect(len(resp.Data[1].Embedding)).To(BeNumerically("==", 384)) sunEmbedding := resp.Data[0].Embedding resp2, err := client.CreateEmbeddings( context.Background(), openai.EmbeddingRequest{ Model: openai.AdaCodeSearchCode, Input: []string{"sun"}, }, ) Expect(err).ToNot(HaveOccurred()) Expect(resp2.Data[0].Embedding).To(Equal(sunEmbedding)) Expect(resp2.Data[0].Embedding).ToNot(Equal(resp.Data[1].Embedding)) }) }) // See tests/integration/stores_test Context("Stores", Label("stores"), func() { BeforeEach(func() { // Only run on linux if runtime.GOOS != "linux" { Skip("test supported only on linux") } }) It("sets, gets, finds and deletes entries", func() { ks := [][]float32{ {0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}, {0.7, 0.8, 0.9}, } vs := []string{ "test1", "test2", "test3", } setBody := schema.StoresSet{ Keys: ks, Values: vs, } url := "http://127.0.0.1:9090/stores/" err := postRequestJSON(url+"set", &setBody) Expect(err).ToNot(HaveOccurred()) getBody := schema.StoresGet{ Keys: ks, } var getRespBody schema.StoresGetResponse err = postRequestResponseJSON(url+"get", &getBody, &getRespBody) Expect(err).ToNot(HaveOccurred()) Expect(len(getRespBody.Keys)).To(Equal(len(ks))) for i, v := range getRespBody.Keys { if v[0] == 0.1 { Expect(getRespBody.Values[i]).To(Equal("test1")) } else if v[0] == 0.4 { Expect(getRespBody.Values[i]).To(Equal("test2")) } else { Expect(getRespBody.Values[i]).To(Equal("test3")) } } deleteBody := schema.StoresDelete{ Keys: [][]float32{ {0.1, 0.2, 0.3}, }, } err = postRequestJSON(url+"delete", &deleteBody) Expect(err).ToNot(HaveOccurred()) findBody := schema.StoresFind{ Key: []float32{0.1, 0.3, 0.7}, Topk: 10, } var findRespBody schema.StoresFindResponse err = postRequestResponseJSON(url+"find", &findBody, &findRespBody) Expect(err).ToNot(HaveOccurred()) Expect(len(findRespBody.Keys)).To(Equal(2)) for i, v := range findRespBody.Keys { if v[0] == 0.4 { Expect(findRespBody.Values[i]).To(Equal("test2")) } else { Expect(findRespBody.Values[i]).To(Equal("test3")) } Expect(findRespBody.Similarities[i]).To(BeNumerically(">=", -1)) Expect(findRespBody.Similarities[i]).To(BeNumerically("<=", 1)) } }) Context("Agent Jobs", Label("agent-jobs"), func() { It("creates and manages tasks", func() { // Create a task taskBody := map[string]interface{}{ "name": "Test Task", "description": "Test Description", "model": "testmodel.ggml", "prompt": "Hello {{.name}}", "enabled": true, } var createResp map[string]interface{} err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) Expect(err).ToNot(HaveOccurred()) Expect(createResp["id"]).ToNot(BeEmpty()) taskID := createResp["id"].(string) // Get the task var task schema.Task resp, err := http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, _ := io.ReadAll(resp.Body) json.Unmarshal(body, &task) Expect(task.Name).To(Equal("Test Task")) // List tasks resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks") Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) var tasks []schema.Task body, _ = io.ReadAll(resp.Body) json.Unmarshal(body, &tasks) Expect(len(tasks)).To(BeNumerically(">=", 1)) // Update task taskBody["name"] = "Updated Task" err = putRequestJSON("http://127.0.0.1:9090/api/agent/tasks/"+taskID, &taskBody) Expect(err).ToNot(HaveOccurred()) // Verify update resp, err = http.Get("http://127.0.0.1:9090/api/agent/tasks/" + taskID) Expect(err).ToNot(HaveOccurred()) body, _ = io.ReadAll(resp.Body) json.Unmarshal(body, &task) Expect(task.Name).To(Equal("Updated Task")) // Delete task req, _ := http.NewRequest("DELETE", "http://127.0.0.1:9090/api/agent/tasks/"+taskID, nil) req.Header.Set("Authorization", bearerKey) resp, err = http.DefaultClient.Do(req) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) }) It("executes and monitors jobs", func() { // Create a task first taskBody := map[string]interface{}{ "name": "Job Test Task", "model": "testmodel.ggml", "prompt": "Say hello", "enabled": true, } var createResp map[string]interface{} err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) Expect(err).ToNot(HaveOccurred()) taskID := createResp["id"].(string) // Execute a job jobBody := map[string]interface{}{ "task_id": taskID, "parameters": map[string]string{}, } var jobResp schema.JobExecutionResponse err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/jobs/execute", &jobBody, &jobResp) Expect(err).ToNot(HaveOccurred()) Expect(jobResp.JobID).ToNot(BeEmpty()) jobID := jobResp.JobID // Get job status var job schema.Job resp, err := http.Get("http://127.0.0.1:9090/api/agent/jobs/" + jobID) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) body, _ := io.ReadAll(resp.Body) json.Unmarshal(body, &job) Expect(job.ID).To(Equal(jobID)) Expect(job.TaskID).To(Equal(taskID)) // List jobs resp, err = http.Get("http://127.0.0.1:9090/api/agent/jobs") Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) var jobs []schema.Job body, _ = io.ReadAll(resp.Body) json.Unmarshal(body, &jobs) Expect(len(jobs)).To(BeNumerically(">=", 1)) // Cancel job (if still pending/running) if job.Status == schema.JobStatusPending || job.Status == schema.JobStatusRunning { req, _ := http.NewRequest("POST", "http://127.0.0.1:9090/api/agent/jobs/"+jobID+"/cancel", nil) req.Header.Set("Authorization", bearerKey) resp, err = http.DefaultClient.Do(req) Expect(err).ToNot(HaveOccurred()) Expect(resp.StatusCode).To(Equal(200)) } }) It("executes task by name", func() { // Create a task with a specific name taskBody := map[string]interface{}{ "name": "Named Task", "model": "testmodel.ggml", "prompt": "Hello", "enabled": true, } var createResp map[string]interface{} err := postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks", &taskBody, &createResp) Expect(err).ToNot(HaveOccurred()) // Execute by name paramsBody := map[string]string{"param1": "value1"} var jobResp schema.JobExecutionResponse err = postRequestResponseJSON("http://127.0.0.1:9090/api/agent/tasks/Named Task/execute", ¶msBody, &jobResp) Expect(err).ToNot(HaveOccurred()) Expect(jobResp.JobID).ToNot(BeEmpty()) }) }) }) }) Context("Config file", func() { BeforeEach(func() { if runtime.GOOS != "linux" { Skip("run this test only on linux") } modelPath := os.Getenv("MODELS_PATH") backendPath := os.Getenv("BACKENDS_PATH") c, cancel = context.WithCancel(context.Background()) var err error systemState, err := system.GetSystemState( system.WithBackendPath(backendPath), system.WithModelPath(modelPath), ) Expect(err).ToNot(HaveOccurred()) application, err := application.New( append(commonOpts, config.WithContext(c), config.WithSystemState(systemState), config.WithConfigFile(os.Getenv("CONFIG_FILE")))..., ) Expect(err).ToNot(HaveOccurred()) app, err = API(application) Expect(err).ToNot(HaveOccurred()) go func() { if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { xlog.Error("server error", "error", err) } }() defaultConfig := openai.DefaultConfig("") defaultConfig.BaseURL = "http://127.0.0.1:9090/v1" client2 = openaigo.NewClient("") client2.BaseURL = defaultConfig.BaseURL // Wait for API to be ready client = openai.NewClientWithConfig(defaultConfig) Eventually(func() error { _, err := client.ListModels(context.TODO()) return err }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func() { cancel() if app != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err := app.Shutdown(ctx) Expect(err).ToNot(HaveOccurred()) } }) It("can generate chat completions from config file (list1)", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list1", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate chat completions from config file (list2)", func() { resp, err := client.CreateChatCompletion(context.TODO(), openai.ChatCompletionRequest{Model: "list2", Messages: []openai.ChatCompletionMessage{{Role: "user", Content: testPrompt}}}) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Message.Content).ToNot(BeEmpty()) }) It("can generate edit completions from config file", func() { request := openaigo.EditCreateRequestBody{ Model: "list2", Instruction: "foo", Input: "bar", } resp, err := client2.CreateEdit(context.Background(), request) Expect(err).ToNot(HaveOccurred()) Expect(len(resp.Choices)).To(Equal(1)) Expect(resp.Choices[0].Text).ToNot(BeEmpty()) }) }) }) ================================================ FILE: core/http/auth/apikeys.go ================================================ package auth import ( "crypto/hmac" "crypto/rand" "crypto/sha256" "encoding/hex" "fmt" "time" "github.com/google/uuid" "gorm.io/gorm" ) const ( apiKeyPrefix = "lai-" apiKeyRandBytes = 32 // 32 bytes = 64 hex chars keyPrefixLen = 8 // display prefix length (from the random part) ) // GenerateAPIKey generates a new API key. Returns the plaintext key, // its HMAC-SHA256 hash, and a display prefix. func GenerateAPIKey(hmacSecret string) (plaintext, hash, prefix string, err error) { b := make([]byte, apiKeyRandBytes) if _, err := rand.Read(b); err != nil { return "", "", "", fmt.Errorf("failed to generate API key: %w", err) } randHex := hex.EncodeToString(b) plaintext = apiKeyPrefix + randHex hash = HashAPIKey(plaintext, hmacSecret) prefix = plaintext[:len(apiKeyPrefix)+keyPrefixLen] return plaintext, hash, prefix, nil } // HashAPIKey returns the HMAC-SHA256 hex digest of the given plaintext key. // If hmacSecret is empty, falls back to plain SHA-256 for backward compatibility. func HashAPIKey(plaintext, hmacSecret string) string { if hmacSecret == "" { h := sha256.Sum256([]byte(plaintext)) return hex.EncodeToString(h[:]) } mac := hmac.New(sha256.New, []byte(hmacSecret)) mac.Write([]byte(plaintext)) return hex.EncodeToString(mac.Sum(nil)) } // CreateAPIKey generates and stores a new API key for the given user. // Returns the plaintext key (shown once) and the database record. func CreateAPIKey(db *gorm.DB, userID, name, role, hmacSecret string, expiresAt *time.Time) (string, *UserAPIKey, error) { plaintext, hash, prefix, err := GenerateAPIKey(hmacSecret) if err != nil { return "", nil, err } record := &UserAPIKey{ ID: uuid.New().String(), UserID: userID, Name: name, KeyHash: hash, KeyPrefix: prefix, Role: role, ExpiresAt: expiresAt, } if err := db.Create(record).Error; err != nil { return "", nil, fmt.Errorf("failed to store API key: %w", err) } return plaintext, record, nil } // ValidateAPIKey looks up an API key by hashing the plaintext and searching // the database. Returns the key record if found, or an error. // Updates LastUsed on successful validation. func ValidateAPIKey(db *gorm.DB, plaintext, hmacSecret string) (*UserAPIKey, error) { hash := HashAPIKey(plaintext, hmacSecret) var key UserAPIKey if err := db.Preload("User").Where("key_hash = ?", hash).First(&key).Error; err != nil { return nil, fmt.Errorf("invalid API key") } if key.ExpiresAt != nil && time.Now().After(*key.ExpiresAt) { return nil, fmt.Errorf("API key expired") } if key.User.Status != StatusActive { return nil, fmt.Errorf("user account is not active") } // Update LastUsed now := time.Now() db.Model(&key).Update("last_used", now) return &key, nil } // ListAPIKeys returns all API keys for the given user (without plaintext). func ListAPIKeys(db *gorm.DB, userID string) ([]UserAPIKey, error) { var keys []UserAPIKey if err := db.Where("user_id = ?", userID).Order("created_at DESC").Find(&keys).Error; err != nil { return nil, err } return keys, nil } // RevokeAPIKey deletes an API key. Only the owner can revoke their own key. func RevokeAPIKey(db *gorm.DB, keyID, userID string) error { result := db.Where("id = ? AND user_id = ?", keyID, userID).Delete(&UserAPIKey{}) if result.RowsAffected == 0 { return fmt.Errorf("API key not found or not owned by user") } return result.Error } // CleanExpiredAPIKeys removes all API keys that have passed their expiry time. func CleanExpiredAPIKeys(db *gorm.DB) error { return db.Where("expires_at IS NOT NULL AND expires_at < ?", time.Now()).Delete(&UserAPIKey{}).Error } ================================================ FILE: core/http/auth/apikeys_test.go ================================================ //go:build auth package auth_test import ( "strings" "github.com/mudler/LocalAI/core/http/auth" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gorm.io/gorm" ) var _ = Describe("API Keys", func() { var ( db *gorm.DB user *auth.User ) // Use empty HMAC secret for tests (falls back to plain SHA-256) hmacSecret := "" BeforeEach(func() { db = testDB() user = createTestUser(db, "apikey@example.com", auth.RoleUser, auth.ProviderGitHub) }) Describe("GenerateAPIKey", func() { It("returns key with 'lai-' prefix", func() { plaintext, _, _, err := auth.GenerateAPIKey(hmacSecret) Expect(err).ToNot(HaveOccurred()) Expect(plaintext).To(HavePrefix("lai-")) }) It("returns consistent hash for same plaintext", func() { plaintext, hash, _, err := auth.GenerateAPIKey(hmacSecret) Expect(err).ToNot(HaveOccurred()) Expect(auth.HashAPIKey(plaintext, hmacSecret)).To(Equal(hash)) }) It("returns prefix for display", func() { _, _, prefix, err := auth.GenerateAPIKey(hmacSecret) Expect(err).ToNot(HaveOccurred()) Expect(prefix).To(HavePrefix("lai-")) Expect(len(prefix)).To(Equal(12)) // "lai-" + 8 chars }) It("generates unique keys", func() { key1, _, _, _ := auth.GenerateAPIKey(hmacSecret) key2, _, _, _ := auth.GenerateAPIKey(hmacSecret) Expect(key1).ToNot(Equal(key2)) }) }) Describe("CreateAPIKey", func() { It("stores hashed key in DB", func() { plaintext, record, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser, hmacSecret, nil) Expect(err).ToNot(HaveOccurred()) Expect(plaintext).To(HavePrefix("lai-")) Expect(record.KeyHash).To(Equal(auth.HashAPIKey(plaintext, hmacSecret))) }) It("does not store plaintext in DB", func() { plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test key", auth.RoleUser, hmacSecret, nil) Expect(err).ToNot(HaveOccurred()) var keys []auth.UserAPIKey db.Find(&keys) for _, k := range keys { Expect(k.KeyHash).ToNot(Equal(plaintext)) Expect(strings.Contains(k.KeyHash, "lai-")).To(BeFalse()) } }) It("inherits role from parameter", func() { _, record, err := auth.CreateAPIKey(db, user.ID, "admin key", auth.RoleAdmin, hmacSecret, nil) Expect(err).ToNot(HaveOccurred()) Expect(record.Role).To(Equal(auth.RoleAdmin)) }) }) Describe("ValidateAPIKey", func() { It("returns UserAPIKey for valid key", func() { plaintext, _, err := auth.CreateAPIKey(db, user.ID, "valid key", auth.RoleUser, hmacSecret, nil) Expect(err).ToNot(HaveOccurred()) found, err := auth.ValidateAPIKey(db, plaintext, hmacSecret) Expect(err).ToNot(HaveOccurred()) Expect(found).ToNot(BeNil()) Expect(found.UserID).To(Equal(user.ID)) }) It("returns error for invalid key", func() { _, err := auth.ValidateAPIKey(db, "lai-invalidkey12345678901234567890", hmacSecret) Expect(err).To(HaveOccurred()) }) It("updates LastUsed timestamp", func() { plaintext, record, err := auth.CreateAPIKey(db, user.ID, "used key", auth.RoleUser, hmacSecret, nil) Expect(err).ToNot(HaveOccurred()) Expect(record.LastUsed).To(BeNil()) _, err = auth.ValidateAPIKey(db, plaintext, hmacSecret) Expect(err).ToNot(HaveOccurred()) var updated auth.UserAPIKey db.First(&updated, "id = ?", record.ID) Expect(updated.LastUsed).ToNot(BeNil()) }) It("loads associated user", func() { plaintext, _, err := auth.CreateAPIKey(db, user.ID, "with user", auth.RoleUser, hmacSecret, nil) Expect(err).ToNot(HaveOccurred()) found, err := auth.ValidateAPIKey(db, plaintext, hmacSecret) Expect(err).ToNot(HaveOccurred()) Expect(found.User.ID).To(Equal(user.ID)) Expect(found.User.Email).To(Equal("apikey@example.com")) }) }) Describe("ListAPIKeys", func() { It("returns all keys for the user", func() { auth.CreateAPIKey(db, user.ID, "key1", auth.RoleUser, hmacSecret, nil) auth.CreateAPIKey(db, user.ID, "key2", auth.RoleUser, hmacSecret, nil) keys, err := auth.ListAPIKeys(db, user.ID) Expect(err).ToNot(HaveOccurred()) Expect(keys).To(HaveLen(2)) }) It("does not return other users' keys", func() { other := createTestUser(db, "other@example.com", auth.RoleUser, auth.ProviderGitHub) auth.CreateAPIKey(db, user.ID, "my key", auth.RoleUser, hmacSecret, nil) auth.CreateAPIKey(db, other.ID, "other key", auth.RoleUser, hmacSecret, nil) keys, err := auth.ListAPIKeys(db, user.ID) Expect(err).ToNot(HaveOccurred()) Expect(keys).To(HaveLen(1)) Expect(keys[0].Name).To(Equal("my key")) }) }) Context("with HMAC secret", func() { hmacSecretVal := "test-hmac-secret-456" It("generates different hash than empty secret", func() { plaintext, _, _, err := auth.GenerateAPIKey("") Expect(err).ToNot(HaveOccurred()) hashEmpty := auth.HashAPIKey(plaintext, "") hashHMAC := auth.HashAPIKey(plaintext, hmacSecretVal) Expect(hashEmpty).ToNot(Equal(hashHMAC)) }) It("round-trips CreateAPIKey and ValidateAPIKey with HMAC secret", func() { plaintext, _, err := auth.CreateAPIKey(db, user.ID, "hmac key", auth.RoleUser, hmacSecretVal, nil) Expect(err).ToNot(HaveOccurred()) found, err := auth.ValidateAPIKey(db, plaintext, hmacSecretVal) Expect(err).ToNot(HaveOccurred()) Expect(found).ToNot(BeNil()) Expect(found.UserID).To(Equal(user.ID)) }) It("does not validate with wrong HMAC secret", func() { plaintext, _, err := auth.CreateAPIKey(db, user.ID, "hmac key2", auth.RoleUser, hmacSecretVal, nil) Expect(err).ToNot(HaveOccurred()) _, err = auth.ValidateAPIKey(db, plaintext, "wrong-secret") Expect(err).To(HaveOccurred()) }) It("does not validate key created with empty secret using non-empty secret", func() { plaintext, _, err := auth.CreateAPIKey(db, user.ID, "empty-secret key", auth.RoleUser, "", nil) Expect(err).ToNot(HaveOccurred()) _, err = auth.ValidateAPIKey(db, plaintext, hmacSecretVal) Expect(err).To(HaveOccurred()) }) It("does not validate key created with non-empty secret using empty secret", func() { plaintext, _, err := auth.CreateAPIKey(db, user.ID, "nonempty-secret key", auth.RoleUser, hmacSecretVal, nil) Expect(err).ToNot(HaveOccurred()) _, err = auth.ValidateAPIKey(db, plaintext, "") Expect(err).To(HaveOccurred()) }) }) Describe("RevokeAPIKey", func() { It("deletes the key record", func() { plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser, hmacSecret, nil) Expect(err).ToNot(HaveOccurred()) err = auth.RevokeAPIKey(db, record.ID, user.ID) Expect(err).ToNot(HaveOccurred()) _, err = auth.ValidateAPIKey(db, plaintext, hmacSecret) Expect(err).To(HaveOccurred()) }) It("only allows owner to revoke their own key", func() { _, record, err := auth.CreateAPIKey(db, user.ID, "mine", auth.RoleUser, hmacSecret, nil) Expect(err).ToNot(HaveOccurred()) other := createTestUser(db, "attacker@example.com", auth.RoleUser, auth.ProviderGitHub) err = auth.RevokeAPIKey(db, record.ID, other.ID) Expect(err).To(HaveOccurred()) }) }) }) ================================================ FILE: core/http/auth/auth_suite_test.go ================================================ //go:build auth package auth_test import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestAuth(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Auth Suite") } ================================================ FILE: core/http/auth/db.go ================================================ package auth import ( "fmt" "strings" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" ) // InitDB initializes the auth database. If databaseURL starts with "postgres://" // or "postgresql://", it connects to PostgreSQL; otherwise it treats the value // as a SQLite file path (use ":memory:" for in-memory). // SQLite support requires building with the "auth" build tag (CGO). func InitDB(databaseURL string) (*gorm.DB, error) { var dialector gorm.Dialector if strings.HasPrefix(databaseURL, "postgres://") || strings.HasPrefix(databaseURL, "postgresql://") { dialector = postgres.Open(databaseURL) } else { d, err := openSQLiteDialector(databaseURL) if err != nil { return nil, err } dialector = d } db, err := gorm.Open(dialector, &gorm.Config{ Logger: logger.Default.LogMode(logger.Silent), }) if err != nil { return nil, fmt.Errorf("failed to open auth database: %w", err) } if err := db.AutoMigrate(&User{}, &Session{}, &UserAPIKey{}, &UsageRecord{}, &UserPermission{}, &InviteCode{}); err != nil { return nil, fmt.Errorf("failed to migrate auth tables: %w", err) } // Create composite index on users(provider, subject) for fast OAuth lookups if err := db.Exec("CREATE INDEX IF NOT EXISTS idx_users_provider_subject ON users(provider, subject)").Error; err != nil { // Ignore error on postgres if index already exists if !strings.Contains(err.Error(), "already exists") { return nil, fmt.Errorf("failed to create composite index: %w", err) } } return db, nil } ================================================ FILE: core/http/auth/db_nosqlite.go ================================================ //go:build !auth package auth import ( "fmt" "gorm.io/gorm" ) func openSQLiteDialector(path string) (gorm.Dialector, error) { return nil, fmt.Errorf("SQLite auth database requires building with -tags auth (CGO); use DATABASE_URL with PostgreSQL instead") } ================================================ FILE: core/http/auth/db_sqlite.go ================================================ //go:build auth package auth import ( "gorm.io/driver/sqlite" "gorm.io/gorm" ) func openSQLiteDialector(path string) (gorm.Dialector, error) { return sqlite.Open(path), nil } ================================================ FILE: core/http/auth/db_test.go ================================================ //go:build auth package auth_test import ( "github.com/mudler/LocalAI/core/http/auth" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("InitDB", func() { Context("SQLite", func() { It("creates all tables with in-memory SQLite", func() { db, err := auth.InitDB(":memory:") Expect(err).ToNot(HaveOccurred()) Expect(db).ToNot(BeNil()) // Verify tables exist Expect(db.Migrator().HasTable(&auth.User{})).To(BeTrue()) Expect(db.Migrator().HasTable(&auth.Session{})).To(BeTrue()) Expect(db.Migrator().HasTable(&auth.UserAPIKey{})).To(BeTrue()) }) It("is idempotent - running twice does not error", func() { db, err := auth.InitDB(":memory:") Expect(err).ToNot(HaveOccurred()) // Re-migrate on same DB should succeed err = db.AutoMigrate(&auth.User{}, &auth.Session{}, &auth.UserAPIKey{}) Expect(err).ToNot(HaveOccurred()) }) It("creates composite index on users(provider, subject)", func() { db, err := auth.InitDB(":memory:") Expect(err).ToNot(HaveOccurred()) // Insert a user to verify the index doesn't prevent normal operations user := &auth.User{ ID: "test-1", Provider: auth.ProviderGitHub, Subject: "12345", Role: "admin", Status: auth.StatusActive, } Expect(db.Create(user).Error).ToNot(HaveOccurred()) // Query using the indexed columns should work var found auth.User Expect(db.Where("provider = ? AND subject = ?", auth.ProviderGitHub, "12345").First(&found).Error).ToNot(HaveOccurred()) Expect(found.ID).To(Equal("test-1")) }) }) }) ================================================ FILE: core/http/auth/features.go ================================================ package auth // RouteFeature maps a route pattern + HTTP method to a required feature. type RouteFeature struct { Method string // "POST", "GET", "*" (any) Pattern string // Echo route pattern, e.g. "/v1/chat/completions" Feature string // Feature constant, e.g. FeatureChat } // RouteFeatureRegistry is the single source of truth for endpoint -> feature mappings. // To gate a new endpoint, add an entry here -- no other file changes needed. var RouteFeatureRegistry = []RouteFeature{ // Chat / Completions {"POST", "/v1/chat/completions", FeatureChat}, {"POST", "/chat/completions", FeatureChat}, {"POST", "/v1/completions", FeatureChat}, {"POST", "/completions", FeatureChat}, {"POST", "/v1/engines/:model/completions", FeatureChat}, {"POST", "/v1/edits", FeatureChat}, {"POST", "/edits", FeatureChat}, // Anthropic {"POST", "/v1/messages", FeatureChat}, {"POST", "/messages", FeatureChat}, // Open Responses {"POST", "/v1/responses", FeatureChat}, {"POST", "/responses", FeatureChat}, {"GET", "/v1/responses", FeatureChat}, {"GET", "/responses", FeatureChat}, // Embeddings {"POST", "/v1/embeddings", FeatureEmbeddings}, {"POST", "/embeddings", FeatureEmbeddings}, {"POST", "/v1/engines/:model/embeddings", FeatureEmbeddings}, // Images {"POST", "/v1/images/generations", FeatureImages}, {"POST", "/images/generations", FeatureImages}, {"POST", "/v1/images/inpainting", FeatureImages}, {"POST", "/images/inpainting", FeatureImages}, // Audio transcription {"POST", "/v1/audio/transcriptions", FeatureAudioTranscription}, {"POST", "/audio/transcriptions", FeatureAudioTranscription}, // Audio speech / TTS {"POST", "/v1/audio/speech", FeatureAudioSpeech}, {"POST", "/audio/speech", FeatureAudioSpeech}, {"POST", "/tts", FeatureAudioSpeech}, {"POST", "/v1/text-to-speech/:voice-id", FeatureAudioSpeech}, // VAD {"POST", "/vad", FeatureVAD}, {"POST", "/v1/vad", FeatureVAD}, // Detection {"POST", "/v1/detection", FeatureDetection}, // Video {"POST", "/video", FeatureVideo}, // Sound generation {"POST", "/v1/sound-generation", FeatureSound}, // Realtime {"GET", "/v1/realtime", FeatureRealtime}, {"POST", "/v1/realtime/sessions", FeatureRealtime}, {"POST", "/v1/realtime/transcription_session", FeatureRealtime}, {"POST", "/v1/realtime/calls", FeatureRealtime}, // MCP {"POST", "/v1/mcp/chat/completions", FeatureMCP}, {"POST", "/mcp/v1/chat/completions", FeatureMCP}, {"POST", "/mcp/chat/completions", FeatureMCP}, // Tokenize {"POST", "/v1/tokenize", FeatureTokenize}, // Rerank {"POST", "/v1/rerank", FeatureRerank}, // Stores {"POST", "/stores/set", FeatureStores}, {"POST", "/stores/delete", FeatureStores}, {"POST", "/stores/get", FeatureStores}, {"POST", "/stores/find", FeatureStores}, } // FeatureMeta describes a feature for the admin API/UI. type FeatureMeta struct { Key string `json:"key"` Label string `json:"label"` DefaultValue bool `json:"default"` } // AgentFeatureMetas returns metadata for agent features. func AgentFeatureMetas() []FeatureMeta { return []FeatureMeta{ {FeatureAgents, "Agents", false}, {FeatureSkills, "Skills", false}, {FeatureCollections, "Collections", false}, {FeatureMCPJobs, "MCP CI Jobs", false}, } } // APIFeatureMetas returns metadata for API endpoint features. func APIFeatureMetas() []FeatureMeta { return []FeatureMeta{ {FeatureChat, "Chat Completions", true}, {FeatureImages, "Image Generation", true}, {FeatureAudioSpeech, "Audio Speech / TTS", true}, {FeatureAudioTranscription, "Audio Transcription", true}, {FeatureVAD, "Voice Activity Detection", true}, {FeatureDetection, "Detection", true}, {FeatureVideo, "Video Generation", true}, {FeatureEmbeddings, "Embeddings", true}, {FeatureSound, "Sound Generation", true}, {FeatureRealtime, "Realtime", true}, {FeatureRerank, "Rerank", true}, {FeatureTokenize, "Tokenize", true}, {FeatureMCP, "MCP", true}, {FeatureStores, "Stores", true}, } } ================================================ FILE: core/http/auth/helpers_test.go ================================================ //go:build auth package auth_test import ( "net/http" "net/http/httptest" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/auth" . "github.com/onsi/gomega" "gorm.io/gorm" ) // testDB creates an in-memory SQLite GORM instance with auto-migration. func testDB() *gorm.DB { db, err := auth.InitDB(":memory:") Expect(err).ToNot(HaveOccurred()) return db } // createTestUser inserts a user directly into the DB for test setup. func createTestUser(db *gorm.DB, email, role, provider string) *auth.User { user := &auth.User{ ID: generateTestID(), Email: email, Name: "Test User", Provider: provider, Subject: generateTestID(), Role: role, Status: auth.StatusActive, } err := db.Create(user).Error Expect(err).ToNot(HaveOccurred()) return user } // createTestSession creates a session for a user, returns plaintext session token. func createTestSession(db *gorm.DB, userID string) string { sessionID, err := auth.CreateSession(db, userID, "") Expect(err).ToNot(HaveOccurred()) return sessionID } var testIDCounter int func generateTestID() string { testIDCounter++ return "test-id-" + string(rune('a'+testIDCounter)) } // ok is a simple handler that returns 200 OK. func ok(c echo.Context) error { return c.String(http.StatusOK, "ok") } // newAuthTestApp creates a minimal Echo app with the new auth middleware. func newAuthTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo { e := echo.New() e.Use(auth.Middleware(db, appConfig)) // API routes (require auth) e.GET("/v1/models", ok) e.POST("/v1/chat/completions", ok) e.GET("/api/settings", ok) e.POST("/api/settings", ok) // Auth routes (exempt) e.GET("/api/auth/status", ok) e.GET("/api/auth/github/login", ok) // Static routes e.GET("/app", ok) e.GET("/app/*", ok) return e } // newAdminTestApp creates an Echo app with admin-protected routes. func newAdminTestApp(db *gorm.DB, appConfig *config.ApplicationConfig) *echo.Echo { e := echo.New() e.Use(auth.Middleware(db, appConfig)) // Regular routes e.GET("/v1/models", ok) e.POST("/v1/chat/completions", ok) // Admin-only routes adminMw := auth.RequireAdmin() e.POST("/api/settings", ok, adminMw) e.POST("/models/apply", ok, adminMw) e.POST("/backends/apply", ok, adminMw) e.GET("/api/agents", ok, adminMw) // Trace/log endpoints (admin only) e.GET("/api/traces", ok, adminMw) e.POST("/api/traces/clear", ok, adminMw) e.GET("/api/backend-logs", ok, adminMw) e.GET("/api/backend-logs/:modelId", ok, adminMw) // Gallery/management reads (admin only) e.GET("/api/operations", ok, adminMw) e.GET("/api/models", ok, adminMw) e.GET("/api/backends", ok, adminMw) e.GET("/api/resources", ok, adminMw) e.GET("/api/p2p/workers", ok, adminMw) // Agent task/job routes (admin only) e.POST("/api/agent/tasks", ok, adminMw) e.GET("/api/agent/tasks", ok, adminMw) e.GET("/api/agent/jobs", ok, adminMw) // System info (admin only) e.GET("/system", ok, adminMw) e.GET("/backend/monitor", ok, adminMw) return e } // doRequest performs an HTTP request against the given Echo app and returns the recorder. func doRequest(e *echo.Echo, method, path string, opts ...func(*http.Request)) *httptest.ResponseRecorder { req := httptest.NewRequest(method, path, nil) req.Header.Set("Content-Type", "application/json") for _, opt := range opts { opt(req) } rec := httptest.NewRecorder() e.ServeHTTP(rec, req) return rec } func withBearerToken(token string) func(*http.Request) { return func(req *http.Request) { req.Header.Set("Authorization", "Bearer "+token) } } func withXApiKey(key string) func(*http.Request) { return func(req *http.Request) { req.Header.Set("x-api-key", key) } } func withSessionCookie(sessionID string) func(*http.Request) { return func(req *http.Request) { req.AddCookie(&http.Cookie{Name: "session", Value: sessionID}) } } func withTokenCookie(token string) func(*http.Request) { return func(req *http.Request) { req.AddCookie(&http.Cookie{Name: "token", Value: token}) } } ================================================ FILE: core/http/auth/middleware.go ================================================ package auth import ( "bytes" "crypto/subtle" "encoding/json" "io" "net/http" "strings" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "gorm.io/gorm" ) const ( contextKeyUser = "auth_user" contextKeyRole = "auth_role" ) // Middleware returns an Echo middleware that handles authentication. // // Resolution order: // 1. If auth not enabled AND no legacy API keys → pass through // 2. Skip auth for exempt paths (PathWithoutAuth + /api/auth/) // 3. If auth enabled (db != nil): // a. Try "session" cookie → DB lookup // b. Try Authorization: Bearer → session ID, then user API key // c. Try x-api-key / xi-api-key → user API key // d. Try "token" cookie → legacy API key check // e. Check all extracted keys against legacy ApiKeys → synthetic admin // 4. If auth not enabled → delegate to legacy API key validation // 5. If no auth found for /api/ or /v1/ paths → 401 // 6. Otherwise pass through (static assets, UI pages, etc.) func Middleware(db *gorm.DB, appConfig *config.ApplicationConfig) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { authEnabled := db != nil hasLegacyKeys := len(appConfig.ApiKeys) > 0 // 1. No auth at all if !authEnabled && !hasLegacyKeys { return next(c) } path := c.Request().URL.Path exempt := isExemptPath(path, appConfig) authenticated := false // 2. Try to authenticate (populates user in context if possible) if authEnabled { user := tryAuthenticate(c, db, appConfig) if user != nil { c.Set(contextKeyUser, user) c.Set(contextKeyRole, user.Role) authenticated = true // Session rotation for cookie-based sessions if session, ok := c.Get("_auth_session").(*Session); ok { MaybeRotateSession(c, db, session, appConfig.Auth.APIKeyHMACSecret) } } } // 3. Legacy API key validation (works whether auth is enabled or not) if !authenticated && hasLegacyKeys { key := extractKey(c) if key != "" && isValidLegacyKey(key, appConfig) { syntheticUser := &User{ ID: "legacy-api-key", Name: "API Key User", Role: RoleAdmin, } c.Set(contextKeyUser, syntheticUser) c.Set(contextKeyRole, RoleAdmin) authenticated = true } } // 4. If authenticated or exempt path, proceed if authenticated || exempt { return next(c) } // 5. Require auth for API paths if isAPIPath(path) { // Check GET exemptions for legacy keys if hasLegacyKeys && appConfig.DisableApiKeyRequirementForHttpGet && c.Request().Method == http.MethodGet { for _, rx := range appConfig.HttpGetExemptedEndpoints { if rx.MatchString(c.Path()) { return next(c) } } } return authError(c, appConfig) } // 6. Non-API paths (UI, static assets) pass through. // The React UI handles login redirects client-side. return next(c) } } } // RequireAdmin returns middleware that checks the user has admin role. func RequireAdmin() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { user := GetUser(c) if user == nil { return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ Error: &schema.APIError{ Message: "Authentication required", Code: http.StatusUnauthorized, Type: "authentication_error", }, }) } if user.Role != RoleAdmin { return c.JSON(http.StatusForbidden, schema.ErrorResponse{ Error: &schema.APIError{ Message: "Admin access required", Code: http.StatusForbidden, Type: "authorization_error", }, }) } return next(c) } } } // NoopMiddleware returns a middleware that does nothing (pass-through). // Used when auth is disabled to satisfy route registration that expects // an admin middleware parameter. func NoopMiddleware() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return next } } // RequireFeature returns middleware that checks the user has access to the given feature. // If no auth DB is provided, it passes through (backward compat). // Admins always pass. Regular users must have the feature enabled in their permissions. func RequireFeature(db *gorm.DB, feature string) echo.MiddlewareFunc { if db == nil { return NoopMiddleware() } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { user := GetUser(c) if user == nil { return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ Error: &schema.APIError{ Message: "Authentication required", Code: http.StatusUnauthorized, Type: "authentication_error", }, }) } if user.Role == RoleAdmin { return next(c) } perm, err := GetCachedUserPermissions(c, db, user.ID) if err != nil { return c.JSON(http.StatusForbidden, schema.ErrorResponse{ Error: &schema.APIError{ Message: "feature not enabled for your account", Code: http.StatusForbidden, Type: "authorization_error", }, }) } val, exists := perm.Permissions[feature] if !exists { if !isDefaultOnFeature(feature) { return c.JSON(http.StatusForbidden, schema.ErrorResponse{ Error: &schema.APIError{ Message: "feature not enabled for your account", Code: http.StatusForbidden, Type: "authorization_error", }, }) } } else if !val { return c.JSON(http.StatusForbidden, schema.ErrorResponse{ Error: &schema.APIError{ Message: "feature not enabled for your account", Code: http.StatusForbidden, Type: "authorization_error", }, }) } return next(c) } } } // GetUser returns the authenticated user from the echo context, or nil. func GetUser(c echo.Context) *User { u, ok := c.Get(contextKeyUser).(*User) if !ok { return nil } return u } // GetUserRole returns the role of the authenticated user, or empty string. func GetUserRole(c echo.Context) string { role, _ := c.Get(contextKeyRole).(string) return role } // RequireRouteFeature returns a global middleware that checks the user has access // to the feature required by the matched route. It uses the RouteFeatureRegistry // to look up the required feature for each route pattern + HTTP method. // If no entry matches, the request passes through (no restriction). func RequireRouteFeature(db *gorm.DB) echo.MiddlewareFunc { if db == nil { return NoopMiddleware() } // Pre-build lookup map: "METHOD:pattern" -> feature lookup := map[string]string{} for _, rf := range RouteFeatureRegistry { lookup[rf.Method+":"+rf.Pattern] = rf.Feature } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { path := c.Path() // Echo route pattern (e.g. "/v1/engines/:model/completions") method := c.Request().Method feature := lookup[method+":"+path] if feature == "" { feature = lookup["*:"+path] } if feature == "" { return next(c) // no restriction for this route } user := GetUser(c) if user == nil { return next(c) // auth middleware handles unauthenticated } if user.Role == RoleAdmin { return next(c) } perm, err := GetCachedUserPermissions(c, db, user.ID) if err != nil { return c.JSON(http.StatusInternalServerError, schema.ErrorResponse{ Error: &schema.APIError{ Message: "failed to check permissions", Code: http.StatusInternalServerError, Type: "server_error", }, }) } val, exists := perm.Permissions[feature] if !exists { if !isDefaultOnFeature(feature) { return c.JSON(http.StatusForbidden, schema.ErrorResponse{ Error: &schema.APIError{ Message: "feature not enabled for your account: " + feature, Code: http.StatusForbidden, Type: "authorization_error", }, }) } } else if !val { return c.JSON(http.StatusForbidden, schema.ErrorResponse{ Error: &schema.APIError{ Message: "feature not enabled for your account: " + feature, Code: http.StatusForbidden, Type: "authorization_error", }, }) } return next(c) } } } // RequireModelAccess returns a global middleware that checks the user is allowed // to use the resolved model. It extracts the model name directly from the request // (path param, query param, JSON body, or form value) rather than relying on a // context key set by downstream route-specific middleware. func RequireModelAccess(db *gorm.DB) echo.MiddlewareFunc { if db == nil { return NoopMiddleware() } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { user := GetUser(c) if user == nil { return next(c) } if user.Role == RoleAdmin { return next(c) } // Check if this user even has a model allowlist enabled before // doing the expensive body read. Most users won't have restrictions. // Uses request-scoped cache to avoid duplicate DB hit when // RequireRouteFeature already fetched permissions. perm, err := GetCachedUserPermissions(c, db, user.ID) if err != nil { return c.JSON(http.StatusInternalServerError, schema.ErrorResponse{ Error: &schema.APIError{ Message: "failed to check permissions", Code: http.StatusInternalServerError, Type: "server_error", }, }) } allowlist := perm.AllowedModels if !allowlist.Enabled { return next(c) } modelName := extractModelFromRequest(c) if modelName == "" { return next(c) } for _, m := range allowlist.Models { if m == modelName { return next(c) } } return c.JSON(http.StatusForbidden, schema.ErrorResponse{ Error: &schema.APIError{ Message: "access denied to model: " + modelName, Code: http.StatusForbidden, Type: "authorization_error", }, }) } } } // extractModelFromRequest extracts the model name from various request sources. // It checks URL path params, query params, JSON body, and form values. // For JSON bodies, it peeks at the body and resets it so downstream handlers // can still read it. func extractModelFromRequest(c echo.Context) string { // 1. URL path param (e.g. /v1/engines/:model/completions) if model := c.Param("model"); model != "" { return model } // 2. Query param if model := c.QueryParam("model"); model != "" { return model } // 3. Peek at JSON body if strings.HasPrefix(c.Request().Header.Get("Content-Type"), "application/json") { body, err := io.ReadAll(c.Request().Body) c.Request().Body = io.NopCloser(bytes.NewReader(body)) // always reset if err == nil && len(body) > 0 { var m struct { Model string `json:"model"` } if json.Unmarshal(body, &m) == nil && m.Model != "" { return m.Model } } } // 4. Form value (multipart/form-data) if model := c.FormValue("model"); model != "" { return model } return "" } // tryAuthenticate attempts to authenticate the request using the database. func tryAuthenticate(c echo.Context, db *gorm.DB, appConfig *config.ApplicationConfig) *User { hmacSecret := appConfig.Auth.APIKeyHMACSecret // a. Session cookie if cookie, err := c.Cookie(sessionCookie); err == nil && cookie.Value != "" { if user, session := ValidateSession(db, cookie.Value, hmacSecret); user != nil { // Store session for rotation check in middleware c.Set("_auth_session", session) return user } } // b. Authorization: Bearer token authHeader := c.Request().Header.Get("Authorization") if strings.HasPrefix(authHeader, "Bearer ") { token := strings.TrimPrefix(authHeader, "Bearer ") // Try as session ID first if user, _ := ValidateSession(db, token, hmacSecret); user != nil { return user } // Try as user API key if key, err := ValidateAPIKey(db, token, hmacSecret); err == nil { return &key.User } } // c. x-api-key / xi-api-key headers for _, header := range []string{"x-api-key", "xi-api-key"} { if key := c.Request().Header.Get(header); key != "" { if apiKey, err := ValidateAPIKey(db, key, hmacSecret); err == nil { return &apiKey.User } } } // d. token cookie (legacy) if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" { // Try as user API key if key, err := ValidateAPIKey(db, cookie.Value, hmacSecret); err == nil { return &key.User } } return nil } // extractKey extracts an API key from the request (all sources). func extractKey(c echo.Context) string { // Authorization header auth := c.Request().Header.Get("Authorization") if strings.HasPrefix(auth, "Bearer ") { return strings.TrimPrefix(auth, "Bearer ") } if auth != "" { return auth } // x-api-key if key := c.Request().Header.Get("x-api-key"); key != "" { return key } // xi-api-key if key := c.Request().Header.Get("xi-api-key"); key != "" { return key } // token cookie if cookie, err := c.Cookie("token"); err == nil && cookie.Value != "" { return cookie.Value } return "" } // isValidLegacyKey checks if the key matches any configured API key // using constant-time comparison to prevent timing attacks. func isValidLegacyKey(key string, appConfig *config.ApplicationConfig) bool { for _, validKey := range appConfig.ApiKeys { if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { return true } } return false } // isExemptPath returns true if the path should skip authentication. func isExemptPath(path string, appConfig *config.ApplicationConfig) bool { // Auth endpoints are always public if strings.HasPrefix(path, "/api/auth/") { return true } // Check configured exempt paths for _, p := range appConfig.PathWithoutAuth { if strings.HasPrefix(path, p) { return true } } return false } // isAPIPath returns true for paths that always require authentication. func isAPIPath(path string) bool { return strings.HasPrefix(path, "/api/") || strings.HasPrefix(path, "/v1/") || strings.HasPrefix(path, "/models/") || strings.HasPrefix(path, "/backends/") || strings.HasPrefix(path, "/backend/") || strings.HasPrefix(path, "/tts") || strings.HasPrefix(path, "/vad") || strings.HasPrefix(path, "/video") || strings.HasPrefix(path, "/stores/") || strings.HasPrefix(path, "/system") || strings.HasPrefix(path, "/ws/") || strings.HasPrefix(path, "/generated-") || path == "/metrics" } // authError returns an appropriate error response. func authError(c echo.Context, appConfig *config.ApplicationConfig) error { c.Response().Header().Set("WWW-Authenticate", "Bearer") if appConfig.OpaqueErrors { return c.NoContent(http.StatusUnauthorized) } contentType := c.Request().Header.Get("Content-Type") if strings.Contains(contentType, "application/json") { return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ Error: &schema.APIError{ Message: "An authentication key is required", Code: http.StatusUnauthorized, Type: "invalid_request_error", }, }) } return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ Error: &schema.APIError{ Message: "An authentication key is required", Code: http.StatusUnauthorized, Type: "invalid_request_error", }, }) } ================================================ FILE: core/http/auth/middleware_test.go ================================================ //go:build auth package auth_test import ( "net/http" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/auth" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gorm.io/gorm" ) var _ = Describe("Auth Middleware", func() { Context("auth disabled, no API keys", func() { var app *echo.Echo BeforeEach(func() { appConfig := config.NewApplicationConfig() app = newAuthTestApp(nil, appConfig) }) It("passes through all requests", func() { rec := doRequest(app, http.MethodGet, "/v1/models") Expect(rec.Code).To(Equal(http.StatusOK)) }) It("passes through POST requests", func() { rec := doRequest(app, http.MethodPost, "/v1/chat/completions") Expect(rec.Code).To(Equal(http.StatusOK)) }) }) Context("auth disabled, API keys configured", func() { var app *echo.Echo const validKey = "sk-test-key-123" BeforeEach(func() { appConfig := config.NewApplicationConfig() appConfig.ApiKeys = []string{validKey} app = newAuthTestApp(nil, appConfig) }) It("returns 401 for request without key", func() { rec := doRequest(app, http.MethodGet, "/v1/models") Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("passes with valid Bearer token", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(validKey)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("passes with valid x-api-key header", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withXApiKey(validKey)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("passes with valid token cookie", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withTokenCookie(validKey)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("returns 401 for invalid key", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("wrong-key")) Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) }) Context("auth enabled with database", func() { var ( db *gorm.DB app *echo.Echo appConfig *config.ApplicationConfig user *auth.User ) BeforeEach(func() { db = testDB() appConfig = config.NewApplicationConfig() app = newAuthTestApp(db, appConfig) user = createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) }) It("allows requests with valid session cookie", func() { sessionID := createTestSession(db, user.ID) rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows requests with valid session as Bearer token", func() { sessionID := createTestSession(db, user.ID) rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows requests with valid user API key as Bearer token", func() { plaintext, _, err := auth.CreateAPIKey(db, user.ID, "test", auth.RoleUser, "", nil) Expect(err).ToNot(HaveOccurred()) rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows requests with legacy API_KEY as admin bypass", func() { appConfig.ApiKeys = []string{"legacy-key-123"} app = newAuthTestApp(db, appConfig) rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("legacy-key-123")) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("returns 401 for expired session", func() { sessionID := createTestSession(db, user.ID) // Manually expire (session ID in DB is the hash) hash := auth.HashAPIKey(sessionID, "") db.Model(&auth.Session{}).Where("id = ?", hash). Update("expires_at", "2020-01-01") rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("returns 401 for invalid session ID", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withSessionCookie("invalid-session-id")) Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("returns 401 for revoked API key", func() { plaintext, record, err := auth.CreateAPIKey(db, user.ID, "to revoke", auth.RoleUser, "", nil) Expect(err).ToNot(HaveOccurred()) err = auth.RevokeAPIKey(db, record.ID, user.ID) Expect(err).ToNot(HaveOccurred()) rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(plaintext)) Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("skips auth for /api/auth/* paths", func() { rec := doRequest(app, http.MethodGet, "/api/auth/status") Expect(rec.Code).To(Equal(http.StatusOK)) }) It("skips auth for PathWithoutAuth paths", func() { rec := doRequest(app, http.MethodGet, "/healthz") // healthz is not registered in our test app, so it'll be 404/405 but NOT 401 Expect(rec.Code).ToNot(Equal(http.StatusUnauthorized)) }) It("returns 401 for unauthenticated API requests", func() { rec := doRequest(app, http.MethodGet, "/v1/models") Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("allows unauthenticated access to non-API paths when no legacy keys", func() { rec := doRequest(app, http.MethodGet, "/app") Expect(rec.Code).To(Equal(http.StatusOK)) }) }) Describe("RequireAdmin", func() { var ( db *gorm.DB appConfig *config.ApplicationConfig ) BeforeEach(func() { db = testDB() appConfig = config.NewApplicationConfig() }) It("passes for admin user", func() { admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub) sessionID := createTestSession(db, admin.ID) app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("returns 403 for user role", func() { user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) sessionID := createTestSession(db, user.ID) app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodPost, "/api/settings", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusForbidden)) }) It("returns 401 when no user in context", func() { app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodPost, "/api/settings") Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("allows admin to access model management", func() { admin := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub) sessionID := createTestSession(db, admin.ID) app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("blocks user from model management", func() { user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) sessionID := createTestSession(db, user.ID) app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodPost, "/models/apply", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusForbidden)) }) It("allows user to access regular inference endpoints", func() { user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) sessionID := createTestSession(db, user.ID) app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodPost, "/v1/chat/completions", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows legacy API key (admin bypass) on admin routes", func() { appConfig.ApiKeys = []string{"admin-key"} app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodPost, "/api/settings", withBearerToken("admin-key")) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows admin to access trace endpoints", func() { admin := createTestUser(db, "admin2@example.com", auth.RoleAdmin, auth.ProviderGitHub) sessionID := createTestSession(db, admin.ID) app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK)) rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("blocks non-admin from trace endpoints", func() { user := createTestUser(db, "user2@example.com", auth.RoleUser, auth.ProviderGitHub) sessionID := createTestSession(db, user.ID) app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodGet, "/api/traces", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusForbidden)) rec = doRequest(app, http.MethodGet, "/api/backend-logs", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusForbidden)) }) It("allows admin to access agent job endpoints", func() { admin := createTestUser(db, "admin3@example.com", auth.RoleAdmin, auth.ProviderGitHub) sessionID := createTestSession(db, admin.ID) app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK)) rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("blocks non-admin from agent job endpoints", func() { user := createTestUser(db, "user3@example.com", auth.RoleUser, auth.ProviderGitHub) sessionID := createTestSession(db, user.ID) app := newAdminTestApp(db, appConfig) rec := doRequest(app, http.MethodGet, "/api/agent/tasks", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusForbidden)) rec = doRequest(app, http.MethodGet, "/api/agent/jobs", withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusForbidden)) }) It("blocks non-admin from system/management endpoints", func() { user := createTestUser(db, "user4@example.com", auth.RoleUser, auth.ProviderGitHub) sessionID := createTestSession(db, user.ID) app := newAdminTestApp(db, appConfig) for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} { rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusForbidden), "expected 403 for path: "+path) } }) It("allows admin to access system/management endpoints", func() { admin := createTestUser(db, "admin4@example.com", auth.RoleAdmin, auth.ProviderGitHub) sessionID := createTestSession(db, admin.ID) app := newAdminTestApp(db, appConfig) for _, path := range []string{"/api/operations", "/api/models", "/api/backends", "/api/resources", "/api/p2p/workers", "/system", "/backend/monitor"} { rec := doRequest(app, http.MethodGet, path, withSessionCookie(sessionID)) Expect(rec.Code).To(Equal(http.StatusOK), "expected 200 for path: "+path) } }) }) }) ================================================ FILE: core/http/auth/models.go ================================================ package auth import ( "database/sql/driver" "encoding/json" "fmt" "time" ) // Auth provider constants. const ( ProviderLocal = "local" ProviderGitHub = "github" ProviderOIDC = "oidc" ) // User represents an authenticated user. type User struct { ID string `gorm:"primaryKey;size:36"` Email string `gorm:"size:255;index"` Name string `gorm:"size:255"` AvatarURL string `gorm:"size:512"` Provider string `gorm:"size:50"` // ProviderLocal, ProviderGitHub, ProviderOIDC Subject string `gorm:"size:255"` // provider-specific user ID PasswordHash string `json:"-"` // bcrypt hash, empty for OAuth-only users Role string `gorm:"size:20;default:user"` Status string `gorm:"size:20;default:active"` // "active", "pending" CreatedAt time.Time UpdatedAt time.Time } // Session represents a user login session. type Session struct { ID string `gorm:"primaryKey;size:64"` // HMAC-SHA256 hash of session token UserID string `gorm:"size:36;index"` ExpiresAt time.Time RotatedAt time.Time CreatedAt time.Time User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` } // UserAPIKey represents a user-generated API key for programmatic access. type UserAPIKey struct { ID string `gorm:"primaryKey;size:36"` UserID string `gorm:"size:36;index"` Name string `gorm:"size:255"` // user-provided label KeyHash string `gorm:"size:64;uniqueIndex"` KeyPrefix string `gorm:"size:12"` // first 8 chars of key for display Role string `gorm:"size:20"` CreatedAt time.Time ExpiresAt *time.Time `gorm:"index"` LastUsed *time.Time User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` } // PermissionMap is a flexible map of feature -> enabled, stored as JSON text. // Known features: "agents", "skills", "collections", "mcp_jobs". // New features can be added without schema changes. type PermissionMap map[string]bool // Value implements driver.Valuer for GORM JSON serialization. func (p PermissionMap) Value() (driver.Value, error) { if p == nil { return "{}", nil } b, err := json.Marshal(p) if err != nil { return nil, fmt.Errorf("failed to marshal PermissionMap: %w", err) } return string(b), nil } // Scan implements sql.Scanner for GORM JSON deserialization. func (p *PermissionMap) Scan(value any) error { if value == nil { *p = PermissionMap{} return nil } var bytes []byte switch v := value.(type) { case string: bytes = []byte(v) case []byte: bytes = v default: return fmt.Errorf("cannot scan %T into PermissionMap", value) } return json.Unmarshal(bytes, p) } // InviteCode represents an admin-generated invitation for user registration. type InviteCode struct { ID string `gorm:"primaryKey;size:36"` Code string `gorm:"uniqueIndex;not null;size:64"` // HMAC-SHA256 hash of invite code CodePrefix string `gorm:"size:12"` // first 8 chars for admin display CreatedBy string `gorm:"size:36;not null"` UsedBy *string `gorm:"size:36"` UsedAt *time.Time ExpiresAt time.Time `gorm:"not null;index"` CreatedAt time.Time Creator User `gorm:"foreignKey:CreatedBy"` Consumer *User `gorm:"foreignKey:UsedBy"` } // ModelAllowlist controls which models a user can access. // When Enabled is false (default), all models are allowed. type ModelAllowlist struct { Enabled bool `json:"enabled"` Models []string `json:"models,omitempty"` } // Value implements driver.Valuer for GORM JSON serialization. func (m ModelAllowlist) Value() (driver.Value, error) { b, err := json.Marshal(m) if err != nil { return nil, fmt.Errorf("failed to marshal ModelAllowlist: %w", err) } return string(b), nil } // Scan implements sql.Scanner for GORM JSON deserialization. func (m *ModelAllowlist) Scan(value any) error { if value == nil { *m = ModelAllowlist{} return nil } var bytes []byte switch v := value.(type) { case string: bytes = []byte(v) case []byte: bytes = v default: return fmt.Errorf("cannot scan %T into ModelAllowlist", value) } return json.Unmarshal(bytes, m) } // UserPermission stores per-user feature permissions. type UserPermission struct { ID string `gorm:"primaryKey;size:36"` UserID string `gorm:"size:36;uniqueIndex"` Permissions PermissionMap `gorm:"type:text"` AllowedModels ModelAllowlist `gorm:"type:text"` CreatedAt time.Time UpdatedAt time.Time User User `gorm:"foreignKey:UserID;constraint:OnDelete:CASCADE"` } ================================================ FILE: core/http/auth/oauth.go ================================================ package auth import ( "context" "crypto/rand" "crypto/subtle" "encoding/hex" "encoding/json" "fmt" "io" "net/http" "strings" "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/xlog" "golang.org/x/oauth2" githubOAuth "golang.org/x/oauth2/github" "gorm.io/gorm" ) // providerEntry holds the OAuth2/OIDC config for a single provider. type providerEntry struct { oauth2Config oauth2.Config oidcVerifier *oidc.IDTokenVerifier // nil for GitHub (API-based user info) name string userInfoURL string // only used for GitHub } // oauthUserInfo is a provider-agnostic representation of an authenticated user. type oauthUserInfo struct { Subject string Email string Name string AvatarURL string } // OAuthManager manages multiple OAuth/OIDC providers. type OAuthManager struct { providers map[string]*providerEntry } // OAuthParams groups the parameters needed to create an OAuthManager. type OAuthParams struct { GitHubClientID string GitHubClientSecret string OIDCIssuer string OIDCClientID string OIDCClientSecret string } // NewOAuthManager creates an OAuthManager from the given params. func NewOAuthManager(baseURL string, params OAuthParams) (*OAuthManager, error) { m := &OAuthManager{providers: make(map[string]*providerEntry)} if params.GitHubClientID != "" { m.providers[ProviderGitHub] = &providerEntry{ name: ProviderGitHub, oauth2Config: oauth2.Config{ ClientID: params.GitHubClientID, ClientSecret: params.GitHubClientSecret, Endpoint: githubOAuth.Endpoint, RedirectURL: baseURL + "/api/auth/github/callback", Scopes: []string{"user:email", "read:user"}, }, userInfoURL: "https://api.github.com/user", } } if params.OIDCClientID != "" && params.OIDCIssuer != "" { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() provider, err := oidc.NewProvider(ctx, params.OIDCIssuer) if err != nil { return nil, fmt.Errorf("OIDC discovery failed for %s: %w", params.OIDCIssuer, err) } verifier := provider.Verifier(&oidc.Config{ClientID: params.OIDCClientID}) m.providers[ProviderOIDC] = &providerEntry{ name: ProviderOIDC, oauth2Config: oauth2.Config{ ClientID: params.OIDCClientID, ClientSecret: params.OIDCClientSecret, Endpoint: provider.Endpoint(), RedirectURL: baseURL + "/api/auth/oidc/callback", Scopes: []string{oidc.ScopeOpenID, "profile", "email"}, }, oidcVerifier: verifier, } } return m, nil } // Providers returns the list of configured provider names. func (m *OAuthManager) Providers() []string { names := make([]string, 0, len(m.providers)) for name := range m.providers { names = append(names, name) } return names } // LoginHandler redirects the user to the OAuth provider's login page. func (m *OAuthManager) LoginHandler(providerName string) echo.HandlerFunc { return func(c echo.Context) error { provider, ok := m.providers[providerName] if !ok { return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"}) } state, err := generateState() if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to generate state"}) } secure := isSecure(c) c.SetCookie(&http.Cookie{ Name: "oauth_state", Value: state, Path: "/", HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, MaxAge: 600, // 10 minutes }) // Store invite code in cookie if provided if inviteCode := c.QueryParam("invite_code"); inviteCode != "" { c.SetCookie(&http.Cookie{ Name: "invite_code", Value: inviteCode, Path: "/", HttpOnly: true, Secure: secure, SameSite: http.SameSiteLaxMode, MaxAge: 600, }) } url := provider.oauth2Config.AuthCodeURL(state) return c.Redirect(http.StatusTemporaryRedirect, url) } } // CallbackHandler handles the OAuth callback, creates/updates the user, and // creates a session. func (m *OAuthManager) CallbackHandler(providerName string, db *gorm.DB, adminEmail, registrationMode, hmacSecret string) echo.HandlerFunc { return func(c echo.Context) error { provider, ok := m.providers[providerName] if !ok { return c.JSON(http.StatusNotFound, map[string]string{"error": "unknown provider"}) } // Validate state stateCookie, err := c.Cookie("oauth_state") if err != nil || stateCookie.Value == "" || subtle.ConstantTimeCompare([]byte(stateCookie.Value), []byte(c.QueryParam("state"))) != 1 { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid OAuth state"}) } // Clear state cookie c.SetCookie(&http.Cookie{ Name: "oauth_state", Value: "", Path: "/", HttpOnly: true, Secure: isSecure(c), MaxAge: -1, }) // Exchange code for token code := c.QueryParam("code") if code == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing authorization code"}) } ctx, cancel := context.WithTimeout(c.Request().Context(), 30*time.Second) defer cancel() token, err := provider.oauth2Config.Exchange(ctx, code) if err != nil { xlog.Error("OAuth code exchange failed", "provider", providerName, "error", err) return c.JSON(http.StatusBadRequest, map[string]string{"error": "OAuth authentication failed"}) } // Fetch user info — branch based on provider type var userInfo *oauthUserInfo if provider.oidcVerifier != nil { userInfo, err = extractOIDCUserInfo(ctx, provider.oidcVerifier, token) } else { userInfo, err = fetchGitHubUserInfoAsOAuth(ctx, token.AccessToken) } if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to fetch user info"}) } // Retrieve invite code from cookie if present var inviteCode string if ic, err := c.Cookie("invite_code"); err == nil && ic.Value != "" { inviteCode = ic.Value // Clear the invite code cookie c.SetCookie(&http.Cookie{ Name: "invite_code", Value: "", Path: "/", HttpOnly: true, Secure: isSecure(c), MaxAge: -1, }) } // Upsert user (with invite code support) user, err := upsertOAuthUser(db, providerName, userInfo, adminEmail, registrationMode) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create user"}) } // For new users that are pending, check if they have a valid invite if user.Status != StatusActive && inviteCode != "" { if invite, err := ValidateInvite(db, inviteCode, hmacSecret); err == nil { user.Status = StatusActive db.Model(user).Update("status", StatusActive) ConsumeInvite(db, invite, user.ID) } } if user.Status != StatusActive { if registrationMode == "invite" { return c.JSON(http.StatusForbidden, map[string]string{"error": "a valid invite code is required to register"}) } return c.JSON(http.StatusForbidden, map[string]string{"error": "account pending approval"}) } // Maybe promote on login MaybePromote(db, user, adminEmail) // Create session sessionID, err := CreateSession(db, user.ID, hmacSecret) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create session"}) } SetSessionCookie(c, sessionID) return c.Redirect(http.StatusTemporaryRedirect, "/app") } } // extractOIDCUserInfo extracts user info from the OIDC ID token. func extractOIDCUserInfo(ctx context.Context, verifier *oidc.IDTokenVerifier, token *oauth2.Token) (*oauthUserInfo, error) { rawIDToken, ok := token.Extra("id_token").(string) if !ok || rawIDToken == "" { return nil, fmt.Errorf("no id_token in token response") } idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { return nil, fmt.Errorf("failed to verify ID token: %w", err) } var claims struct { Sub string `json:"sub"` Email string `json:"email"` Name string `json:"name"` Picture string `json:"picture"` } if err := idToken.Claims(&claims); err != nil { return nil, fmt.Errorf("failed to parse ID token claims: %w", err) } return &oauthUserInfo{ Subject: claims.Sub, Email: claims.Email, Name: claims.Name, AvatarURL: claims.Picture, }, nil } type githubUserInfo struct { ID int `json:"id"` Login string `json:"login"` Name string `json:"name"` Email string `json:"email"` AvatarURL string `json:"avatar_url"` } type githubEmail struct { Email string `json:"email"` Primary bool `json:"primary"` Verified bool `json:"verified"` } // fetchGitHubUserInfoAsOAuth fetches GitHub user info and returns it as oauthUserInfo. func fetchGitHubUserInfoAsOAuth(ctx context.Context, accessToken string) (*oauthUserInfo, error) { info, err := fetchGitHubUserInfo(ctx, accessToken) if err != nil { return nil, err } return &oauthUserInfo{ Subject: fmt.Sprintf("%d", info.ID), Email: info.Email, Name: info.Name, AvatarURL: info.AvatarURL, }, nil } func fetchGitHubUserInfo(ctx context.Context, accessToken string) (*githubUserInfo, error) { client := &http.Client{Timeout: 10 * time.Second} req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user", nil) req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Accept", "application/json") resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return nil, err } var info githubUserInfo if err := json.Unmarshal(body, &info); err != nil { return nil, err } // If no public email, fetch from /user/emails if info.Email == "" { info.Email, _ = fetchGitHubPrimaryEmail(ctx, accessToken) } return &info, nil } func fetchGitHubPrimaryEmail(ctx context.Context, accessToken string) (string, error) { client := &http.Client{Timeout: 10 * time.Second} req, _ := http.NewRequestWithContext(ctx, "GET", "https://api.github.com/user/emails", nil) req.Header.Set("Authorization", "Bearer "+accessToken) req.Header.Set("Accept", "application/json") resp, err := client.Do(req) if err != nil { return "", err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { return "", err } var emails []githubEmail if err := json.Unmarshal(body, &emails); err != nil { return "", err } for _, e := range emails { if e.Primary && e.Verified { return e.Email, nil } } // Fall back to first verified email for _, e := range emails { if e.Verified { return e.Email, nil } } return "", fmt.Errorf("no verified email found") } func upsertOAuthUser(db *gorm.DB, provider string, info *oauthUserInfo, adminEmail, registrationMode string) (*User, error) { // Normalize email from provider (#10) if info.Email != "" { info.Email = strings.ToLower(strings.TrimSpace(info.Email)) } var user User err := db.Where("provider = ? AND subject = ?", provider, info.Subject).First(&user).Error if err == nil { // Existing user — update profile fields user.Name = info.Name user.AvatarURL = info.AvatarURL if info.Email != "" { user.Email = info.Email } db.Save(&user) return &user, nil } // New user — empty registration mode defaults to "approval" effectiveMode := registrationMode if effectiveMode == "" { effectiveMode = "approval" } status := StatusActive if effectiveMode == "approval" || effectiveMode == "invite" { status = StatusPending } role := AssignRole(db, info.Email, adminEmail) // First user is always active regardless of registration mode if role == RoleAdmin { status = StatusActive } user = User{ ID: uuid.New().String(), Email: info.Email, Name: info.Name, AvatarURL: info.AvatarURL, Provider: provider, Subject: info.Subject, Role: role, Status: status, } if err := db.Create(&user).Error; err != nil { return nil, err } return &user, nil } func generateState() (string, error) { b := make([]byte, 16) if _, err := rand.Read(b); err != nil { return "", err } return hex.EncodeToString(b), nil } ================================================ FILE: core/http/auth/password.go ================================================ package auth import "golang.org/x/crypto/bcrypt" // HashPassword returns a bcrypt hash of the given password. func HashPassword(password string) (string, error) { bytes, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) return string(bytes), err } // CheckPassword compares a bcrypt hash with a plaintext password. func CheckPassword(hash, password string) bool { return bcrypt.CompareHashAndPassword([]byte(hash), []byte(password)) == nil } ================================================ FILE: core/http/auth/permissions.go ================================================ package auth import ( "github.com/google/uuid" "github.com/labstack/echo/v4" "gorm.io/gorm" ) const contextKeyPermissions = "auth_permissions" // GetCachedUserPermissions returns the user's permission record, using a // request-scoped cache stored in the echo context. This avoids duplicate // DB lookups when multiple middlewares (RequireRouteFeature, RequireModelAccess) // both need permissions in the same request. func GetCachedUserPermissions(c echo.Context, db *gorm.DB, userID string) (*UserPermission, error) { if perm, ok := c.Get(contextKeyPermissions).(*UserPermission); ok && perm != nil { return perm, nil } perm, err := GetUserPermissions(db, userID) if err != nil { return nil, err } c.Set(contextKeyPermissions, perm) return perm, nil } // Feature name constants — all code must use these, never bare strings. const ( // Agent features (default OFF for new users) FeatureAgents = "agents" FeatureSkills = "skills" FeatureCollections = "collections" FeatureMCPJobs = "mcp_jobs" // API features (default ON for new users) FeatureChat = "chat" FeatureImages = "images" FeatureAudioSpeech = "audio_speech" FeatureAudioTranscription = "audio_transcription" FeatureVAD = "vad" FeatureDetection = "detection" FeatureVideo = "video" FeatureEmbeddings = "embeddings" FeatureSound = "sound" FeatureRealtime = "realtime" FeatureRerank = "rerank" FeatureTokenize = "tokenize" FeatureMCP = "mcp" FeatureStores = "stores" ) // AgentFeatures lists agent-related features (default OFF). var AgentFeatures = []string{FeatureAgents, FeatureSkills, FeatureCollections, FeatureMCPJobs} // APIFeatures lists API endpoint features (default ON). var APIFeatures = []string{ FeatureChat, FeatureImages, FeatureAudioSpeech, FeatureAudioTranscription, FeatureVAD, FeatureDetection, FeatureVideo, FeatureEmbeddings, FeatureSound, FeatureRealtime, FeatureRerank, FeatureTokenize, FeatureMCP, FeatureStores, } // AllFeatures lists all known features (used by UI and validation). var AllFeatures = append(append([]string{}, AgentFeatures...), APIFeatures...) // defaultOnFeatures is the set of features that default to ON when absent from a user's permission map. var defaultOnFeatures = func() map[string]bool { m := map[string]bool{} for _, f := range APIFeatures { m[f] = true } return m }() // isDefaultOnFeature returns true if the feature defaults to ON when not explicitly set. func isDefaultOnFeature(feature string) bool { return defaultOnFeatures[feature] } // GetUserPermissions returns the permission record for a user, creating a default // (empty map = all disabled) if none exists. func GetUserPermissions(db *gorm.DB, userID string) (*UserPermission, error) { var perm UserPermission err := db.Where("user_id = ?", userID).First(&perm).Error if err == gorm.ErrRecordNotFound { perm = UserPermission{ ID: uuid.New().String(), UserID: userID, Permissions: PermissionMap{}, } if err := db.Create(&perm).Error; err != nil { return nil, err } return &perm, nil } if err != nil { return nil, err } return &perm, nil } // UpdateUserPermissions upserts the permission map for a user. func UpdateUserPermissions(db *gorm.DB, userID string, perms PermissionMap) error { var perm UserPermission err := db.Where("user_id = ?", userID).First(&perm).Error if err == gorm.ErrRecordNotFound { perm = UserPermission{ ID: uuid.New().String(), UserID: userID, Permissions: perms, } return db.Create(&perm).Error } if err != nil { return err } perm.Permissions = perms return db.Save(&perm).Error } // HasFeatureAccess returns true if the user is an admin or has the given feature enabled. // When a feature key is absent from the user's permission map, it checks whether the // feature defaults to ON (API features) or OFF (agent features) for backward compatibility. func HasFeatureAccess(db *gorm.DB, user *User, feature string) bool { if user == nil { return false } if user.Role == RoleAdmin { return true } perm, err := GetUserPermissions(db, user.ID) if err != nil { return false } val, exists := perm.Permissions[feature] if !exists { return isDefaultOnFeature(feature) } return val } // GetPermissionMapForUser returns the effective permission map for a user. // Admins get all features as true (virtual). // For regular users, absent keys are filled with their defaults so the // UI/API always returns a complete picture. func GetPermissionMapForUser(db *gorm.DB, user *User) PermissionMap { if user == nil { return PermissionMap{} } if user.Role == RoleAdmin { m := PermissionMap{} for _, f := range AllFeatures { m[f] = true } return m } perm, err := GetUserPermissions(db, user.ID) if err != nil { return PermissionMap{} } // Fill in defaults for absent keys effective := PermissionMap{} for _, f := range AllFeatures { val, exists := perm.Permissions[f] if exists { effective[f] = val } else { effective[f] = isDefaultOnFeature(f) } } return effective } // GetModelAllowlist returns the model allowlist for a user. func GetModelAllowlist(db *gorm.DB, userID string) ModelAllowlist { perm, err := GetUserPermissions(db, userID) if err != nil { return ModelAllowlist{} } return perm.AllowedModels } // UpdateModelAllowlist updates the model allowlist for a user. func UpdateModelAllowlist(db *gorm.DB, userID string, allowlist ModelAllowlist) error { perm, err := GetUserPermissions(db, userID) if err != nil { return err } perm.AllowedModels = allowlist return db.Save(perm).Error } // IsModelAllowed returns true if the user is allowed to use the given model. // Admins always have access. If the allowlist is not enabled, all models are allowed. func IsModelAllowed(db *gorm.DB, user *User, modelName string) bool { if user == nil { return false } if user.Role == RoleAdmin { return true } allowlist := GetModelAllowlist(db, user.ID) if !allowlist.Enabled { return true } for _, m := range allowlist.Models { if m == modelName { return true } } return false } ================================================ FILE: core/http/auth/roles.go ================================================ package auth import ( "fmt" "strings" "time" "gorm.io/gorm" ) const ( RoleAdmin = "admin" RoleUser = "user" StatusActive = "active" StatusPending = "pending" StatusDisabled = "disabled" ) // AssignRole determines the role for a new user. // First user in the database becomes admin. If adminEmail is set and matches, // the user becomes admin. Otherwise, the user gets the "user" role. // Must be called within a transaction that also creates the user to prevent // race conditions on the first-user admin assignment. func AssignRole(tx *gorm.DB, email, adminEmail string) string { var count int64 tx.Model(&User{}).Count(&count) if count == 0 { return RoleAdmin } if adminEmail != "" && strings.EqualFold(email, adminEmail) { return RoleAdmin } return RoleUser } // MaybePromote promotes a user to admin on login if their email matches // adminEmail. It does not demote existing admins. Returns true if the user // was promoted. func MaybePromote(db *gorm.DB, user *User, adminEmail string) bool { if user.Role == RoleAdmin { return false } if adminEmail != "" && strings.EqualFold(user.Email, adminEmail) { user.Role = RoleAdmin db.Model(user).Update("role", RoleAdmin) return true } return false } // ValidateInvite checks that an invite code exists, is unused, and has not expired. // The code is hashed with HMAC-SHA256 before lookup. func ValidateInvite(db *gorm.DB, code, hmacSecret string) (*InviteCode, error) { hash := HashAPIKey(code, hmacSecret) var invite InviteCode if err := db.Where("code = ?", hash).First(&invite).Error; err != nil { return nil, fmt.Errorf("invite code not found") } if invite.UsedBy != nil { return nil, fmt.Errorf("invite code already used") } if time.Now().After(invite.ExpiresAt) { return nil, fmt.Errorf("invite code expired") } return &invite, nil } // ConsumeInvite marks an invite code as used by the given user. func ConsumeInvite(db *gorm.DB, invite *InviteCode, userID string) { now := time.Now() invite.UsedBy = &userID invite.UsedAt = &now db.Save(invite) } // NeedsInviteOrApproval returns true if registration gating applies for the given mode. // Admins (first user or matching adminEmail) are never gated. // Must be called within a transaction that also creates the user. func NeedsInviteOrApproval(tx *gorm.DB, email, adminEmail, registrationMode string) bool { // Empty registration mode defaults to "approval" if registrationMode == "" { registrationMode = "approval" } if registrationMode != "approval" && registrationMode != "invite" { return false } // Admin email is never gated if adminEmail != "" && strings.EqualFold(email, adminEmail) { return false } // First user is never gated var count int64 tx.Model(&User{}).Count(&count) if count == 0 { return false } return true } ================================================ FILE: core/http/auth/roles_test.go ================================================ //go:build auth package auth_test import ( "github.com/mudler/LocalAI/core/http/auth" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gorm.io/gorm" ) var _ = Describe("Roles", func() { var db *gorm.DB BeforeEach(func() { db = testDB() }) Describe("AssignRole", func() { It("returns admin for the first user (empty DB)", func() { role := auth.AssignRole(db, "first@example.com", "") Expect(role).To(Equal(auth.RoleAdmin)) }) It("returns user for the second user", func() { createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub) role := auth.AssignRole(db, "second@example.com", "") Expect(role).To(Equal(auth.RoleUser)) }) It("returns admin when email matches adminEmail", func() { createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub) role := auth.AssignRole(db, "admin@example.com", "admin@example.com") Expect(role).To(Equal(auth.RoleAdmin)) }) It("is case-insensitive for admin email match", func() { createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub) role := auth.AssignRole(db, "Admin@Example.COM", "admin@example.com") Expect(role).To(Equal(auth.RoleAdmin)) }) It("returns user when email does not match adminEmail", func() { createTestUser(db, "first@example.com", auth.RoleAdmin, auth.ProviderGitHub) role := auth.AssignRole(db, "other@example.com", "admin@example.com") Expect(role).To(Equal(auth.RoleUser)) }) }) Describe("MaybePromote", func() { It("promotes user to admin when email matches", func() { user := createTestUser(db, "promoted@example.com", auth.RoleUser, auth.ProviderGitHub) promoted := auth.MaybePromote(db, user, "promoted@example.com") Expect(promoted).To(BeTrue()) Expect(user.Role).To(Equal(auth.RoleAdmin)) // Verify in DB var dbUser auth.User db.First(&dbUser, "id = ?", user.ID) Expect(dbUser.Role).To(Equal(auth.RoleAdmin)) }) It("does not promote when email does not match", func() { user := createTestUser(db, "user@example.com", auth.RoleUser, auth.ProviderGitHub) promoted := auth.MaybePromote(db, user, "admin@example.com") Expect(promoted).To(BeFalse()) Expect(user.Role).To(Equal(auth.RoleUser)) }) It("does not demote an existing admin", func() { user := createTestUser(db, "admin@example.com", auth.RoleAdmin, auth.ProviderGitHub) promoted := auth.MaybePromote(db, user, "other@example.com") Expect(promoted).To(BeFalse()) Expect(user.Role).To(Equal(auth.RoleAdmin)) }) }) }) ================================================ FILE: core/http/auth/session.go ================================================ package auth import ( "crypto/rand" "encoding/hex" "fmt" "net/http" "time" "github.com/labstack/echo/v4" "gorm.io/gorm" ) const ( sessionDuration = 30 * 24 * time.Hour // 30 days sessionIDBytes = 32 // 32 bytes = 64 hex chars sessionCookie = "session" sessionRotationInterval = 1 * time.Hour ) // CreateSession creates a new session for the given user, returning the // plaintext token (64-char hex string). The stored session ID is the // HMAC-SHA256 hash of the token. func CreateSession(db *gorm.DB, userID, hmacSecret string) (string, error) { b := make([]byte, sessionIDBytes) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("failed to generate session ID: %w", err) } plaintext := hex.EncodeToString(b) hash := HashAPIKey(plaintext, hmacSecret) now := time.Now() session := Session{ ID: hash, UserID: userID, ExpiresAt: now.Add(sessionDuration), RotatedAt: now, } if err := db.Create(&session).Error; err != nil { return "", fmt.Errorf("failed to create session: %w", err) } return plaintext, nil } // ValidateSession hashes the plaintext token and looks up the session. // Returns the associated user and session, or (nil, nil) if not found/expired. func ValidateSession(db *gorm.DB, token, hmacSecret string) (*User, *Session) { hash := HashAPIKey(token, hmacSecret) var session Session if err := db.Preload("User").Where("id = ? AND expires_at > ?", hash, time.Now()).First(&session).Error; err != nil { return nil, nil } if session.User.Status != StatusActive { return nil, nil } return &session.User, &session } // DeleteSession removes a session by hashing the plaintext token. func DeleteSession(db *gorm.DB, token, hmacSecret string) error { hash := HashAPIKey(token, hmacSecret) return db.Where("id = ?", hash).Delete(&Session{}).Error } // CleanExpiredSessions removes all sessions that have passed their expiry time. func CleanExpiredSessions(db *gorm.DB) error { return db.Where("expires_at < ?", time.Now()).Delete(&Session{}).Error } // DeleteUserSessions removes all sessions for the given user. func DeleteUserSessions(db *gorm.DB, userID string) error { return db.Where("user_id = ?", userID).Delete(&Session{}).Error } // RotateSession creates a new session for the same user, deletes the old one, // and returns the new plaintext token. func RotateSession(db *gorm.DB, oldSession *Session, hmacSecret string) (string, error) { b := make([]byte, sessionIDBytes) if _, err := rand.Read(b); err != nil { return "", fmt.Errorf("failed to generate session ID: %w", err) } plaintext := hex.EncodeToString(b) hash := HashAPIKey(plaintext, hmacSecret) now := time.Now() newSession := Session{ ID: hash, UserID: oldSession.UserID, ExpiresAt: oldSession.ExpiresAt, RotatedAt: now, } err := db.Transaction(func(tx *gorm.DB) error { if err := tx.Create(&newSession).Error; err != nil { return err } return tx.Where("id = ?", oldSession.ID).Delete(&Session{}).Error }) if err != nil { return "", fmt.Errorf("failed to rotate session: %w", err) } return plaintext, nil } // MaybeRotateSession checks if the session should be rotated and does so if needed. // Called from the auth middleware after successful cookie-based authentication. func MaybeRotateSession(c echo.Context, db *gorm.DB, session *Session, hmacSecret string) { if session == nil { return } rotatedAt := session.RotatedAt if rotatedAt.IsZero() { rotatedAt = session.CreatedAt } if time.Since(rotatedAt) < sessionRotationInterval { return } newToken, err := RotateSession(db, session, hmacSecret) if err != nil { // Rotation failure is non-fatal; the old session remains valid return } SetSessionCookie(c, newToken) } // isSecure returns true when the request arrived over HTTPS, either directly // or via a reverse proxy that sets X-Forwarded-Proto. func isSecure(c echo.Context) bool { return c.Scheme() == "https" } // SetSessionCookie sets the session cookie on the response. func SetSessionCookie(c echo.Context, sessionID string) { cookie := &http.Cookie{ Name: sessionCookie, Value: sessionID, Path: "/", HttpOnly: true, Secure: isSecure(c), SameSite: http.SameSiteLaxMode, MaxAge: int(sessionDuration.Seconds()), } c.SetCookie(cookie) } // SetTokenCookie sets an httpOnly "token" cookie for legacy API key auth. func SetTokenCookie(c echo.Context, token string) { cookie := &http.Cookie{ Name: "token", Value: token, Path: "/", HttpOnly: true, Secure: isSecure(c), SameSite: http.SameSiteLaxMode, MaxAge: int(sessionDuration.Seconds()), } c.SetCookie(cookie) } // ClearSessionCookie clears the session cookie. func ClearSessionCookie(c echo.Context) { cookie := &http.Cookie{ Name: sessionCookie, Value: "", Path: "/", HttpOnly: true, Secure: isSecure(c), SameSite: http.SameSiteLaxMode, MaxAge: -1, } c.SetCookie(cookie) } ================================================ FILE: core/http/auth/session_test.go ================================================ //go:build auth package auth_test import ( "time" "github.com/mudler/LocalAI/core/http/auth" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "gorm.io/gorm" ) var _ = Describe("Sessions", func() { var ( db *gorm.DB user *auth.User ) // Use empty HMAC secret for basic tests hmacSecret := "" BeforeEach(func() { db = testDB() user = createTestUser(db, "session@example.com", auth.RoleUser, auth.ProviderGitHub) }) Describe("CreateSession", func() { It("creates a session and returns 64-char hex plaintext token", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) Expect(token).To(HaveLen(64)) }) It("stores the hash (not plaintext) in the DB", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) hash := auth.HashAPIKey(token, hmacSecret) var session auth.Session err = db.First(&session, "id = ?", hash).Error Expect(err).ToNot(HaveOccurred()) Expect(session.UserID).To(Equal(user.ID)) // The plaintext token should NOT be stored as the ID Expect(session.ID).ToNot(Equal(token)) Expect(session.ID).To(Equal(hash)) }) It("sets expiry to approximately 30 days from now", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) hash := auth.HashAPIKey(token, hmacSecret) var session auth.Session db.First(&session, "id = ?", hash) expectedExpiry := time.Now().Add(30 * 24 * time.Hour) Expect(session.ExpiresAt).To(BeTemporally("~", expectedExpiry, time.Minute)) }) It("sets RotatedAt on creation", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) hash := auth.HashAPIKey(token, hmacSecret) var session auth.Session db.First(&session, "id = ?", hash) Expect(session.RotatedAt).To(BeTemporally("~", time.Now(), time.Minute)) }) It("associates session with correct user", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) hash := auth.HashAPIKey(token, hmacSecret) var session auth.Session db.First(&session, "id = ?", hash) Expect(session.UserID).To(Equal(user.ID)) }) }) Describe("ValidateSession", func() { It("returns user for valid session", func() { token := createTestSession(db, user.ID) found, session := auth.ValidateSession(db, token, hmacSecret) Expect(found).ToNot(BeNil()) Expect(found.ID).To(Equal(user.ID)) Expect(session).ToNot(BeNil()) }) It("returns nil for non-existent session", func() { found, session := auth.ValidateSession(db, "nonexistent-session-id", hmacSecret) Expect(found).To(BeNil()) Expect(session).To(BeNil()) }) It("returns nil for expired session", func() { token := createTestSession(db, user.ID) hash := auth.HashAPIKey(token, hmacSecret) // Manually expire the session db.Model(&auth.Session{}).Where("id = ?", hash). Update("expires_at", time.Now().Add(-1*time.Hour)) found, _ := auth.ValidateSession(db, token, hmacSecret) Expect(found).To(BeNil()) }) }) Describe("DeleteSession", func() { It("removes the session from DB", func() { token := createTestSession(db, user.ID) err := auth.DeleteSession(db, token, hmacSecret) Expect(err).ToNot(HaveOccurred()) found, _ := auth.ValidateSession(db, token, hmacSecret) Expect(found).To(BeNil()) }) It("does not error on non-existent session", func() { err := auth.DeleteSession(db, "nonexistent", hmacSecret) Expect(err).ToNot(HaveOccurred()) }) }) Describe("CleanExpiredSessions", func() { It("removes expired sessions", func() { token := createTestSession(db, user.ID) hash := auth.HashAPIKey(token, hmacSecret) // Manually expire the session db.Model(&auth.Session{}).Where("id = ?", hash). Update("expires_at", time.Now().Add(-1*time.Hour)) err := auth.CleanExpiredSessions(db) Expect(err).ToNot(HaveOccurred()) var count int64 db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count) Expect(count).To(Equal(int64(0))) }) It("keeps active sessions", func() { token := createTestSession(db, user.ID) hash := auth.HashAPIKey(token, hmacSecret) err := auth.CleanExpiredSessions(db) Expect(err).ToNot(HaveOccurred()) var count int64 db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count) Expect(count).To(Equal(int64(1))) }) }) Describe("RotateSession", func() { It("creates a new session and deletes the old one", func() { token := createTestSession(db, user.ID) hash := auth.HashAPIKey(token, hmacSecret) // Get the old session var oldSession auth.Session db.First(&oldSession, "id = ?", hash) newToken, err := auth.RotateSession(db, &oldSession, hmacSecret) Expect(err).ToNot(HaveOccurred()) Expect(newToken).To(HaveLen(64)) Expect(newToken).ToNot(Equal(token)) // Old session should be gone var count int64 db.Model(&auth.Session{}).Where("id = ?", hash).Count(&count) Expect(count).To(Equal(int64(0))) // New session should exist and validate found, _ := auth.ValidateSession(db, newToken, hmacSecret) Expect(found).ToNot(BeNil()) Expect(found.ID).To(Equal(user.ID)) }) It("preserves user ID and expiry", func() { token := createTestSession(db, user.ID) hash := auth.HashAPIKey(token, hmacSecret) var oldSession auth.Session db.First(&oldSession, "id = ?", hash) newToken, err := auth.RotateSession(db, &oldSession, hmacSecret) Expect(err).ToNot(HaveOccurred()) newHash := auth.HashAPIKey(newToken, hmacSecret) var newSession auth.Session db.First(&newSession, "id = ?", newHash) Expect(newSession.UserID).To(Equal(oldSession.UserID)) Expect(newSession.ExpiresAt).To(BeTemporally("~", oldSession.ExpiresAt, time.Second)) }) }) Context("with HMAC secret", func() { hmacSecret := "test-hmac-secret-123" It("creates and validates sessions with HMAC secret", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) found, session := auth.ValidateSession(db, token, hmacSecret) Expect(found).ToNot(BeNil()) Expect(found.ID).To(Equal(user.ID)) Expect(session).ToNot(BeNil()) }) It("does not validate with wrong HMAC secret", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) found, _ := auth.ValidateSession(db, token, "wrong-secret") Expect(found).To(BeNil()) }) It("does not validate with empty HMAC secret", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) found, _ := auth.ValidateSession(db, token, "") Expect(found).To(BeNil()) }) It("session created with empty secret does not validate with non-empty secret", func() { token, err := auth.CreateSession(db, user.ID, "") Expect(err).ToNot(HaveOccurred()) found, _ := auth.ValidateSession(db, token, hmacSecret) Expect(found).To(BeNil()) }) It("deletes session with correct HMAC secret", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) err = auth.DeleteSession(db, token, hmacSecret) Expect(err).ToNot(HaveOccurred()) found, _ := auth.ValidateSession(db, token, hmacSecret) Expect(found).To(BeNil()) }) It("rotates session with HMAC secret", func() { token, err := auth.CreateSession(db, user.ID, hmacSecret) Expect(err).ToNot(HaveOccurred()) hash := auth.HashAPIKey(token, hmacSecret) var oldSession auth.Session db.First(&oldSession, "id = ?", hash) newToken, err := auth.RotateSession(db, &oldSession, hmacSecret) Expect(err).ToNot(HaveOccurred()) // Old token should not validate found, _ := auth.ValidateSession(db, token, hmacSecret) Expect(found).To(BeNil()) // New token should validate found, _ = auth.ValidateSession(db, newToken, hmacSecret) Expect(found).ToNot(BeNil()) Expect(found.ID).To(Equal(user.ID)) }) }) }) ================================================ FILE: core/http/auth/usage.go ================================================ package auth import ( "fmt" "strings" "time" "gorm.io/gorm" ) // UsageRecord represents a single API request's token usage. type UsageRecord struct { ID uint `gorm:"primaryKey;autoIncrement"` UserID string `gorm:"size:36;index:idx_usage_user_time"` UserName string `gorm:"size:255"` Model string `gorm:"size:255;index"` Endpoint string `gorm:"size:255"` PromptTokens int64 CompletionTokens int64 TotalTokens int64 Duration int64 // milliseconds CreatedAt time.Time `gorm:"index:idx_usage_user_time"` } // RecordUsage inserts a usage record. func RecordUsage(db *gorm.DB, record *UsageRecord) error { return db.Create(record).Error } // UsageBucket is an aggregated time bucket for the dashboard. type UsageBucket struct { Bucket string `json:"bucket"` Model string `json:"model"` UserID string `json:"user_id,omitempty"` UserName string `json:"user_name,omitempty"` PromptTokens int64 `json:"prompt_tokens"` CompletionTokens int64 `json:"completion_tokens"` TotalTokens int64 `json:"total_tokens"` RequestCount int64 `json:"request_count"` } // UsageTotals is a summary of all usage. type UsageTotals struct { PromptTokens int64 `json:"prompt_tokens"` CompletionTokens int64 `json:"completion_tokens"` TotalTokens int64 `json:"total_tokens"` RequestCount int64 `json:"request_count"` } // periodToWindow returns the time window and SQL date format for a period. func periodToWindow(period string, isSQLite bool) (time.Time, string) { now := time.Now() var since time.Time var dateFmt string switch period { case "day": since = now.Add(-24 * time.Hour) if isSQLite { dateFmt = "strftime('%Y-%m-%d %H:00', created_at)" } else { dateFmt = "to_char(date_trunc('hour', created_at), 'YYYY-MM-DD HH24:00')" } case "week": since = now.Add(-7 * 24 * time.Hour) if isSQLite { dateFmt = "strftime('%Y-%m-%d', created_at)" } else { dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')" } case "all": since = time.Time{} // zero time = no filter if isSQLite { dateFmt = "strftime('%Y-%m', created_at)" } else { dateFmt = "to_char(date_trunc('month', created_at), 'YYYY-MM')" } default: // "month" since = now.Add(-30 * 24 * time.Hour) if isSQLite { dateFmt = "strftime('%Y-%m-%d', created_at)" } else { dateFmt = "to_char(date_trunc('day', created_at), 'YYYY-MM-DD')" } } return since, dateFmt } func isSQLiteDB(db *gorm.DB) bool { return strings.Contains(db.Dialector.Name(), "sqlite") } // GetUserUsage returns aggregated usage for a single user. func GetUserUsage(db *gorm.DB, userID, period string) ([]UsageBucket, error) { sqlite := isSQLiteDB(db) since, dateFmt := periodToWindow(period, sqlite) bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) query := db.Model(&UsageRecord{}). Select(bucketExpr+", model, "+ "SUM(prompt_tokens) as prompt_tokens, "+ "SUM(completion_tokens) as completion_tokens, "+ "SUM(total_tokens) as total_tokens, "+ "COUNT(*) as request_count"). Where("user_id = ?", userID). Group("bucket, model"). Order("bucket ASC") if !since.IsZero() { query = query.Where("created_at >= ?", since) } var buckets []UsageBucket if err := query.Find(&buckets).Error; err != nil { return nil, err } return buckets, nil } // GetAllUsage returns aggregated usage for all users (admin). Optional userID filter. func GetAllUsage(db *gorm.DB, period, userID string) ([]UsageBucket, error) { sqlite := isSQLiteDB(db) since, dateFmt := periodToWindow(period, sqlite) bucketExpr := fmt.Sprintf("%s as bucket", dateFmt) query := db.Model(&UsageRecord{}). Select(bucketExpr+", model, user_id, user_name, "+ "SUM(prompt_tokens) as prompt_tokens, "+ "SUM(completion_tokens) as completion_tokens, "+ "SUM(total_tokens) as total_tokens, "+ "COUNT(*) as request_count"). Group("bucket, model, user_id, user_name"). Order("bucket ASC") if !since.IsZero() { query = query.Where("created_at >= ?", since) } if userID != "" { query = query.Where("user_id = ?", userID) } var buckets []UsageBucket if err := query.Find(&buckets).Error; err != nil { return nil, err } return buckets, nil } ================================================ FILE: core/http/auth/usage_test.go ================================================ //go:build auth package auth_test import ( "time" "github.com/mudler/LocalAI/core/http/auth" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("Usage", func() { Describe("RecordUsage", func() { It("inserts a usage record", func() { db := testDB() record := &auth.UsageRecord{ UserID: "user-1", UserName: "Test User", Model: "gpt-4", Endpoint: "/v1/chat/completions", PromptTokens: 100, CompletionTokens: 50, TotalTokens: 150, Duration: 1200, CreatedAt: time.Now(), } err := auth.RecordUsage(db, record) Expect(err).ToNot(HaveOccurred()) Expect(record.ID).ToNot(BeZero()) }) }) Describe("GetUserUsage", func() { It("returns aggregated usage for a specific user", func() { db := testDB() // Insert records for two users for i := 0; i < 3; i++ { err := auth.RecordUsage(db, &auth.UsageRecord{ UserID: "user-a", UserName: "Alice", Model: "gpt-4", Endpoint: "/v1/chat/completions", PromptTokens: 100, TotalTokens: 150, CreatedAt: time.Now(), }) Expect(err).ToNot(HaveOccurred()) } err := auth.RecordUsage(db, &auth.UsageRecord{ UserID: "user-b", UserName: "Bob", Model: "gpt-4", PromptTokens: 200, TotalTokens: 300, CreatedAt: time.Now(), }) Expect(err).ToNot(HaveOccurred()) buckets, err := auth.GetUserUsage(db, "user-a", "month") Expect(err).ToNot(HaveOccurred()) Expect(buckets).ToNot(BeEmpty()) // All returned buckets should be for user-a's model totalPrompt := int64(0) for _, b := range buckets { totalPrompt += b.PromptTokens } Expect(totalPrompt).To(Equal(int64(300))) }) It("filters by period", func() { db := testDB() // Record in the past (beyond day window) err := auth.RecordUsage(db, &auth.UsageRecord{ UserID: "user-c", UserName: "Carol", Model: "gpt-4", PromptTokens: 100, TotalTokens: 100, CreatedAt: time.Now().Add(-48 * time.Hour), }) Expect(err).ToNot(HaveOccurred()) // Record now err = auth.RecordUsage(db, &auth.UsageRecord{ UserID: "user-c", UserName: "Carol", Model: "gpt-4", PromptTokens: 200, TotalTokens: 200, CreatedAt: time.Now(), }) Expect(err).ToNot(HaveOccurred()) // Day period should only include recent record buckets, err := auth.GetUserUsage(db, "user-c", "day") Expect(err).ToNot(HaveOccurred()) totalPrompt := int64(0) for _, b := range buckets { totalPrompt += b.PromptTokens } Expect(totalPrompt).To(Equal(int64(200))) // Month period should include both buckets, err = auth.GetUserUsage(db, "user-c", "month") Expect(err).ToNot(HaveOccurred()) totalPrompt = 0 for _, b := range buckets { totalPrompt += b.PromptTokens } Expect(totalPrompt).To(Equal(int64(300))) }) }) Describe("GetAllUsage", func() { It("returns usage for all users", func() { db := testDB() for _, uid := range []string{"user-x", "user-y"} { err := auth.RecordUsage(db, &auth.UsageRecord{ UserID: uid, UserName: uid, Model: "gpt-4", PromptTokens: 100, TotalTokens: 150, CreatedAt: time.Now(), }) Expect(err).ToNot(HaveOccurred()) } buckets, err := auth.GetAllUsage(db, "month", "") Expect(err).ToNot(HaveOccurred()) Expect(len(buckets)).To(BeNumerically(">=", 2)) }) It("filters by user ID when specified", func() { db := testDB() err := auth.RecordUsage(db, &auth.UsageRecord{ UserID: "user-p", UserName: "Pat", Model: "gpt-4", PromptTokens: 100, TotalTokens: 100, CreatedAt: time.Now(), }) Expect(err).ToNot(HaveOccurred()) err = auth.RecordUsage(db, &auth.UsageRecord{ UserID: "user-q", UserName: "Quinn", Model: "gpt-4", PromptTokens: 200, TotalTokens: 200, CreatedAt: time.Now(), }) Expect(err).ToNot(HaveOccurred()) buckets, err := auth.GetAllUsage(db, "month", "user-p") Expect(err).ToNot(HaveOccurred()) for _, b := range buckets { Expect(b.UserID).To(Equal("user-p")) } }) }) }) ================================================ FILE: core/http/endpoints/anthropic/messages.go ================================================ package anthropic import ( "encoding/json" "fmt" "strings" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // MessagesEndpoint is the Anthropic Messages API endpoint // https://docs.anthropic.com/claude/reference/messages_post // @Summary Generate a message response for the given messages and model. // @Param request body schema.AnthropicRequest true "query params" // @Success 200 {object} schema.AnthropicResponse "Response" // @Router /v1/messages [post] func MessagesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { id := uuid.New().String() input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.AnthropicRequest) if !ok || input.Model == "" { return sendAnthropicError(c, 400, "invalid_request_error", "model is required") } cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return sendAnthropicError(c, 400, "invalid_request_error", "model configuration not found") } if input.MaxTokens <= 0 { return sendAnthropicError(c, 400, "invalid_request_error", "max_tokens is required and must be greater than 0") } xlog.Debug("Anthropic Messages endpoint configuration read", "config", cfg) // Convert Anthropic messages to OpenAI format for internal processing openAIMessages := convertAnthropicToOpenAIMessages(input) // Convert Anthropic tools to internal Functions format funcs, shouldUseFn := convertAnthropicTools(input, cfg) // MCP injection: prompts, resources, and tools var mcpToolInfos []mcpTools.MCPToolInfo mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata) mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata) mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata) if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (cfg.MCP.Servers != "" || cfg.MCP.Stdio != "") { remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML() if mcpErr == nil { namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers) if sessErr == nil && len(namedSessions) > 0 { // Prompt injection if mcpPromptName != "" { prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions) if discErr == nil { promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs) if getErr == nil { var injected []schema.Message for _, pm := range promptMsgs { injected = append(injected, schema.Message{ Role: string(pm.Role), Content: mcpTools.PromptMessageToText(pm), }) } openAIMessages = append(injected, openAIMessages...) xlog.Debug("Anthropic MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected)) } else { xlog.Error("Failed to get MCP prompt", "error", getErr) } } } // Resource injection if len(mcpResourceURIs) > 0 { resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions) if discErr == nil { var resourceTexts []string for _, uri := range mcpResourceURIs { content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri) if readErr != nil { xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri) continue } name := uri for _, r := range resources { if r.URI == uri { name = r.Name break } } resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content)) } if len(resourceTexts) > 0 && len(openAIMessages) > 0 { lastIdx := len(openAIMessages) - 1 suffix := "\n\n" + strings.Join(resourceTexts, "\n\n") switch ct := openAIMessages[lastIdx].Content.(type) { case string: openAIMessages[lastIdx].Content = ct + suffix default: openAIMessages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix) } xlog.Debug("Anthropic MCP resources injected", "count", len(resourceTexts)) } } } // Tool injection if len(mcpServers) > 0 { discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions) if discErr == nil { mcpToolInfos = discovered for _, ti := range mcpToolInfos { funcs = append(funcs, ti.Function) } shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions() xlog.Debug("Anthropic MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs)) } else { xlog.Error("Failed to discover MCP tools", "error", discErr) } } } } else { xlog.Error("Failed to parse MCP config", "error", mcpErr) } } // Create an OpenAI-compatible request for internal processing openAIReq := &schema.OpenAIRequest{ PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{Model: input.Model}, Temperature: input.Temperature, TopK: input.TopK, TopP: input.TopP, Maxtokens: &input.MaxTokens, }, Messages: openAIMessages, Stream: input.Stream, Context: input.Context, Cancel: input.Cancel, } // Set stop sequences if len(input.StopSequences) > 0 { openAIReq.Stop = input.StopSequences } // Merge config settings if input.Temperature != nil { cfg.Temperature = input.Temperature } if input.TopK != nil { cfg.TopK = input.TopK } if input.TopP != nil { cfg.TopP = input.TopP } cfg.Maxtokens = &input.MaxTokens if len(input.StopSequences) > 0 { cfg.StopWords = append(cfg.StopWords, input.StopSequences...) } // Template the prompt with tools if available predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) xlog.Debug("Anthropic Messages - Prompt (after templating)", "prompt", predInput) if input.Stream { return handleAnthropicStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) } return handleAnthropicNonStream(c, id, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) } } func handleAnthropicNonStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error { mcpMaxIterations := 10 if cfg.Agent.MaxIterations > 0 { mcpMaxIterations = cfg.Agent.MaxIterations } hasMCPTools := len(mcpToolInfos) > 0 for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { // Re-template on each MCP iteration since messages may have changed if mcpIteration > 0 { predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) xlog.Debug("Anthropic MCP re-templating", "iteration", mcpIteration, "prompt_len", len(predInput)) } // Populate openAIReq fields for ComputeChoices openAIReq.Tools = convertFuncsToOpenAITools(funcs) openAIReq.ToolsChoice = input.ToolChoice openAIReq.Metadata = input.Metadata var result string cb := func(s string, c *[]schema.Choice) { result = s } _, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil) if err != nil { xlog.Error("Anthropic model inference failed", "error", err) return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) } // Try pre-parsed tool calls from C++ autoparser first, fall back to text parsing var toolCalls []functions.FuncCallResults if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { xlog.Debug("[ChatDeltas] Anthropic: using pre-parsed tool calls", "count", len(deltaToolCalls)) toolCalls = deltaToolCalls } else { xlog.Debug("[ChatDeltas] Anthropic: no pre-parsed tool calls, falling back to Go-side text parsing") toolCalls = functions.ParseFunctionCall(result, cfg.FunctionsConfig) } // MCP server-side tool execution: if any tool calls are MCP tools, execute and loop if hasMCPTools && shouldUseFn && len(toolCalls) > 0 { var hasMCPCalls bool for _, tc := range toolCalls { if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { hasMCPCalls = true break } } if hasMCPCalls { // Append assistant message with tool_calls to conversation assistantMsg := schema.Message{ Role: "assistant", Content: result, } for i, tc := range toolCalls { toolCallID := tc.ID if toolCallID == "" { toolCallID = fmt.Sprintf("toolu_%s_%d", id, i) } assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, schema.ToolCall{ Index: i, ID: toolCallID, Type: "function", FunctionCall: schema.FunctionCall{ Name: tc.Name, Arguments: tc.Arguments, }, }) } openAIReq.Messages = append(openAIReq.Messages, assistantMsg) // Execute each MCP tool call and append results for _, tc := range assistantMsg.ToolCalls { if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool (Anthropic)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration) toolResult, toolErr := mcpTools.ExecuteMCPToolCall( c.Request().Context(), mcpToolInfos, tc.FunctionCall.Name, tc.FunctionCall.Arguments, ) if toolErr != nil { xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) toolResult = fmt.Sprintf("Error: %v", toolErr) } openAIReq.Messages = append(openAIReq.Messages, schema.Message{ Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tc.ID, Name: tc.FunctionCall.Name, }) } xlog.Debug("Anthropic MCP tools executed, re-running inference", "iteration", mcpIteration) continue // next MCP iteration } } // No MCP tools to execute, build and return response var contentBlocks []schema.AnthropicContentBlock var stopReason string if shouldUseFn && len(toolCalls) > 0 { stopReason = "tool_use" for _, tc := range toolCalls { var inputArgs map[string]interface{} if err := json.Unmarshal([]byte(tc.Arguments), &inputArgs); err != nil { xlog.Warn("Failed to parse tool call arguments as JSON", "error", err, "args", tc.Arguments) inputArgs = map[string]interface{}{"raw": tc.Arguments} } contentBlocks = append(contentBlocks, schema.AnthropicContentBlock{ Type: "tool_use", ID: fmt.Sprintf("toolu_%s_%d", id, len(contentBlocks)), Name: tc.Name, Input: inputArgs, }) } textContent := functions.ParseTextContent(result, cfg.FunctionsConfig) if textContent != "" { contentBlocks = append([]schema.AnthropicContentBlock{{Type: "text", Text: textContent}}, contentBlocks...) } } else { stopReason = "end_turn" contentBlocks = []schema.AnthropicContentBlock{ {Type: "text", Text: result}, } } resp := &schema.AnthropicResponse{ ID: fmt.Sprintf("msg_%s", id), Type: "message", Role: "assistant", Model: input.Model, StopReason: &stopReason, Content: contentBlocks, Usage: schema.AnthropicUsage{ InputTokens: tokenUsage.Prompt, OutputTokens: tokenUsage.Completion, }, } if respData, err := json.Marshal(resp); err == nil { xlog.Debug("Anthropic Response", "response", string(respData)) } return c.JSON(200, resp) } // end MCP iteration loop return sendAnthropicError(c, 500, "api_error", "MCP iteration limit reached") } func handleAnthropicStream(c echo.Context, id string, input *schema.AnthropicRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") // Send message_start event messageStart := schema.AnthropicStreamEvent{ Type: "message_start", Message: &schema.AnthropicStreamMessage{ ID: fmt.Sprintf("msg_%s", id), Type: "message", Role: "assistant", Content: []schema.AnthropicContentBlock{}, Model: input.Model, Usage: schema.AnthropicUsage{InputTokens: 0, OutputTokens: 0}, }, } sendAnthropicSSE(c, messageStart) mcpMaxIterations := 10 if cfg.Agent.MaxIterations > 0 { mcpMaxIterations = cfg.Agent.MaxIterations } hasMCPTools := len(mcpToolInfos) > 0 for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { // Re-template on MCP iterations if mcpIteration > 0 { predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) xlog.Debug("Anthropic MCP stream re-templating", "iteration", mcpIteration) } // Track accumulated content for tool call detection accumulatedContent := "" currentBlockIndex := 0 inToolCall := false toolCallsEmitted := 0 // Send initial content_block_start event contentBlockStart := schema.AnthropicStreamEvent{ Type: "content_block_start", Index: currentBlockIndex, ContentBlock: &schema.AnthropicContentBlock{Type: "text", Text: ""}, } sendAnthropicSSE(c, contentBlockStart) // Collect tool calls for MCP execution var collectedToolCalls []functions.FuncCallResults tokenCallback := func(token string, usage backend.TokenUsage) bool { accumulatedContent += token if shouldUseFn { cleanedResult := functions.CleanupLLMResult(accumulatedContent, cfg.FunctionsConfig) toolCalls := functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) if len(toolCalls) > toolCallsEmitted { if !inToolCall && currentBlockIndex == 0 { sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_stop", Index: currentBlockIndex, }) currentBlockIndex++ inToolCall = true } for i := toolCallsEmitted; i < len(toolCalls); i++ { tc := toolCalls[i] sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_start", Index: currentBlockIndex, ContentBlock: &schema.AnthropicContentBlock{ Type: "tool_use", ID: fmt.Sprintf("toolu_%s_%d", id, i), Name: tc.Name, }, }) sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_delta", Index: currentBlockIndex, Delta: &schema.AnthropicStreamDelta{ Type: "input_json_delta", PartialJSON: tc.Arguments, }, }) sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_stop", Index: currentBlockIndex, }) currentBlockIndex++ } collectedToolCalls = toolCalls toolCallsEmitted = len(toolCalls) return true } } if !inToolCall { sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_delta", Index: 0, Delta: &schema.AnthropicStreamDelta{ Type: "text_delta", Text: token, }, }) } return true } // Populate openAIReq fields for ComputeChoices openAIReq.Tools = convertFuncsToOpenAITools(funcs) openAIReq.ToolsChoice = input.ToolChoice openAIReq.Metadata = input.Metadata _, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, func(s string, c *[]schema.Choice) {}, tokenCallback) if err != nil { xlog.Error("Anthropic stream model inference failed", "error", err) return sendAnthropicError(c, 500, "api_error", fmt.Sprintf("model inference failed: %v", err)) } // Also check chat deltas for tool calls if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 && len(collectedToolCalls) == 0 { collectedToolCalls = deltaToolCalls } // MCP streaming tool execution: if we collected MCP tool calls, execute and loop if hasMCPTools && len(collectedToolCalls) > 0 { var hasMCPCalls bool for _, tc := range collectedToolCalls { if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { hasMCPCalls = true break } } if hasMCPCalls { // Append assistant message with tool_calls assistantMsg := schema.Message{ Role: "assistant", Content: accumulatedContent, } for i, tc := range collectedToolCalls { toolCallID := tc.ID if toolCallID == "" { toolCallID = fmt.Sprintf("toolu_%s_%d", id, i) } assistantMsg.ToolCalls = append(assistantMsg.ToolCalls, schema.ToolCall{ Index: i, ID: toolCallID, Type: "function", FunctionCall: schema.FunctionCall{ Name: tc.Name, Arguments: tc.Arguments, }, }) } openAIReq.Messages = append(openAIReq.Messages, assistantMsg) // Execute MCP tool calls for _, tc := range assistantMsg.ToolCalls { if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool (Anthropic stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIteration) toolResult, toolErr := mcpTools.ExecuteMCPToolCall( c.Request().Context(), mcpToolInfos, tc.FunctionCall.Name, tc.FunctionCall.Arguments, ) if toolErr != nil { xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) toolResult = fmt.Sprintf("Error: %v", toolErr) } openAIReq.Messages = append(openAIReq.Messages, schema.Message{ Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tc.ID, Name: tc.FunctionCall.Name, }) } xlog.Debug("Anthropic MCP streaming tools executed, re-running inference", "iteration", mcpIteration) continue // next MCP iteration } } // No MCP tools to execute, close stream if !inToolCall { sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "content_block_stop", Index: 0, }) } stopReason := "end_turn" if toolCallsEmitted > 0 { stopReason = "tool_use" } sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "message_delta", Delta: &schema.AnthropicStreamDelta{ StopReason: &stopReason, }, Usage: &schema.AnthropicUsage{ OutputTokens: tokenUsage.Completion, }, }) sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "message_stop", }) return nil } // end MCP iteration loop // Safety fallback sendAnthropicSSE(c, schema.AnthropicStreamEvent{ Type: "message_stop", }) return nil } func convertFuncsToOpenAITools(funcs functions.Functions) []functions.Tool { tools := make([]functions.Tool, len(funcs)) for i, f := range funcs { tools[i] = functions.Tool{Type: "function", Function: f} } return tools } func sendAnthropicSSE(c echo.Context, event schema.AnthropicStreamEvent) { data, err := json.Marshal(event) if err != nil { xlog.Error("Failed to marshal SSE event", "error", err) return } fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.Type, string(data)) c.Response().Flush() } func sendAnthropicError(c echo.Context, statusCode int, errorType, message string) error { resp := schema.AnthropicErrorResponse{ Type: "error", Error: schema.AnthropicError{ Type: errorType, Message: message, }, } return c.JSON(statusCode, resp) } func convertAnthropicToOpenAIMessages(input *schema.AnthropicRequest) []schema.Message { var messages []schema.Message // Add system message if present if input.System != "" { sysStr := string(input.System) messages = append(messages, schema.Message{ Role: "system", StringContent: sysStr, Content: sysStr, }) } // Convert Anthropic messages to OpenAI format for _, msg := range input.Messages { openAIMsg := schema.Message{ Role: msg.Role, } // Handle content (can be string or array of content blocks) switch content := msg.Content.(type) { case string: openAIMsg.StringContent = content openAIMsg.Content = content case []interface{}: // Handle array of content blocks var textContent string var stringImages []string var toolCalls []schema.ToolCall toolCallIndex := 0 for _, block := range content { if blockMap, ok := block.(map[string]interface{}); ok { blockType, _ := blockMap["type"].(string) switch blockType { case "text": if text, ok := blockMap["text"].(string); ok { textContent += text } case "image": // Handle image content if source, ok := blockMap["source"].(map[string]interface{}); ok { if sourceType, ok := source["type"].(string); ok && sourceType == "base64" { if data, ok := source["data"].(string); ok { mediaType, _ := source["media_type"].(string) // Format as data URI dataURI := fmt.Sprintf("data:%s;base64,%s", mediaType, data) stringImages = append(stringImages, dataURI) } } } case "tool_use": // Convert tool_use to ToolCall format toolID, _ := blockMap["id"].(string) toolName, _ := blockMap["name"].(string) toolInput := blockMap["input"] // Serialize input to JSON string inputJSON, err := json.Marshal(toolInput) if err != nil { xlog.Warn("Failed to marshal tool input", "error", err) inputJSON = []byte("{}") } toolCalls = append(toolCalls, schema.ToolCall{ Index: toolCallIndex, ID: toolID, Type: "function", FunctionCall: schema.FunctionCall{ Name: toolName, Arguments: string(inputJSON), }, }) toolCallIndex++ case "tool_result": // Convert tool_result to a message with role "tool" // This is handled by creating a separate message after this block // For now, we'll add it as text content toolUseID, _ := blockMap["tool_use_id"].(string) isError := false if isErrorPtr, ok := blockMap["is_error"].(*bool); ok && isErrorPtr != nil { isError = *isErrorPtr } var resultText string if resultContent, ok := blockMap["content"]; ok { switch rc := resultContent.(type) { case string: resultText = rc case []interface{}: // Array of content blocks for _, cb := range rc { if cbMap, ok := cb.(map[string]interface{}); ok { if cbMap["type"] == "text" { if text, ok := cbMap["text"].(string); ok { resultText += text } } } } } } // Add tool result as a tool role message // We need to handle this differently - create a new message if msg.Role == "user" { // Store tool result info for creating separate message prefix := "" if isError { prefix = "Error: " } textContent += fmt.Sprintf("\n[Tool Result for %s]: %s%s", toolUseID, prefix, resultText) } } } } openAIMsg.StringContent = textContent openAIMsg.Content = textContent openAIMsg.StringImages = stringImages // Add tool calls if present if len(toolCalls) > 0 { openAIMsg.ToolCalls = toolCalls } } messages = append(messages, openAIMsg) } return messages } // convertAnthropicTools converts Anthropic tools to internal Functions format func convertAnthropicTools(input *schema.AnthropicRequest, cfg *config.ModelConfig) (functions.Functions, bool) { if len(input.Tools) == 0 { return nil, false } var funcs functions.Functions for _, tool := range input.Tools { f := functions.Function{ Name: tool.Name, Description: tool.Description, Parameters: tool.InputSchema, } funcs = append(funcs, f) } // Handle tool_choice if input.ToolChoice != nil { switch tc := input.ToolChoice.(type) { case string: // "auto", "any", or "none" if tc == "any" { // Force the model to use one of the tools cfg.SetFunctionCallString("required") } else if tc == "none" { // Don't use tools return nil, false } // "auto" is the default - let model decide case map[string]interface{}: // Specific tool selection: {"type": "tool", "name": "tool_name"} if tcType, ok := tc["type"].(string); ok && tcType == "tool" { if name, ok := tc["name"].(string); ok { // Force specific tool cfg.SetFunctionCallString(name) } } } } return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions() } ================================================ FILE: core/http/endpoints/elevenlabs/soundgeneration.go ================================================ package elevenlabs import ( "path/filepath" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/audio" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // SoundGenerationEndpoint is the ElevenLabs SoundGeneration endpoint https://elevenlabs.io/docs/api-reference/sound-generation // @Summary Generates audio from the input text. // @Param request body schema.ElevenLabsSoundGenerationRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/sound-generation [post] func SoundGenerationEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsSoundGenerationRequest) if !ok || input.ModelID == "" { return echo.ErrBadRequest } cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } xlog.Debug("Sound Generation Request about to be sent to backend", "modelFile", "modelFile", "backend", cfg.Backend) language := input.Language if language == "" { language = input.VocalLanguage } var bpm *int32 if input.BPM != nil { b := int32(*input.BPM) bpm = &b } filePath, _, err := backend.SoundGeneration( input.Text, input.Duration, input.Temperature, input.DoSample, nil, nil, input.Think, input.Caption, input.Lyrics, bpm, input.Keyscale, language, input.Timesignature, input.Instrumental, ml, appConfig, *cfg) if err != nil { return err } filePath, contentType := audio.NormalizeAudioFile(filePath) if contentType != "" { c.Response().Header().Set("Content-Type", contentType) } return c.Attachment(filePath, filepath.Base(filePath)) } } ================================================ FILE: core/http/endpoints/elevenlabs/tts.go ================================================ package elevenlabs import ( "path/filepath" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/audio" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech // @Summary Generates audio from the input text. // @Param voice-id path string true "Account ID" // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "Response" // @Router /v1/text-to-speech/{voice-id} [post] func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { voiceID := c.Param("voice-id") input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.ElevenLabsTTSRequest) if !ok || input.ModelID == "" { return echo.ErrBadRequest } cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } xlog.Debug("elevenlabs TTS request received", "modelName", input.ModelID) filePath, _, err := backend.ModelTTS(input.Text, voiceID, input.LanguageCode, ml, appConfig, *cfg) if err != nil { return err } filePath, contentType := audio.NormalizeAudioFile(filePath) if contentType != "" { c.Response().Header().Set("Content-Type", contentType) } return c.Attachment(filePath, filepath.Base(filePath)) } } ================================================ FILE: core/http/endpoints/explorer/dashboard.go ================================================ package explorer import ( "encoding/base64" "net/http" "sort" "strings" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/explorer" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/internal" ) func Dashboard() echo.HandlerFunc { return func(c echo.Context) error { summary := map[string]interface{}{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), "BaseURL": middleware.BaseURL(c), } contentType := c.Request().Header.Get("Content-Type") accept := c.Request().Header.Get("Accept") if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "html")) { // The client expects a JSON response return c.JSON(http.StatusOK, summary) } else { // Render index return c.Render(http.StatusOK, "views/explorer", summary) } } } type AddNetworkRequest struct { Token string `json:"token"` Name string `json:"name"` Description string `json:"description"` } type Network struct { explorer.TokenData Token string `json:"token"` } func ShowNetworks(db *explorer.Database) echo.HandlerFunc { return func(c echo.Context) error { results := []Network{} for _, token := range db.TokenList() { networkData, exists := db.Get(token) // get the token data hasWorkers := false for _, cluster := range networkData.Clusters { if len(cluster.Workers) > 0 { hasWorkers = true break } } if exists && hasWorkers { results = append(results, Network{TokenData: networkData, Token: token}) } } // order by number of clusters sort.Slice(results, func(i, j int) bool { return len(results[i].Clusters) > len(results[j].Clusters) }) return c.JSON(http.StatusOK, results) } } func AddNetwork(db *explorer.Database) echo.HandlerFunc { return func(c echo.Context) error { request := new(AddNetworkRequest) if err := c.Bind(request); err != nil { return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Cannot parse JSON"}) } if request.Token == "" { return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token is required"}) } if request.Name == "" { return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Name is required"}) } if request.Description == "" { return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Description is required"}) } // TODO: check if token is valid, otherwise reject // try to decode the token from base64 _, err := base64.StdEncoding.DecodeString(request.Token) if err != nil { return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Invalid token"}) } if _, exists := db.Get(request.Token); exists { return c.JSON(http.StatusBadRequest, map[string]interface{}{"error": "Token already exists"}) } err = db.Set(request.Token, explorer.TokenData{Name: request.Name, Description: request.Description}) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]interface{}{"error": "Cannot add token"}) } return c.JSON(http.StatusOK, map[string]interface{}{"message": "Token added"}) } } ================================================ FILE: core/http/endpoints/jina/rerank.go ================================================ package jina import ( "net/http" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // JINARerankEndpoint acts like the Jina reranker endpoint (https://jina.ai/reranker/) // @Summary Reranks a list of phrases by relevance to a given text query. // @Param request body schema.JINARerankRequest true "query params" // @Success 200 {object} schema.JINARerankResponse "Response" // @Router /v1/rerank [post] func JINARerankEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.JINARerankRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } xlog.Debug("JINA Rerank Request received", "model", input.Model) var requestTopN int32 docs := int32(len(input.Documents)) if input.TopN == nil { // omit top_n to get all requestTopN = docs } else { requestTopN = int32(*input.TopN) if requestTopN < 1 { return c.JSON(http.StatusUnprocessableEntity, "top_n - should be greater than or equal to 1") } if requestTopN > docs { // make it more obvious for backends requestTopN = docs } } request := &proto.RerankRequest{ Query: input.Query, TopN: requestTopN, Documents: input.Documents, } results, err := backend.Rerank(request, ml, appConfig, *cfg) if err != nil { return err } response := &schema.JINARerankResponse{ Model: input.Model, } for _, r := range results.Results { response.Results = append(response.Results, schema.JINADocumentResult{ Index: int(r.Index), Document: schema.JINAText{Text: r.Text}, RelevanceScore: float64(r.RelevanceScore), }) } response.Usage.TotalTokens = int(results.Usage.TotalTokens) response.Usage.PromptTokens = int(results.Usage.PromptTokens) return c.JSON(http.StatusOK, response) } } ================================================ FILE: core/http/endpoints/localai/agent_collections.go ================================================ package localai import ( "net/http" "net/url" "strings" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" ) func ListCollectionsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) cols, err := svc.ListCollectionsForUser(userID) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } resp := map[string]any{ "collections": cols, "count": len(cols), } // Admin cross-user aggregation if wantsAllUsers(c) { usm := svc.UserServicesManager() if usm != nil { userIDs, _ := usm.ListAllUserIDs() userGroups := map[string]any{} for _, uid := range userIDs { if uid == userID { continue } userCols, err := svc.ListCollectionsForUser(uid) if err != nil || len(userCols) == 0 { continue } userGroups[uid] = map[string]any{"collections": userCols} } if len(userGroups) > 0 { resp["user_groups"] = userGroups } } } return c.JSON(http.StatusOK, resp) } } func CreateCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) var payload struct { Name string `json:"name"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } if err := svc.CreateCollectionForUser(userID, payload.Name); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok", "name": payload.Name}) } } func UploadToCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"}) } src, err := file.Open() if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } defer src.Close() if err := svc.UploadToCollectionForUser(userID, name, file.Filename, src); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok", "filename": file.Filename}) } } func ListCollectionEntriesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) entries, err := svc.ListCollectionEntriesForUser(userID, c.Param("name")) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]any{ "entries": entries, "count": len(entries), }) } } func GetCollectionEntryContentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) entryParam := c.Param("*") entry, err := url.PathUnescape(entryParam) if err != nil { entry = entryParam } content, chunkCount, err := svc.GetCollectionEntryContentForUser(userID, c.Param("name"), entry) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]any{ "content": content, "chunk_count": chunkCount, }) } } func SearchCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) var payload struct { Query string `json:"query"` MaxResults int `json:"max_results"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } results, err := svc.SearchCollectionForUser(userID, c.Param("name"), payload.Query, payload.MaxResults) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]any{ "results": results, "count": len(results), }) } } func ResetCollectionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) if err := svc.ResetCollectionForUser(userID, c.Param("name")); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } func DeleteCollectionEntryEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) var payload struct { Entry string `json:"entry"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } remaining, err := svc.DeleteCollectionEntryForUser(userID, c.Param("name"), payload.Entry) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]any{ "remaining_entries": remaining, "count": len(remaining), }) } } func AddCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) var payload struct { URL string `json:"url"` UpdateInterval int `json:"update_interval"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } if payload.UpdateInterval < 1 { payload.UpdateInterval = 60 } if err := svc.AddCollectionSourceForUser(userID, c.Param("name"), payload.URL, payload.UpdateInterval); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } func RemoveCollectionSourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) var payload struct { URL string `json:"url"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } if err := svc.RemoveCollectionSourceForUser(userID, c.Param("name"), payload.URL); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } // GetCollectionEntryRawFileEndpoint serves the original uploaded binary file. func GetCollectionEntryRawFileEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) entryParam := c.Param("*") entry, err := url.PathUnescape(entryParam) if err != nil { entry = entryParam } fpath, err := svc.GetCollectionEntryFilePathForUser(userID, c.Param("name"), entry) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.File(fpath) } } func ListCollectionSourcesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) sources, err := svc.ListCollectionSourcesForUser(userID, c.Param("name")) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]any{ "sources": sources, "count": len(sources), }) } } ================================================ FILE: core/http/endpoints/localai/agent_jobs.go ================================================ package localai import ( "fmt" "net/http" "strconv" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" ) // getJobService returns the job service for the current user. // Falls back to the global service when no user is authenticated. func getJobService(app *application.Application, c echo.Context) *services.AgentJobService { userID := getUserID(c) if userID == "" { return app.AgentJobService() } svc := app.AgentPoolService() if svc == nil { return app.AgentJobService() } jobSvc, err := svc.JobServiceForUser(userID) if err != nil { return app.AgentJobService() } return jobSvc } func CreateTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var task schema.Task if err := c.Bind(&task); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()}) } id, err := getJobService(app, c).CreateTask(task) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"id": id}) } } func UpdateTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") var task schema.Task if err := c.Bind(&task); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()}) } if err := getJobService(app, c).UpdateTask(id, task); err != nil { if err.Error() == "task not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"message": "Task updated"}) } } func DeleteTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") if err := getJobService(app, c).DeleteTask(id); err != nil { if err.Error() == "task not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"message": "Task deleted"}) } } func ListTasksEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { jobSvc := getJobService(app, c) tasks := jobSvc.ListTasks() // Admin cross-user aggregation if wantsAllUsers(c) { svc := app.AgentPoolService() if svc != nil { usm := svc.UserServicesManager() if usm != nil { userID := getUserID(c) userIDs, _ := usm.ListAllUserIDs() userGroups := map[string]any{} for _, uid := range userIDs { if uid == userID { continue } userJobSvc, err := svc.JobServiceForUser(uid) if err != nil { continue } userTasks := userJobSvc.ListTasks() if len(userTasks) == 0 { continue } userGroups[uid] = map[string]any{"tasks": userTasks} } if len(userGroups) > 0 { return c.JSON(http.StatusOK, map[string]any{ "tasks": tasks, "user_groups": userGroups, }) } } } } return c.JSON(http.StatusOK, tasks) } } func GetTaskEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") task, err := getJobService(app, c).GetTask(id) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, task) } } func ExecuteJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var req schema.JobExecutionRequest if err := c.Bind(&req); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request body: " + err.Error()}) } if req.Parameters == nil { req.Parameters = make(map[string]string) } var multimedia *schema.MultimediaAttachment if len(req.Images) > 0 || len(req.Videos) > 0 || len(req.Audios) > 0 || len(req.Files) > 0 { multimedia = &schema.MultimediaAttachment{ Images: req.Images, Videos: req.Videos, Audios: req.Audios, Files: req.Files, } } jobID, err := getJobService(app, c).ExecuteJob(req.TaskID, req.Parameters, "api", multimedia) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } baseURL := c.Scheme() + "://" + c.Request().Host return c.JSON(http.StatusCreated, schema.JobExecutionResponse{ JobID: jobID, Status: "pending", URL: baseURL + "/api/agent/jobs/" + jobID, }) } } func GetJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") job, err := getJobService(app, c).GetJob(id) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, job) } } func ListJobsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var taskID *string var status *schema.JobStatus limit := 0 if taskIDParam := c.QueryParam("task_id"); taskIDParam != "" { taskID = &taskIDParam } if statusParam := c.QueryParam("status"); statusParam != "" { s := schema.JobStatus(statusParam) status = &s } if limitParam := c.QueryParam("limit"); limitParam != "" { if l, err := strconv.Atoi(limitParam); err == nil { limit = l } } jobSvc := getJobService(app, c) jobs := jobSvc.ListJobs(taskID, status, limit) // Admin cross-user aggregation if wantsAllUsers(c) { svc := app.AgentPoolService() if svc != nil { usm := svc.UserServicesManager() if usm != nil { userID := getUserID(c) userIDs, _ := usm.ListAllUserIDs() userGroups := map[string]any{} for _, uid := range userIDs { if uid == userID { continue } userJobSvc, err := svc.JobServiceForUser(uid) if err != nil { continue } userJobs := userJobSvc.ListJobs(taskID, status, limit) if len(userJobs) == 0 { continue } userGroups[uid] = map[string]any{"jobs": userJobs} } if len(userGroups) > 0 { return c.JSON(http.StatusOK, map[string]any{ "jobs": jobs, "user_groups": userGroups, }) } } } } return c.JSON(http.StatusOK, jobs) } } func CancelJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") if err := getJobService(app, c).CancelJob(id); err != nil { if err.Error() == "job not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"message": "Job cancelled"}) } } func DeleteJobEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { id := c.Param("id") if err := getJobService(app, c).DeleteJob(id); err != nil { if err.Error() == "job not found: "+id { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"message": "Job deleted"}) } } func ExecuteTaskByNameEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { name := c.Param("name") var params map[string]string if c.Request().ContentLength > 0 { if err := c.Bind(¶ms); err != nil { body := make(map[string]interface{}) if err := c.Bind(&body); err == nil { params = make(map[string]string) for k, v := range body { if str, ok := v.(string); ok { params[k] = str } else { params[k] = fmt.Sprintf("%v", v) } } } else { params = make(map[string]string) } } } else { params = make(map[string]string) } jobSvc := getJobService(app, c) tasks := jobSvc.ListTasks() var task *schema.Task for _, t := range tasks { if t.Name == name { task = &t break } } if task == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Task not found: " + name}) } jobID, err := jobSvc.ExecuteJob(task.ID, params, "api", nil) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } baseURL := c.Scheme() + "://" + c.Request().Host return c.JSON(http.StatusCreated, schema.JobExecutionResponse{ JobID: jobID, Status: "pending", URL: baseURL + "/api/agent/jobs/" + jobID, }) } } ================================================ FILE: core/http/endpoints/localai/agent_responses.go ================================================ package localai import ( "bytes" "encoding/json" "fmt" "io" "net/http" "time" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" coreTypes "github.com/mudler/LocalAGI/core/types" "github.com/mudler/xlog" "github.com/sashabaranov/go-openai" ) // agentResponsesRequest is the minimal subset of the OpenResponses request body // needed to route to an agent. type agentResponsesRequest struct { Model string `json:"model"` Input json.RawMessage `json:"input"` PreviousResponseID string `json:"previous_response_id,omitempty"` Tools []json.RawMessage `json:"tools,omitempty"` ToolChoice json.RawMessage `json:"tool_choice,omitempty"` } // AgentResponsesInterceptor returns a middleware that intercepts /v1/responses // requests when the model name matches an agent in the pool. If no agent matches, // it restores the request body and falls through to the normal responses pipeline. func AgentResponsesInterceptor(app *application.Application) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() if svc == nil { return next(c) } // Read and buffer the body so we can peek at the model name body, err := io.ReadAll(c.Request().Body) if err != nil { return next(c) } // Always restore the body for the next handler c.Request().Body = io.NopCloser(bytes.NewReader(body)) var req agentResponsesRequest if err := json.Unmarshal(body, &req); err != nil || req.Model == "" { return next(c) } // Check if this model name is an agent ag := svc.GetAgent(req.Model) if ag == nil { return next(c) } // This is an agent — handle the request directly messages := parseInputToMessages(req.Input) if len(messages) == 0 { return c.JSON(http.StatusBadRequest, map[string]any{ "error": map[string]string{ "type": "invalid_request_error", "message": "no input messages provided", }, }) } jobOptions := []coreTypes.JobOption{ coreTypes.WithConversationHistory(messages), } res := ag.Ask(jobOptions...) if res == nil { return c.JSON(http.StatusInternalServerError, map[string]any{ "error": map[string]string{ "type": "server_error", "message": "agent request failed or was cancelled", }, }) } if res.Error != nil { xlog.Error("Error asking agent via responses API", "agent", req.Model, "error", res.Error) return c.JSON(http.StatusInternalServerError, map[string]any{ "error": map[string]string{ "type": "server_error", "message": res.Error.Error(), }, }) } id := fmt.Sprintf("resp_%s", uuid.New().String()) return c.JSON(http.StatusOK, map[string]any{ "id": id, "object": "response", "created_at": time.Now().Unix(), "status": "completed", "model": req.Model, "previous_response_id": nil, "output": []any{ map[string]any{ "type": "message", "id": fmt.Sprintf("msg_%d", time.Now().UnixNano()), "status": "completed", "role": "assistant", "content": []map[string]any{ { "type": "output_text", "text": res.Response, "annotations": []any{}, }, }, }, }, }) } } } // parseInputToMessages converts the raw JSON input (string or message array) to openai messages. func parseInputToMessages(raw json.RawMessage) []openai.ChatCompletionMessage { if len(raw) == 0 { return nil } // Try as string first var text string if err := json.Unmarshal(raw, &text); err == nil && text != "" { return []openai.ChatCompletionMessage{ {Role: "user", Content: text}, } } // Try as array of message objects var messages []struct { Type string `json:"type,omitempty"` Role string `json:"role,omitempty"` Content json.RawMessage `json:"content,omitempty"` CallId string `json:"call_id,omitempty"` Name string `json:"name,omitempty"` Arguments string `json:"arguments,omitempty"` Output string `json:"output,omitempty"` } if err := json.Unmarshal(raw, &messages); err != nil { return nil } var result []openai.ChatCompletionMessage for _, m := range messages { switch m.Type { case "function_call": result = append(result, openai.ChatCompletionMessage{ Role: "assistant", ToolCalls: []openai.ToolCall{ { Type: "function", ID: m.CallId, Function: openai.FunctionCall{ Arguments: m.Arguments, Name: m.Name, }, }, }, }) case "function_call_output": if m.CallId != "" && m.Output != "" { result = append(result, openai.ChatCompletionMessage{ Role: "tool", Content: m.Output, ToolCallID: m.CallId, }) } default: if m.Role == "" { continue } content := parseMessageContent(m.Content) if content != "" { result = append(result, openai.ChatCompletionMessage{ Role: m.Role, Content: content, }) } } } return result } // parseMessageContent extracts text from either a string or array of content items. func parseMessageContent(raw json.RawMessage) string { if len(raw) == 0 { return "" } var text string if err := json.Unmarshal(raw, &text); err == nil { return text } var items []struct { Type string `json:"type"` Text string `json:"text,omitempty"` } if err := json.Unmarshal(raw, &items); err == nil { for _, item := range items { if item.Type == "text" || item.Type == "input_text" { return item.Text } } } return "" } ================================================ FILE: core/http/endpoints/localai/agent_skills.go ================================================ package localai import ( "io" "net/http" "strings" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" skilldomain "github.com/mudler/skillserver/pkg/domain" ) type skillResponse struct { Name string `json:"name"` Content string `json:"content"` Description string `json:"description,omitempty"` License string `json:"license,omitempty"` Compatibility string `json:"compatibility,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` AllowedTools string `json:"allowed-tools,omitempty"` ReadOnly bool `json:"readOnly"` } func skillToResponse(s skilldomain.Skill) skillResponse { out := skillResponse{Name: s.Name, Content: s.Content, ReadOnly: s.ReadOnly} if s.Metadata != nil { out.Description = s.Metadata.Description out.License = s.Metadata.License out.Compatibility = s.Metadata.Compatibility out.Metadata = s.Metadata.Metadata out.AllowedTools = s.Metadata.AllowedTools.String() } return out } func skillsToResponses(skills []skilldomain.Skill) []skillResponse { out := make([]skillResponse, len(skills)) for i, s := range skills { out[i] = skillToResponse(s) } return out } func ListSkillsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) skills, err := svc.ListSkillsForUser(userID) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } // Admin cross-user aggregation if wantsAllUsers(c) { usm := svc.UserServicesManager() if usm != nil { userIDs, _ := usm.ListAllUserIDs() userGroups := map[string]any{} for _, uid := range userIDs { if uid == userID { continue } userSkills, err := svc.ListSkillsForUser(uid) if err != nil || len(userSkills) == 0 { continue } userGroups[uid] = map[string]any{"skills": skillsToResponses(userSkills)} } resp := map[string]any{ "skills": skillsToResponses(skills), } if len(userGroups) > 0 { resp["user_groups"] = userGroups } return c.JSON(http.StatusOK, resp) } } return c.JSON(http.StatusOK, skillsToResponses(skills)) } } func GetSkillsConfigEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) cfg := svc.GetSkillsConfigForUser(userID) return c.JSON(http.StatusOK, cfg) } } func SearchSkillsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) query := c.QueryParam("q") skills, err := svc.SearchSkillsForUser(userID, query) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, skillsToResponses(skills)) } } func CreateSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) var payload struct { Name string `json:"name"` Description string `json:"description"` Content string `json:"content"` License string `json:"license,omitempty"` Compatibility string `json:"compatibility,omitempty"` AllowedTools string `json:"allowed-tools,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } skill, err := svc.CreateSkillForUser(userID, payload.Name, payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) if err != nil { if strings.Contains(err.Error(), "already exists") { return c.JSON(http.StatusConflict, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, skillToResponse(*skill)) } } func GetSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) skill, err := svc.GetSkillForUser(userID, c.Param("name")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, skillToResponse(*skill)) } } func UpdateSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) var payload struct { Description string `json:"description"` Content string `json:"content"` License string `json:"license,omitempty"` Compatibility string `json:"compatibility,omitempty"` AllowedTools string `json:"allowed-tools,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } skill, err := svc.UpdateSkillForUser(userID, c.Param("name"), payload.Description, payload.Content, payload.License, payload.Compatibility, payload.AllowedTools, payload.Metadata) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, skillToResponse(*skill)) } } func DeleteSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) if err := svc.DeleteSkillForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } func ExportSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("*") data, err := svc.ExportSkillForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } c.Response().Header().Set("Content-Disposition", "attachment; filename="+name+".tar.gz") c.Response().Header().Set("Content-Type", "application/gzip") return c.Blob(http.StatusOK, "application/gzip", data) } } func ImportSkillEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file required"}) } src, err := file.Open() if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } defer src.Close() data, err := io.ReadAll(src) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } skill, err := svc.ImportSkillForUser(userID, data) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, skill) } } // --- Skill Resources --- func ListSkillResourcesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) resources, skill, err := svc.ListSkillResourcesForUser(userID, c.Param("name")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } scripts := []map[string]any{} references := []map[string]any{} assets := []map[string]any{} for _, res := range resources { m := map[string]any{ "path": res.Path, "name": res.Name, "size": res.Size, "mime_type": res.MimeType, "readable": res.Readable, "modified": res.Modified.Format("2006-01-02T15:04:05Z07:00"), } switch res.Type { case "script": scripts = append(scripts, m) case "reference": references = append(references, m) case "asset": assets = append(assets, m) } } return c.JSON(http.StatusOK, map[string]any{ "scripts": scripts, "references": references, "assets": assets, "readOnly": skill.ReadOnly, }) } } func GetSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) content, info, err := svc.GetSkillResourceForUser(userID, c.Param("name"), c.Param("*")) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } if c.QueryParam("encoding") == "base64" || !info.Readable { return c.JSON(http.StatusOK, map[string]any{ "content": content.Content, "encoding": content.Encoding, "mime_type": content.MimeType, "size": content.Size, }) } c.Response().Header().Set("Content-Type", content.MimeType) return c.String(http.StatusOK, content.Content) } } func CreateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) file, err := c.FormFile("file") if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "file is required"}) } path := c.FormValue("path") if path == "" { path = file.Filename } src, err := file.Open() if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to open file"}) } defer src.Close() data, err := io.ReadAll(src) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } if err := svc.CreateSkillResourceForUser(userID, c.Param("name"), path, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"path": path}) } } func UpdateSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) var payload struct { Content string `json:"content"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } if err := svc.UpdateSkillResourceForUser(userID, c.Param("name"), c.Param("*"), payload.Content); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } func DeleteSkillResourceEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) if err := svc.DeleteSkillResourceForUser(userID, c.Param("name"), c.Param("*")); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } // --- Git Repos --- func ListGitReposEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) repos, err := svc.ListGitReposForUser(userID) if err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, repos) } } func AddGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) var payload struct { URL string `json:"url"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } repo, err := svc.AddGitRepoForUser(userID, payload.URL) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, repo) } } func UpdateGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) var payload struct { URL string `json:"url"` Enabled *bool `json:"enabled"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } repo, err := svc.UpdateGitRepoForUser(userID, c.Param("id"), payload.URL, payload.Enabled) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, repo) } } func DeleteGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) if err := svc.DeleteGitRepoForUser(userID, c.Param("id")); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } func SyncGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) if err := svc.SyncGitRepoForUser(userID, c.Param("id")); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusAccepted, map[string]string{"status": "syncing"}) } } func ToggleGitRepoEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) repo, err := svc.ToggleGitRepoForUser(userID, c.Param("id")) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, repo) } } ================================================ FILE: core/http/endpoints/localai/agents.go ================================================ package localai import ( "encoding/json" "fmt" "io" "net/http" "os" "path/filepath" "sort" "strings" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAGI/core/state" coreTypes "github.com/mudler/LocalAGI/core/types" agiServices "github.com/mudler/LocalAGI/services" ) // getUserID extracts the scoped user ID from the request context. // Returns empty string when auth is not active (backward compat). func getUserID(c echo.Context) string { user := auth.GetUser(c) if user == nil { return "" } return user.ID } // isAdminUser returns true if the authenticated user has admin role. func isAdminUser(c echo.Context) bool { user := auth.GetUser(c) return user != nil && user.Role == auth.RoleAdmin } // wantsAllUsers returns true if the request has ?all_users=true and the user is admin. func wantsAllUsers(c echo.Context) bool { return c.QueryParam("all_users") == "true" && isAdminUser(c) } // effectiveUserID returns the user ID to scope operations to. // SECURITY: Only admins may supply ?user_id= to operate on another user's // resources. Non-admin callers always get their own ID regardless of query params. func effectiveUserID(c echo.Context) string { if targetUID := c.QueryParam("user_id"); targetUID != "" && isAdminUser(c) { return targetUID } return getUserID(c) } func ListAgentsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) statuses := svc.ListAgentsForUser(userID) agents := make([]string, 0, len(statuses)) for name := range statuses { agents = append(agents, name) } sort.Strings(agents) resp := map[string]any{ "agents": agents, "agentCount": len(agents), "actions": len(agiServices.AvailableActions), "connectors": len(agiServices.AvailableConnectors), "statuses": statuses, } if hubURL := svc.AgentHubURL(); hubURL != "" { resp["agent_hub_url"] = hubURL } // Admin cross-user aggregation if wantsAllUsers(c) { grouped := svc.ListAllAgentsGrouped() userGroups := map[string]any{} for uid, agentList := range grouped { if uid == userID || uid == "" { continue } userGroups[uid] = map[string]any{"agents": agentList} } if len(userGroups) > 0 { resp["user_groups"] = userGroups } } return c.JSON(http.StatusOK, resp) } } func CreateAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) var cfg state.AgentConfig if err := c.Bind(&cfg); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } if err := svc.CreateAgentForUser(userID, &cfg); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok"}) } } func GetAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") ag := svc.GetAgentForUser(userID, name) if ag == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } return c.JSON(http.StatusOK, map[string]any{ "active": !ag.Paused(), }) } } func UpdateAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") var cfg state.AgentConfig if err := c.Bind(&cfg); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } if err := svc.UpdateAgentForUser(userID, name, &cfg); err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } func DeleteAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") if err := svc.DeleteAgentForUser(userID, name); err != nil { return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } func GetAgentConfigEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") cfg := svc.GetAgentConfigForUser(userID, name) if cfg == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } return c.JSON(http.StatusOK, cfg) } } func PauseAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) if err := svc.PauseAgentForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } func ResumeAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) if err := svc.ResumeAgentForUser(userID, c.Param("name")); err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]string{"status": "ok"}) } } func GetAgentStatusEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") history := svc.GetAgentStatusForUser(userID, name) if history == nil { history = &state.Status{ActionResults: []coreTypes.ActionState{}} } entries := []string{} for i := len(history.Results()) - 1; i >= 0; i-- { h := history.Results()[i] actionName := "" if h.ActionCurrentState.Action != nil { actionName = h.ActionCurrentState.Action.Definition().Name.String() } entries = append(entries, fmt.Sprintf("Reasoning: %s\nAction taken: %s\nParameters: %+v\nResult: %s", h.Reasoning, actionName, h.ActionCurrentState.Params, h.Result)) } return c.JSON(http.StatusOK, map[string]any{ "Name": name, "History": entries, }) } } func GetAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") history, err := svc.GetAgentObservablesForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]any{ "Name": name, "History": history, }) } } func ClearAgentObservablesEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") if err := svc.ClearAgentObservablesForUser(userID, name); err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, map[string]any{"Name": name, "cleared": true}) } } func ChatWithAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") var payload struct { Message string `json:"message"` } if err := c.Bind(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Invalid request format"}) } message := strings.TrimSpace(payload.Message) if message == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "Message cannot be empty"}) } messageID, err := svc.ChatForUser(userID, name, message) if err != nil { if strings.Contains(err.Error(), "not found") { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusInternalServerError, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusAccepted, map[string]any{ "status": "message_received", "message_id": messageID, }) } } func AgentSSEEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") manager := svc.GetSSEManagerForUser(userID, name) if manager == nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "Agent not found"}) } return services.HandleSSE(c, manager) } } type agentConfigMetaResponse struct { state.AgentConfigMeta OutputsDir string `json:"OutputsDir"` } func GetAgentConfigMetaEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() return c.JSON(http.StatusOK, agentConfigMetaResponse{ AgentConfigMeta: svc.GetConfigMeta(), OutputsDir: svc.OutputsDir(), }) } } func ExportAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := effectiveUserID(c) name := c.Param("name") data, err := svc.ExportAgentForUser(userID, name) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": err.Error()}) } c.Response().Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%s.json", name)) return c.JSONBlob(http.StatusOK, data) } } func ImportAgentEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() userID := getUserID(c) // Try multipart form file first file, err := c.FormFile("file") if err == nil { src, err := file.Open() if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to open file"}) } defer src.Close() data, err := io.ReadAll(src) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "failed to read file"}) } if err := svc.ImportAgentForUser(userID, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok"}) } // Try JSON body var cfg state.AgentConfig if err := json.NewDecoder(c.Request().Body).Decode(&cfg); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request: provide a file or JSON body"}) } data, err := json.Marshal(&cfg) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } if err := svc.ImportAgentForUser(userID, data); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusCreated, map[string]string{"status": "ok"}) } } // --- Actions --- func ListActionsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() return c.JSON(http.StatusOK, map[string]any{ "actions": svc.ListAvailableActions(), }) } } func GetActionDefinitionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() actionName := c.Param("name") var payload struct { Config map[string]string `json:"config"` } if err := json.NewDecoder(c.Request().Body).Decode(&payload); err != nil { payload.Config = map[string]string{} } def, err := svc.GetActionDefinition(actionName, payload.Config) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, def) } } func ExecuteActionEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() actionName := c.Param("name") var payload struct { Config map[string]string `json:"config"` Params coreTypes.ActionParams `json:"params"` } if err := json.NewDecoder(c.Request().Body).Decode(&payload); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"}) } result, err := svc.ExecuteAction(c.Request().Context(), actionName, payload.Config, payload.Params) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": err.Error()}) } return c.JSON(http.StatusOK, result) } } func AgentFileEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { svc := app.AgentPoolService() requestedPath := c.QueryParam("path") if requestedPath == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "no file path specified"}) } // Resolve the real path (follows symlinks, eliminates ..) resolved, err := filepath.EvalSymlinks(filepath.Clean(requestedPath)) if err != nil { return c.JSON(http.StatusNotFound, map[string]string{"error": "file not found"}) } // Determine the allowed outputs directory — scoped to the user when auth is active allowedDir := svc.OutputsDir() user := auth.GetUser(c) if user != nil { allowedDir = filepath.Join(allowedDir, user.ID) } allowedDirResolved, _ := filepath.EvalSymlinks(filepath.Clean(allowedDir)) if utils.InTrustedRoot(resolved, allowedDirResolved) != nil { return c.JSON(http.StatusForbidden, map[string]string{"error": "access denied"}) } info, err := os.Stat(resolved) if err != nil || info.IsDir() { return c.JSON(http.StatusNotFound, map[string]string{"error": "file not found"}) } return c.File(resolved) } } ================================================ FILE: core/http/endpoints/localai/backend.go ================================================ package localai import ( "encoding/json" "fmt" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) type BackendEndpointService struct { galleries []config.Gallery backendPath string backendSystemPath string backendApplier *services.GalleryService } type GalleryBackend struct { ID string `json:"id"` } func CreateBackendEndpointService(galleries []config.Gallery, systemState *system.SystemState, backendApplier *services.GalleryService) BackendEndpointService { return BackendEndpointService{ galleries: galleries, backendPath: systemState.Backend.BackendsPath, backendSystemPath: systemState.Backend.BackendsSystemPath, backendApplier: backendApplier, } } // GetOpStatusEndpoint returns the job status // @Summary Returns the job status // @Success 200 {object} services.GalleryOpStatus "Response" // @Router /backends/jobs/{uuid} [get] func (mgs *BackendEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { return func(c echo.Context) error { status := mgs.backendApplier.GetStatus(c.Param("uuid")) if status == nil { return fmt.Errorf("could not find any status for ID") } return c.JSON(200, status) } } // GetAllStatusEndpoint returns all the jobs status progress // @Summary Returns all the jobs status progress // @Success 200 {object} map[string]services.GalleryOpStatus "Response" // @Router /backends/jobs [get] func (mgs *BackendEndpointService) GetAllStatusEndpoint() echo.HandlerFunc { return func(c echo.Context) error { return c.JSON(200, mgs.backendApplier.GetAllStatus()) } } // ApplyBackendEndpoint installs a new backend to a LocalAI instance // @Summary Install backends to LocalAI. // @Param request body GalleryBackend true "query params" // @Success 200 {object} schema.BackendResponse "Response" // @Router /backends/apply [post] func (mgs *BackendEndpointService) ApplyBackendEndpoint() echo.HandlerFunc { return func(c echo.Context) error { input := new(GalleryBackend) // Get input data from the request body if err := c.Bind(input); err != nil { return err } uuid, err := uuid.NewUUID() if err != nil { return err } mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{ ID: uuid.String(), GalleryElementName: input.ID, Galleries: mgs.galleries, } return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())}) } } // DeleteBackendEndpoint lets delete backends from a LocalAI instance // @Summary delete backends from LocalAI. // @Param name path string true "Backend name" // @Success 200 {object} schema.BackendResponse "Response" // @Router /backends/delete/{name} [post] func (mgs *BackendEndpointService) DeleteBackendEndpoint() echo.HandlerFunc { return func(c echo.Context) error { backendName := c.Param("name") mgs.backendApplier.BackendGalleryChannel <- services.GalleryOp[gallery.GalleryBackend, any]{ Delete: true, GalleryElementName: backendName, Galleries: mgs.galleries, } uuid, err := uuid.NewUUID() if err != nil { return err } return c.JSON(200, schema.BackendResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%sbackends/jobs/%s", middleware.BaseURL(c), uuid.String())}) } } // ListBackendsEndpoint list the available backends configured in LocalAI // @Summary List all Backends // @Success 200 {object} []gallery.GalleryBackend "Response" // @Router /backends [get] func (mgs *BackendEndpointService) ListBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc { return func(c echo.Context) error { backends, err := gallery.ListSystemBackends(systemState) if err != nil { return err } return c.JSON(200, backends.GetAll()) } } // ListModelGalleriesEndpoint list the available galleries configured in LocalAI // @Summary List all Galleries // @Success 200 {object} []config.Gallery "Response" // @Router /backends/galleries [get] // NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! func (mgs *BackendEndpointService) ListBackendGalleriesEndpoint() echo.HandlerFunc { return func(c echo.Context) error { xlog.Debug("Listing backend galleries", "galleries", mgs.galleries) dat, err := json.Marshal(mgs.galleries) if err != nil { return err } return c.Blob(200, "application/json", dat) } } // ListAvailableBackendsEndpoint list the available backends in the galleries configured in LocalAI // @Summary List all available Backends // @Success 200 {object} []gallery.GalleryBackend "Response" // @Router /backends/available [get] func (mgs *BackendEndpointService) ListAvailableBackendsEndpoint(systemState *system.SystemState) echo.HandlerFunc { return func(c echo.Context) error { backends, err := gallery.AvailableBackends(mgs.galleries, systemState) if err != nil { return err } return c.JSON(200, backends) } } ================================================ FILE: core/http/endpoints/localai/backend_monitor.go ================================================ package localai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" ) // BackendMonitorEndpoint returns the status of the specified backend // @Summary Backend monitor endpoint // @Param request body schema.BackendMonitorRequest true "Backend statistics request" // @Success 200 {object} proto.StatusResponse "Response" // @Router /backend/monitor [get] func BackendMonitorEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.BackendMonitorRequest) // Get input data from the request body if err := c.Bind(input); err != nil { return err } resp, err := bm.CheckAndSample(input.Model) if err != nil { return err } return c.JSON(200, resp) } } // BackendShutdownEndpoint shuts down the specified backend // @Summary Backend monitor endpoint // @Param request body schema.BackendMonitorRequest true "Backend statistics request" // @Router /backend/shutdown [post] func BackendShutdownEndpoint(bm *services.BackendMonitorService) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.BackendMonitorRequest) // Get input data from the request body if err := c.Bind(input); err != nil { return err } return bm.ShutdownModel(input.Model) } } ================================================ FILE: core/http/endpoints/localai/cors_proxy.go ================================================ package localai import ( "fmt" "io" "net/http" "net/url" "strings" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/xlog" ) var corsProxyClient = &http.Client{ Timeout: 10 * time.Minute, } // CORSProxyEndpoint proxies HTTP requests to external MCP servers, // solving CORS issues for browser-based MCP connections. // The target URL is passed as a query parameter: /api/cors-proxy?url=https://... func CORSProxyEndpoint(appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { targetURL := c.QueryParam("url") if targetURL == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "missing 'url' query parameter"}) } parsed, err := url.Parse(targetURL) if err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid target URL"}) } if parsed.Scheme != "http" && parsed.Scheme != "https" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "only http and https schemes are supported"}) } xlog.Debug("CORS proxy request", "method", c.Request().Method, "target", targetURL) proxyReq, err := http.NewRequestWithContext( c.Request().Context(), c.Request().Method, targetURL, c.Request().Body, ) if err != nil { return fmt.Errorf("failed to create proxy request: %w", err) } // Copy headers from the original request, excluding hop-by-hop headers skipHeaders := map[string]bool{ "Host": true, "Connection": true, "Keep-Alive": true, "Transfer-Encoding": true, "Upgrade": true, "Origin": true, "Referer": true, } for key, values := range c.Request().Header { if skipHeaders[key] { continue } for _, v := range values { proxyReq.Header.Add(key, v) } } resp, err := corsProxyClient.Do(proxyReq) if err != nil { xlog.Error("CORS proxy request failed", "error", err, "target", targetURL) return c.JSON(http.StatusBadGateway, map[string]string{"error": "proxy request failed: " + err.Error()}) } defer resp.Body.Close() // Copy response headers for key, values := range resp.Header { lower := strings.ToLower(key) // Skip CORS headers — we'll set our own if strings.HasPrefix(lower, "access-control-") { continue } for _, v := range values { c.Response().Header().Add(key, v) } } // Set CORS headers to allow browser access c.Response().Header().Set("Access-Control-Allow-Origin", "*") c.Response().Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Response().Header().Set("Access-Control-Allow-Headers", "*") c.Response().Header().Set("Access-Control-Expose-Headers", "*") c.Response().WriteHeader(resp.StatusCode) // Stream the response body _, err = io.Copy(c.Response().Writer, resp.Body) return err } } // CORSProxyOptionsEndpoint handles CORS preflight requests for the proxy. func CORSProxyOptionsEndpoint() echo.HandlerFunc { return func(c echo.Context) error { c.Response().Header().Set("Access-Control-Allow-Origin", "*") c.Response().Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") c.Response().Header().Set("Access-Control-Allow-Headers", "*") c.Response().Header().Set("Access-Control-Max-Age", "86400") return c.NoContent(http.StatusNoContent) } } ================================================ FILE: core/http/endpoints/localai/detection.go ================================================ package localai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/xlog" ) // DetectionEndpoint is the LocalAI Detection endpoint https://localai.io/docs/api-reference/detection // @Summary Detects objects in the input image. // @Param request body schema.DetectionRequest true "query params" // @Success 200 {object} schema.DetectionResponse "Response" // @Router /v1/detection [post] func DetectionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.DetectionRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } xlog.Debug("Detection", "image", input.Image, "modelFile", "modelFile", "backend", cfg.Backend) image, err := utils.GetContentURIAsBase64(input.Image) if err != nil { return err } res, err := backend.Detection(image, ml, appConfig, *cfg) if err != nil { return err } response := schema.DetectionResponse{ Detections: make([]schema.Detection, len(res.Detections)), } for i, detection := range res.Detections { response.Detections[i] = schema.Detection{ X: detection.X, Y: detection.Y, Width: detection.Width, Height: detection.Height, ClassName: detection.ClassName, } } return c.JSON(200, response) } } ================================================ FILE: core/http/endpoints/localai/edit_model.go ================================================ package localai import ( "fmt" "io" "net/http" "net/url" "os" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" httpUtils "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" "gopkg.in/yaml.v3" ) // GetEditModelPage renders the edit model page with current configuration func GetEditModelPage(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { modelName := c.Param("name") if decoded, err := url.PathUnescape(modelName); err == nil { modelName = decoded } if modelName == "" { response := ModelResponse{ Success: false, Error: "Model name is required", } return c.JSON(http.StatusBadRequest, response) } modelConfig, exists := cl.GetModelConfig(modelName) if !exists { response := ModelResponse{ Success: false, Error: "Model configuration not found", } return c.JSON(http.StatusNotFound, response) } modelConfigFile := modelConfig.GetModelConfigFile() if modelConfigFile == "" { response := ModelResponse{ Success: false, Error: "Model configuration file not found", } return c.JSON(http.StatusNotFound, response) } configData, err := os.ReadFile(modelConfigFile) if err != nil { response := ModelResponse{ Success: false, Error: "Failed to read configuration file: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Render the edit page with the current configuration templateData := struct { Title string ModelName string Config *config.ModelConfig ConfigJSON string ConfigYAML string BaseURL string Version string DisableRuntimeSettings bool }{ Title: "LocalAI - Edit Model " + modelName, ModelName: modelName, Config: &modelConfig, ConfigYAML: string(configData), BaseURL: httpUtils.BaseURL(c), Version: internal.PrintableVersion(), DisableRuntimeSettings: appConfig.DisableRuntimeSettings, } return c.Render(http.StatusOK, "views/model-editor", templateData) } } // EditModelEndpoint handles updating existing model configurations func EditModelEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { modelName := c.Param("name") if decoded, err := url.PathUnescape(modelName); err == nil { modelName = decoded } if modelName == "" { response := ModelResponse{ Success: false, Error: "Model name is required", } return c.JSON(http.StatusBadRequest, response) } modelConfig, exists := cl.GetModelConfig(modelName) if !exists { response := ModelResponse{ Success: false, Error: "Existing model configuration not found", } return c.JSON(http.StatusNotFound, response) } // Get the raw body body, err := io.ReadAll(c.Request().Body) if err != nil { response := ModelResponse{ Success: false, Error: "Failed to read request body: " + err.Error(), } return c.JSON(http.StatusBadRequest, response) } if len(body) == 0 { response := ModelResponse{ Success: false, Error: "Request body is empty", } return c.JSON(http.StatusBadRequest, response) } // Check content to see if it's a valid model config var req config.ModelConfig // Parse YAML if err := yaml.Unmarshal(body, &req); err != nil { response := ModelResponse{ Success: false, Error: "Failed to parse YAML: " + err.Error(), } return c.JSON(http.StatusBadRequest, response) } // Validate required fields if req.Name == "" { response := ModelResponse{ Success: false, Error: "Name is required", } return c.JSON(http.StatusBadRequest, response) } // Validate the configuration if valid, _ := req.Validate(); !valid { response := ModelResponse{ Success: false, Error: "Validation failed", Details: []string{"Configuration validation failed. Please check your YAML syntax and required fields."}, } return c.JSON(http.StatusBadRequest, response) } // Load the existing configuration configPath := modelConfig.GetModelConfigFile() if err := utils.VerifyPath(configPath, appConfig.SystemState.Model.ModelsPath); err != nil { response := ModelResponse{ Success: false, Error: "Model configuration not trusted: " + err.Error(), } return c.JSON(http.StatusNotFound, response) } // Write new content to file if err := os.WriteFile(configPath, body, 0644); err != nil { response := ModelResponse{ Success: false, Error: "Failed to write configuration file: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Reload configurations if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil { response := ModelResponse{ Success: false, Error: "Failed to reload configurations: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Shutdown the running model to apply new configuration (e.g., context_size) // The model will be reloaded on the next inference request if err := ml.ShutdownModel(modelName); err != nil { // Log the error but don't fail the request - the config was saved successfully // The model can still be manually reloaded or restarted fmt.Printf("Warning: Failed to shutdown model '%s': %v\n", modelName, err) } // Preload the model if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil { response := ModelResponse{ Success: false, Error: "Failed to preload model: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Return success response response := ModelResponse{ Success: true, Message: fmt.Sprintf("Model '%s' updated successfully. Model has been reloaded with new configuration.", modelName), Filename: configPath, Config: req, } return c.JSON(200, response) } } // ReloadModelsEndpoint handles reloading model configurations from disk func ReloadModelsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { // Reload configurations if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil { response := ModelResponse{ Success: false, Error: "Failed to reload configurations: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Preload the models if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil { response := ModelResponse{ Success: false, Error: "Failed to preload models: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Return success response response := ModelResponse{ Success: true, Message: "Model configurations reloaded successfully", } return c.JSON(http.StatusOK, response) } } ================================================ FILE: core/http/endpoints/localai/edit_model_test.go ================================================ package localai_test import ( "bytes" "encoding/json" "io" "net/http" "net/http/httptest" "os" "path/filepath" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/http/endpoints/localai" "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) // testRenderer is a simple renderer for tests that returns JSON type testRenderer struct{} func (t *testRenderer) Render(w io.Writer, name string, data interface{}, c echo.Context) error { // For tests, just return the data as JSON return json.NewEncoder(w).Encode(data) } var _ = Describe("Edit Model test", func() { var tempDir string BeforeEach(func() { var err error tempDir, err = os.MkdirTemp("", "localai-test") Expect(err).ToNot(HaveOccurred()) }) AfterEach(func() { os.RemoveAll(tempDir) }) Context("Edit Model endpoint", func() { It("should edit a model", func() { systemState, err := system.GetSystemState( system.WithModelPath(filepath.Join(tempDir)), ) Expect(err).ToNot(HaveOccurred()) applicationConfig := config.NewApplicationConfig( config.WithSystemState(systemState), ) //modelLoader := model.NewModelLoader(systemState, true) modelConfigLoader := config.NewModelConfigLoader(systemState.Model.ModelsPath) // Define Echo app and register all routes upfront app := echo.New() // Set up a simple renderer for the test app.Renderer = &testRenderer{} app.POST("/import-model", ImportModelEndpoint(modelConfigLoader, applicationConfig)) app.GET("/edit-model/:name", GetEditModelPage(modelConfigLoader, applicationConfig)) requestBody := bytes.NewBufferString(`{"name": "foo", "backend": "foo", "model": "foo"}`) req := httptest.NewRequest("POST", "/import-model", requestBody) req.Header.Set("Content-Type", "application/json") rec := httptest.NewRecorder() app.ServeHTTP(rec, req) body, err := io.ReadAll(rec.Body) Expect(err).ToNot(HaveOccurred()) Expect(string(body)).To(ContainSubstring("Model configuration created successfully")) Expect(rec.Code).To(Equal(http.StatusOK)) req = httptest.NewRequest("GET", "/edit-model/foo", nil) rec = httptest.NewRecorder() app.ServeHTTP(rec, req) body, err = io.ReadAll(rec.Body) Expect(err).ToNot(HaveOccurred()) // The response contains the model configuration with backend field Expect(string(body)).To(ContainSubstring(`"backend":"foo"`)) Expect(string(body)).To(ContainSubstring(`"name":"foo"`)) Expect(rec.Code).To(Equal(http.StatusOK)) }) }) }) ================================================ FILE: core/http/endpoints/localai/gallery.go ================================================ package localai import ( "encoding/json" "fmt" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/system" "github.com/mudler/xlog" ) type ModelGalleryEndpointService struct { galleries []config.Gallery backendGalleries []config.Gallery modelPath string galleryApplier *services.GalleryService configLoader *config.ModelConfigLoader } type GalleryModel struct { ID string `json:"id"` gallery.GalleryModel } func CreateModelGalleryEndpointService(galleries []config.Gallery, backendGalleries []config.Gallery, systemState *system.SystemState, galleryApplier *services.GalleryService, configLoader *config.ModelConfigLoader) ModelGalleryEndpointService { return ModelGalleryEndpointService{ galleries: galleries, backendGalleries: backendGalleries, modelPath: systemState.Model.ModelsPath, galleryApplier: galleryApplier, configLoader: configLoader, } } // GetOpStatusEndpoint returns the job status // @Summary Returns the job status // @Success 200 {object} services.GalleryOpStatus "Response" // @Router /models/jobs/{uuid} [get] func (mgs *ModelGalleryEndpointService) GetOpStatusEndpoint() echo.HandlerFunc { return func(c echo.Context) error { status := mgs.galleryApplier.GetStatus(c.Param("uuid")) if status == nil { return fmt.Errorf("could not find any status for ID") } return c.JSON(200, status) } } // GetAllStatusEndpoint returns all the jobs status progress // @Summary Returns all the jobs status progress // @Success 200 {object} map[string]services.GalleryOpStatus "Response" // @Router /models/jobs [get] func (mgs *ModelGalleryEndpointService) GetAllStatusEndpoint() echo.HandlerFunc { return func(c echo.Context) error { return c.JSON(200, mgs.galleryApplier.GetAllStatus()) } } // ApplyModelGalleryEndpoint installs a new model to a LocalAI instance from the model gallery // @Summary Install models to LocalAI. // @Param request body GalleryModel true "query params" // @Success 200 {object} schema.GalleryResponse "Response" // @Router /models/apply [post] func (mgs *ModelGalleryEndpointService) ApplyModelGalleryEndpoint() echo.HandlerFunc { return func(c echo.Context) error { input := new(GalleryModel) // Get input data from the request body if err := c.Bind(input); err != nil { return err } uuid, err := uuid.NewUUID() if err != nil { return err } mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ Req: input.GalleryModel, ID: uuid.String(), GalleryElementName: input.ID, Galleries: mgs.galleries, BackendGalleries: mgs.backendGalleries, } return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())}) } } // DeleteModelGalleryEndpoint lets delete models from a LocalAI instance // @Summary delete models to LocalAI. // @Param name path string true "Model name" // @Success 200 {object} schema.GalleryResponse "Response" // @Router /models/delete/{name} [post] func (mgs *ModelGalleryEndpointService) DeleteModelGalleryEndpoint() echo.HandlerFunc { return func(c echo.Context) error { modelName := c.Param("name") mgs.galleryApplier.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ Delete: true, GalleryElementName: modelName, } mgs.configLoader.RemoveModelConfig(modelName) uuid, err := uuid.NewUUID() if err != nil { return err } return c.JSON(200, schema.GalleryResponse{ID: uuid.String(), StatusURL: fmt.Sprintf("%smodels/jobs/%s", middleware.BaseURL(c), uuid.String())}) } } // ListModelFromGalleryEndpoint list the available models for installation from the active galleries // @Summary List installable models. // @Success 200 {object} []gallery.GalleryModel "Response" // @Router /models/available [get] func (mgs *ModelGalleryEndpointService) ListModelFromGalleryEndpoint(systemState *system.SystemState) echo.HandlerFunc { return func(c echo.Context) error { models, err := gallery.AvailableGalleryModels(mgs.galleries, systemState) if err != nil { xlog.Error("could not list models from galleries", "error", err) return err } xlog.Debug("Available models from galleries", "modelCount", len(models), "galleryCount", len(mgs.galleries)) m := []gallery.Metadata{} for _, mm := range models { m = append(m, mm.Metadata) } xlog.Debug("Models", "models", m) dat, err := json.Marshal(m) if err != nil { return fmt.Errorf("could not marshal models: %w", err) } return c.Blob(200, "application/json", dat) } } // ListModelGalleriesEndpoint list the available galleries configured in LocalAI // @Summary List all Galleries // @Success 200 {object} []config.Gallery "Response" // @Router /models/galleries [get] // NOTE: This is different (and much simpler!) than above! This JUST lists the model galleries that have been loaded, not their contents! func (mgs *ModelGalleryEndpointService) ListModelGalleriesEndpoint() echo.HandlerFunc { return func(c echo.Context) error { xlog.Debug("Listing model galleries", "galleries", mgs.galleries) dat, err := json.Marshal(mgs.galleries) if err != nil { return err } return c.Blob(200, "application/json", dat) } } ================================================ FILE: core/http/endpoints/localai/get_token_metrics.go ================================================ package localai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/xlog" "github.com/mudler/LocalAI/pkg/model" ) // TODO: This is not yet in use. Needs middleware rework, since it is not referenced. // TokenMetricsEndpoint is an endpoint to get TokensProcessed Per Second for Active SlotID // // @Summary Get TokenMetrics for Active Slot. // @Accept json // @Produce audio/x-wav // @Success 200 {string} binary "generated audio/wav file" // @Router /v1/tokenMetrics [get] // @Router /tokenMetrics [get] func TokenMetricsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.TokenMetricsRequest) // Get input data from the request body if err := c.Bind(input); err != nil { return err } modelFile, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_NAME).(string) if !ok || modelFile != "" { modelFile = input.Model xlog.Warn("Model not found in context", "model", input.Model) } cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(modelFile, appConfig) if err != nil { xlog.Error("Error loading model config", "error", err) modelFile = input.Model xlog.Warn("Model not found in context", "model", input.Model) } else { modelFile = cfg.Model } xlog.Debug("Token Metrics for model", "model", modelFile) response, err := backend.TokenMetrics(modelFile, ml, appConfig, *cfg) if err != nil { return err } return c.JSON(200, response) } } ================================================ FILE: core/http/endpoints/localai/import_model.go ================================================ package localai import ( "context" "encoding/json" "fmt" "io" "net/http" "os" "path/filepath" "strings" "time" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery/importers" httpUtils "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/LocalAI/pkg/vram" "gopkg.in/yaml.v3" ) // ImportModelURIEndpoint handles creating new model configurations from a URI func ImportModelURIEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, galleryService *services.GalleryService, opcache *services.OpCache) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.ImportModelRequest) if err := c.Bind(input); err != nil { return err } modelConfig, err := importers.DiscoverModelConfig(input.URI, input.Preferences) if err != nil { return fmt.Errorf("failed to discover model config: %w", err) } resp := schema.GalleryResponse{ StatusURL: fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), ""), } if len(modelConfig.Files) > 0 { files := make([]vram.FileInput, 0, len(modelConfig.Files)) for _, f := range modelConfig.Files { files = append(files, vram.FileInput{URI: f.URI, Size: 0}) } estCtx, cancel := context.WithTimeout(c.Request().Context(), 5*time.Second) defer cancel() opts := vram.EstimateOptions{ContextLength: 8192} result, err := vram.Estimate(estCtx, files, opts, vram.DefaultCachedSizeResolver(), vram.DefaultCachedGGUFReader()) if err == nil { if result.SizeBytes > 0 { resp.EstimatedSizeBytes = result.SizeBytes resp.EstimatedSizeDisplay = result.SizeDisplay } if result.VRAMBytes > 0 { resp.EstimatedVRAMBytes = result.VRAMBytes resp.EstimatedVRAMDisplay = result.VRAMDisplay } } } uuid, err := uuid.NewUUID() if err != nil { return err } // Determine gallery ID for tracking - use model name if available, otherwise use URI galleryID := input.URI if modelConfig.Name != "" { galleryID = modelConfig.Name } // Register operation in opcache if available (for UI progress tracking) if opcache != nil { opcache.Set(galleryID, uuid.String()) } galleryService.ModelGalleryChannel <- services.GalleryOp[gallery.GalleryModel, gallery.ModelConfig]{ Req: gallery.GalleryModel{ Overrides: map[string]interface{}{}, }, ID: uuid.String(), GalleryElementName: galleryID, GalleryElement: &modelConfig, BackendGalleries: appConfig.BackendGalleries, } resp.ID = uuid.String() resp.StatusURL = fmt.Sprintf("%smodels/jobs/%s", httpUtils.BaseURL(c), uuid.String()) return c.JSON(200, resp) } } // ImportModelEndpoint handles creating new model configurations func ImportModelEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { // Get the raw body body, err := io.ReadAll(c.Request().Body) if err != nil { response := ModelResponse{ Success: false, Error: "Failed to read request body: " + err.Error(), } return c.JSON(http.StatusBadRequest, response) } if len(body) == 0 { response := ModelResponse{ Success: false, Error: "Request body is empty", } return c.JSON(http.StatusBadRequest, response) } // Check content type to determine how to parse contentType := c.Request().Header.Get("Content-Type") var modelConfig config.ModelConfig if strings.Contains(contentType, "application/json") { // Parse JSON if err := json.Unmarshal(body, &modelConfig); err != nil { response := ModelResponse{ Success: false, Error: "Failed to parse JSON: " + err.Error(), } return c.JSON(http.StatusBadRequest, response) } } else if strings.Contains(contentType, "application/x-yaml") || strings.Contains(contentType, "text/yaml") { // Parse YAML if err := yaml.Unmarshal(body, &modelConfig); err != nil { response := ModelResponse{ Success: false, Error: "Failed to parse YAML: " + err.Error(), } return c.JSON(http.StatusBadRequest, response) } } else { // Try to auto-detect format if len(body) > 0 && strings.TrimSpace(string(body))[0] == '{' { // Looks like JSON if err := json.Unmarshal(body, &modelConfig); err != nil { response := ModelResponse{ Success: false, Error: "Failed to parse JSON: " + err.Error(), } return c.JSON(http.StatusBadRequest, response) } } else { // Assume YAML if err := yaml.Unmarshal(body, &modelConfig); err != nil { response := ModelResponse{ Success: false, Error: "Failed to parse YAML: " + err.Error(), } return c.JSON(http.StatusBadRequest, response) } } } // Validate required fields if modelConfig.Name == "" { response := ModelResponse{ Success: false, Error: "Name is required", } return c.JSON(http.StatusBadRequest, response) } // Set defaults modelConfig.SetDefaults(appConfig.ToConfigLoaderOptions()...) // Validate the configuration if valid, _ := modelConfig.Validate(); !valid { response := ModelResponse{ Success: false, Error: "Invalid configuration", } return c.JSON(http.StatusBadRequest, response) } // Create the configuration file configPath := filepath.Join(appConfig.SystemState.Model.ModelsPath, modelConfig.Name+".yaml") if err := utils.VerifyPath(modelConfig.Name+".yaml", appConfig.SystemState.Model.ModelsPath); err != nil { response := ModelResponse{ Success: false, Error: "Model path not trusted: " + err.Error(), } return c.JSON(http.StatusBadRequest, response) } // Marshal to YAML for storage yamlData, err := yaml.Marshal(&modelConfig) if err != nil { response := ModelResponse{ Success: false, Error: "Failed to marshal configuration: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Write the file if err := os.WriteFile(configPath, yamlData, 0644); err != nil { response := ModelResponse{ Success: false, Error: "Failed to write configuration file: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Reload configurations if err := cl.LoadModelConfigsFromPath(appConfig.SystemState.Model.ModelsPath, appConfig.ToConfigLoaderOptions()...); err != nil { response := ModelResponse{ Success: false, Error: "Failed to reload configurations: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Preload the model if err := cl.Preload(appConfig.SystemState.Model.ModelsPath); err != nil { response := ModelResponse{ Success: false, Error: "Failed to preload model: " + err.Error(), } return c.JSON(http.StatusInternalServerError, response) } // Return success response response := ModelResponse{ Success: true, Message: "Model configuration created successfully", Filename: filepath.Base(configPath), } return c.JSON(200, response) } } ================================================ FILE: core/http/endpoints/localai/localai_suite_test.go ================================================ package localai_test import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestLocalAIEndpoints(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "LocalAI Endpoints test suite") } ================================================ FILE: core/http/endpoints/localai/mcp.go ================================================ package localai import ( "fmt" "strings" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/model" ) // MCP SSE Event Types (kept for backward compatibility with MCP endpoint consumers) type MCPReasoningEvent struct { Type string `json:"type"` Content string `json:"content"` } type MCPToolCallEvent struct { Type string `json:"type"` Name string `json:"name"` Arguments map[string]interface{} `json:"arguments"` Reasoning string `json:"reasoning"` } type MCPToolResultEvent struct { Type string `json:"type"` Name string `json:"name"` Result string `json:"result"` } type MCPStatusEvent struct { Type string `json:"type"` Message string `json:"message"` } type MCPAssistantEvent struct { Type string `json:"type"` Content string `json:"content"` } type MCPErrorEvent struct { Type string `json:"type"` Message string `json:"message"` } // MCPEndpoint is the endpoint for MCP chat completions. // It enables all MCP servers for the model and delegates to the standard chat endpoint, // which handles MCP tool injection and server-side execution. // Both streaming and non-streaming modes use standard OpenAI response format. // @Summary MCP chat completions with automatic tool execution // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/mcp/chat/completions [post] func MCPEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { chatHandler := openai.ChatEndpoint(cl, ml, evaluator, appConfig) return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } modelConfig, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || modelConfig == nil { return echo.ErrBadRequest } if modelConfig.MCP.Servers == "" && modelConfig.MCP.Stdio == "" { return fmt.Errorf("no MCP servers configured") } // Enable all MCP servers if none explicitly specified (preserve original behavior) if input.Metadata == nil { input.Metadata = map[string]string{} } if _, hasMCP := input.Metadata["mcp_servers"]; !hasMCP { remote, stdio, err := modelConfig.MCP.MCPConfigFromYAML() if err != nil { return fmt.Errorf("failed to get MCP config: %w", err) } var allServers []string for name := range remote.Servers { allServers = append(allServers, name) } for name := range stdio.Servers { allServers = append(allServers, name) } input.Metadata["mcp_servers"] = strings.Join(allServers, ",") } // Delegate to the standard chat endpoint which handles MCP tool // injection and server-side execution for both streaming and non-streaming. return chatHandler(c) } } ================================================ FILE: core/http/endpoints/localai/mcp_prompts.go ================================================ package localai import ( "fmt" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" ) // MCPPromptsEndpoint returns the list of MCP prompts for a given model. // GET /v1/mcp/prompts/:model func MCPPromptsEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { modelName := c.Param("model") if modelName == "" { return echo.ErrBadRequest } cfg, exists := cl.GetModelConfig(modelName) if !exists { return fmt.Errorf("model %q not found", modelName) } if cfg.MCP.Servers == "" && cfg.MCP.Stdio == "" { return c.JSON(200, []any{}) } remote, stdio, err := cfg.MCP.MCPConfigFromYAML() if err != nil { return fmt.Errorf("failed to parse MCP config: %w", err) } namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) if err != nil { return fmt.Errorf("failed to get MCP sessions: %w", err) } prompts, err := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions) if err != nil { return fmt.Errorf("failed to discover MCP prompts: %w", err) } type promptArgJSON struct { Name string `json:"name"` Description string `json:"description,omitempty"` Required bool `json:"required,omitempty"` } type promptJSON struct { Name string `json:"name"` Description string `json:"description,omitempty"` Title string `json:"title,omitempty"` Arguments []promptArgJSON `json:"arguments,omitempty"` Server string `json:"server"` } var result []promptJSON for _, p := range prompts { pj := promptJSON{ Name: p.PromptName, Description: p.Description, Title: p.Title, Server: p.ServerName, } for _, arg := range p.Arguments { pj.Arguments = append(pj.Arguments, promptArgJSON{ Name: arg.Name, Description: arg.Description, Required: arg.Required, }) } result = append(result, pj) } return c.JSON(200, result) } } // MCPGetPromptEndpoint expands a prompt by name with the given arguments. // POST /v1/mcp/prompts/:model/:prompt func MCPGetPromptEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { modelName := c.Param("model") promptName := c.Param("prompt") if modelName == "" || promptName == "" { return echo.ErrBadRequest } cfg, exists := cl.GetModelConfig(modelName) if !exists { return fmt.Errorf("model %q not found", modelName) } if cfg.MCP.Servers == "" && cfg.MCP.Stdio == "" { return fmt.Errorf("no MCP servers configured for model %q", modelName) } var req struct { Arguments map[string]string `json:"arguments"` } if err := c.Bind(&req); err != nil { return echo.ErrBadRequest } remote, stdio, err := cfg.MCP.MCPConfigFromYAML() if err != nil { return fmt.Errorf("failed to parse MCP config: %w", err) } namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) if err != nil { return fmt.Errorf("failed to get MCP sessions: %w", err) } prompts, err := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions) if err != nil { return fmt.Errorf("failed to discover MCP prompts: %w", err) } messages, err := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, promptName, req.Arguments) if err != nil { return fmt.Errorf("failed to get prompt: %w", err) } type messageJSON struct { Role string `json:"role"` Content string `json:"content"` } var result []messageJSON for _, m := range messages { result = append(result, messageJSON{ Role: string(m.Role), Content: mcpTools.PromptMessageToText(m), }) } return c.JSON(200, map[string]any{ "messages": result, }) } } ================================================ FILE: core/http/endpoints/localai/mcp_resources.go ================================================ package localai import ( "fmt" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" ) // MCPResourcesEndpoint returns the list of MCP resources for a given model. // GET /v1/mcp/resources/:model func MCPResourcesEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { modelName := c.Param("model") if modelName == "" { return echo.ErrBadRequest } cfg, exists := cl.GetModelConfig(modelName) if !exists { return fmt.Errorf("model %q not found", modelName) } if cfg.MCP.Servers == "" && cfg.MCP.Stdio == "" { return c.JSON(200, []any{}) } remote, stdio, err := cfg.MCP.MCPConfigFromYAML() if err != nil { return fmt.Errorf("failed to parse MCP config: %w", err) } namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) if err != nil { return fmt.Errorf("failed to get MCP sessions: %w", err) } resources, err := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions) if err != nil { return fmt.Errorf("failed to discover MCP resources: %w", err) } type resourceJSON struct { Name string `json:"name"` URI string `json:"uri"` Description string `json:"description,omitempty"` MIMEType string `json:"mimeType,omitempty"` Server string `json:"server"` } var result []resourceJSON for _, r := range resources { result = append(result, resourceJSON{ Name: r.Name, URI: r.URI, Description: r.Description, MIMEType: r.MIMEType, Server: r.ServerName, }) } return c.JSON(200, result) } } // MCPReadResourceEndpoint reads a specific MCP resource by URI. // POST /v1/mcp/resources/:model/read func MCPReadResourceEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { modelName := c.Param("model") if modelName == "" { return echo.ErrBadRequest } cfg, exists := cl.GetModelConfig(modelName) if !exists { return fmt.Errorf("model %q not found", modelName) } if cfg.MCP.Servers == "" && cfg.MCP.Stdio == "" { return fmt.Errorf("no MCP servers configured for model %q", modelName) } var req struct { URI string `json:"uri"` } if err := c.Bind(&req); err != nil || req.URI == "" { return echo.ErrBadRequest } remote, stdio, err := cfg.MCP.MCPConfigFromYAML() if err != nil { return fmt.Errorf("failed to parse MCP config: %w", err) } namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) if err != nil { return fmt.Errorf("failed to get MCP sessions: %w", err) } resources, err := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions) if err != nil { return fmt.Errorf("failed to discover MCP resources: %w", err) } content, err := mcpTools.ReadMCPResource(c.Request().Context(), resources, req.URI) if err != nil { return fmt.Errorf("failed to read resource: %w", err) } // Find the resource info for mimeType mimeType := "" for _, r := range resources { if r.URI == req.URI { mimeType = r.MIMEType break } } return c.JSON(200, map[string]any{ "uri": req.URI, "content": content, "mimeType": mimeType, }) } } ================================================ FILE: core/http/endpoints/localai/mcp_tools.go ================================================ package localai import ( "fmt" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/http/middleware" ) // MCPServersEndpoint returns the list of MCP servers and their tools for a given model. // GET /v1/mcp/servers/:model func MCPServersEndpoint(cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { modelName := c.Param("model") if modelName == "" { return echo.ErrBadRequest } cfg, exists := cl.GetModelConfig(modelName) if !exists { return fmt.Errorf("model %q not found", modelName) } if cfg.MCP.Servers == "" && cfg.MCP.Stdio == "" { return c.JSON(200, map[string]any{ "model": modelName, "servers": []any{}, }) } remote, stdio, err := cfg.MCP.MCPConfigFromYAML() if err != nil { return fmt.Errorf("failed to parse MCP config: %w", err) } namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) if err != nil { return fmt.Errorf("failed to get MCP sessions: %w", err) } servers, err := mcpTools.ListMCPServers(c.Request().Context(), namedSessions) if err != nil { return fmt.Errorf("failed to list MCP servers: %w", err) } return c.JSON(200, map[string]any{ "model": modelName, "servers": servers, }) } } // MCPServersEndpointFromMiddleware is a version that uses the middleware-resolved model config. // This allows it to use the same middleware chain as other endpoints. func MCPServersEndpointFromMiddleware() echo.HandlerFunc { return func(c echo.Context) error { cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } if cfg.MCP.Servers == "" && cfg.MCP.Stdio == "" { return c.JSON(200, map[string]any{ "model": cfg.Name, "servers": []any{}, }) } remote, stdio, err := cfg.MCP.MCPConfigFromYAML() if err != nil { return fmt.Errorf("failed to parse MCP config: %w", err) } namedSessions, err := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) if err != nil { return fmt.Errorf("failed to get MCP sessions: %w", err) } servers, err := mcpTools.ListMCPServers(c.Request().Context(), namedSessions) if err != nil { return fmt.Errorf("failed to list MCP servers: %w", err) } return c.JSON(200, map[string]any{ "model": cfg.Name, "servers": servers, }) } } ================================================ FILE: core/http/endpoints/localai/metrics.go ================================================ package localai import ( "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/services" "github.com/prometheus/client_golang/prometheus/promhttp" ) // LocalAIMetricsEndpoint returns the metrics endpoint for LocalAI // @Summary Prometheus metrics endpoint // @Param request body config.Gallery true "Gallery details" // @Router /metrics [get] func LocalAIMetricsEndpoint() echo.HandlerFunc { return echo.WrapHandler(promhttp.Handler()) } type apiMiddlewareConfig struct { Filter func(c echo.Context) bool metricsService *services.LocalAIMetricsService } func LocalAIMetricsAPIMiddleware(metrics *services.LocalAIMetricsService) echo.MiddlewareFunc { cfg := apiMiddlewareConfig{ metricsService: metrics, Filter: func(c echo.Context) bool { return c.Path() == "/metrics" }, } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if cfg.Filter != nil && cfg.Filter(c) { return next(c) } path := c.Path() method := c.Request().Method start := time.Now() err := next(c) elapsed := float64(time.Since(start)) / float64(time.Second) cfg.metricsService.ObserveAPICall(method, path, elapsed) return err } } } ================================================ FILE: core/http/endpoints/localai/p2p.go ================================================ package localai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/schema" ) // ShowP2PNodes returns the P2P Nodes // @Summary Returns available P2P nodes // @Success 200 {object} []schema.P2PNodesResponse "Response" // @Router /api/p2p [get] func ShowP2PNodes(appConfig *config.ApplicationConfig) echo.HandlerFunc { // Render index return func(c echo.Context) error { return c.JSON(200, schema.P2PNodesResponse{ LlamaCPPNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.LlamaCPPWorkerID)), FederatedNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.FederatedID)), MLXNodes: p2p.GetAvailableNodes(p2p.NetworkID(appConfig.P2PNetworkID, p2p.MLXWorkerID)), }) } } // ShowP2PToken returns the P2P token // @Summary Show the P2P token // @Success 200 {string} string "Response" // @Router /api/p2p/token [get] func ShowP2PToken(appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { return c.String(200, appConfig.P2PToken) } } ================================================ FILE: core/http/endpoints/localai/settings.go ================================================ package localai import ( "encoding/json" "io" "net/http" "os" "path/filepath" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/openresponses" "github.com/mudler/LocalAI/core/p2p" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/xlog" ) // GetSettingsEndpoint returns current settings with precedence (env > file > defaults) func GetSettingsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { appConfig := app.ApplicationConfig() settings := appConfig.ToRuntimeSettings() return c.JSON(http.StatusOK, settings) } } // UpdateSettingsEndpoint updates settings, saves to file, and applies immediately func UpdateSettingsEndpoint(app *application.Application) echo.HandlerFunc { return func(c echo.Context) error { appConfig := app.ApplicationConfig() startupConfig := app.StartupConfig() if startupConfig == nil { startupConfig = appConfig } body, err := io.ReadAll(c.Request().Body) if err != nil { return c.JSON(http.StatusBadRequest, schema.SettingsResponse{ Success: false, Error: "Failed to read request body: " + err.Error(), }) } var settings config.RuntimeSettings if err := json.Unmarshal(body, &settings); err != nil { return c.JSON(http.StatusBadRequest, schema.SettingsResponse{ Success: false, Error: "Failed to parse JSON: " + err.Error(), }) } // Validate timeouts if provided if settings.WatchdogIdleTimeout != nil { if _, err := time.ParseDuration(*settings.WatchdogIdleTimeout); err != nil { return c.JSON(http.StatusBadRequest, schema.SettingsResponse{ Success: false, Error: "Invalid watchdog_idle_timeout format: " + err.Error(), }) } } if settings.WatchdogBusyTimeout != nil { if _, err := time.ParseDuration(*settings.WatchdogBusyTimeout); err != nil { return c.JSON(http.StatusBadRequest, schema.SettingsResponse{ Success: false, Error: "Invalid watchdog_busy_timeout format: " + err.Error(), }) } } if settings.WatchdogInterval != nil { if _, err := time.ParseDuration(*settings.WatchdogInterval); err != nil { return c.JSON(http.StatusBadRequest, schema.SettingsResponse{ Success: false, Error: "Invalid watchdog_interval format: " + err.Error(), }) } } if settings.LRUEvictionRetryInterval != nil { if _, err := time.ParseDuration(*settings.LRUEvictionRetryInterval); err != nil { return c.JSON(http.StatusBadRequest, schema.SettingsResponse{ Success: false, Error: "Invalid lru_eviction_retry_interval format: " + err.Error(), }) } } if settings.OpenResponsesStoreTTL != nil { if *settings.OpenResponsesStoreTTL != "0" && *settings.OpenResponsesStoreTTL != "" { if _, err := time.ParseDuration(*settings.OpenResponsesStoreTTL); err != nil { return c.JSON(http.StatusBadRequest, schema.SettingsResponse{ Success: false, Error: "Invalid open_responses_store_ttl format: " + err.Error(), }) } } } // Generate P2P token before saving so the real token is persisted (not "0") if settings.P2PToken != nil && *settings.P2PToken == "0" { token := p2p.GenerateToken(60, 60) settings.P2PToken = &token } // Save to file if appConfig.DynamicConfigsDir == "" { return c.JSON(http.StatusBadRequest, schema.SettingsResponse{ Success: false, Error: "DynamicConfigsDir is not set", }) } settingsFile := filepath.Join(appConfig.DynamicConfigsDir, "runtime_settings.json") settingsJSON, err := json.MarshalIndent(settings, "", " ") if err != nil { return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{ Success: false, Error: "Failed to marshal settings: " + err.Error(), }) } if err := os.WriteFile(settingsFile, settingsJSON, 0600); err != nil { return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{ Success: false, Error: "Failed to write settings file: " + err.Error(), }) } // Apply settings using centralized method watchdogChanged := appConfig.ApplyRuntimeSettings(&settings) // Handle API keys specially (merge with startup keys) if settings.ApiKeys != nil { envKeys := startupConfig.ApiKeys runtimeKeys := *settings.ApiKeys appConfig.ApiKeys = append(envKeys, runtimeKeys...) } // Update backend logging dynamically if settings.EnableBackendLogging != nil { app.ModelLoader().SetBackendLoggingEnabled(*settings.EnableBackendLogging) xlog.Info("Updated backend logging setting", "enableBackendLogging", *settings.EnableBackendLogging) } // Update watchdog dynamically for settings that don't require restart if settings.ForceEvictionWhenBusy != nil { currentWD := app.ModelLoader().GetWatchDog() if currentWD != nil { currentWD.SetForceEvictionWhenBusy(*settings.ForceEvictionWhenBusy) xlog.Info("Updated watchdog force eviction when busy setting", "forceEvictionWhenBusy", *settings.ForceEvictionWhenBusy) } } // Update ModelLoader LRU eviction retry settings dynamically maxRetries := appConfig.LRUEvictionMaxRetries retryInterval := appConfig.LRUEvictionRetryInterval if settings.LRUEvictionMaxRetries != nil { maxRetries = *settings.LRUEvictionMaxRetries } if settings.LRUEvictionRetryInterval != nil { if dur, err := time.ParseDuration(*settings.LRUEvictionRetryInterval); err == nil { retryInterval = dur } } if settings.LRUEvictionMaxRetries != nil || settings.LRUEvictionRetryInterval != nil { app.ModelLoader().SetLRUEvictionRetrySettings(maxRetries, retryInterval) xlog.Info("Updated LRU eviction retry settings", "maxRetries", maxRetries, "retryInterval", retryInterval) } // Update Open Responses store TTL dynamically if settings.OpenResponsesStoreTTL != nil { ttl := time.Duration(0) if *settings.OpenResponsesStoreTTL != "0" && *settings.OpenResponsesStoreTTL != "" { if dur, err := time.ParseDuration(*settings.OpenResponsesStoreTTL); err == nil { ttl = dur } else { xlog.Warn("Invalid Open Responses store TTL format", "ttl", *settings.OpenResponsesStoreTTL, "error", err) } } // Import the store package store := openresponses.GetGlobalStore() store.SetTTL(ttl) xlog.Info("Updated Open Responses store TTL", "ttl", ttl) } // Check if agent job retention changed agentJobChanged := settings.AgentJobRetentionDays != nil // Restart watchdog if settings changed if watchdogChanged { if settings.WatchdogEnabled != nil && !*settings.WatchdogEnabled { if err := app.StopWatchdog(); err != nil { xlog.Error("Failed to stop watchdog", "error", err) return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{ Success: false, Error: "Settings saved but failed to stop watchdog: " + err.Error(), }) } } else { if err := app.RestartWatchdog(); err != nil { xlog.Error("Failed to restart watchdog", "error", err) return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{ Success: false, Error: "Settings saved but failed to restart watchdog: " + err.Error(), }) } } } // Restart agent job service if retention days changed if agentJobChanged { if err := app.RestartAgentJobService(); err != nil { xlog.Error("Failed to restart agent job service", "error", err) return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{ Success: false, Error: "Settings saved but failed to restart agent job service: " + err.Error(), }) } } // Restart P2P if P2P settings changed p2pChanged := settings.P2PToken != nil || settings.P2PNetworkID != nil || settings.Federated != nil if p2pChanged { if settings.P2PToken != nil && *settings.P2PToken == "" { if err := app.StopP2P(); err != nil { xlog.Error("Failed to stop P2P", "error", err) return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{ Success: false, Error: "Settings saved but failed to stop P2P: " + err.Error(), }) } } else { if err := app.RestartP2P(); err != nil { xlog.Error("Failed to restart P2P", "error", err) return c.JSON(http.StatusInternalServerError, schema.SettingsResponse{ Success: false, Error: "Settings saved but failed to restart P2P: " + err.Error(), }) } } } return c.JSON(http.StatusOK, schema.SettingsResponse{ Success: true, Message: "Settings updated successfully", }) } } ================================================ FILE: core/http/endpoints/localai/stores.go ================================================ package localai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/store" ) func StoresSetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.StoresSet) if err := c.Bind(input); err != nil { return err } sb, err := backend.StoreBackend(sl, appConfig, input.Store, input.Backend) if err != nil { return err } vals := make([][]byte, len(input.Values)) for i, v := range input.Values { vals[i] = []byte(v) } err = store.SetCols(c.Request().Context(), sb, input.Keys, vals) if err != nil { return err } return c.NoContent(200) } } func StoresDeleteEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.StoresDelete) if err := c.Bind(input); err != nil { return err } sb, err := backend.StoreBackend(sl, appConfig, input.Store, input.Backend) if err != nil { return err } if err := store.DeleteCols(c.Request().Context(), sb, input.Keys); err != nil { return err } return c.NoContent(200) } } func StoresGetEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.StoresGet) if err := c.Bind(input); err != nil { return err } sb, err := backend.StoreBackend(sl, appConfig, input.Store, input.Backend) if err != nil { return err } keys, vals, err := store.GetCols(c.Request().Context(), sb, input.Keys) if err != nil { return err } res := schema.StoresGetResponse{ Keys: keys, Values: make([]string, len(vals)), } for i, v := range vals { res.Values[i] = string(v) } return c.JSON(200, res) } } func StoresFindEndpoint(sl *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input := new(schema.StoresFind) if err := c.Bind(input); err != nil { return err } sb, err := backend.StoreBackend(sl, appConfig, input.Store, input.Backend) if err != nil { return err } keys, vals, similarities, err := store.Find(c.Request().Context(), sb, input.Key, input.Topk) if err != nil { return err } res := schema.StoresFindResponse{ Keys: keys, Values: make([]string, len(vals)), Similarities: similarities, } for i, v := range vals { res.Values[i] = string(v) } return c.JSON(200, res) } } ================================================ FILE: core/http/endpoints/localai/system.go ================================================ package localai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" ) // SystemInformations returns the system informations // @Summary Show the LocalAI instance information // @Success 200 {object} schema.SystemInformationResponse "Response" // @Router /system [get] func SystemInformations(ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { availableBackends := []string{} loadedModels := ml.ListLoadedModels() for b := range appConfig.ExternalGRPCBackends { availableBackends = append(availableBackends, b) } for b := range ml.GetAllExternalBackends(nil) { availableBackends = append(availableBackends, b) } sysmodels := []schema.SysInfoModel{} for _, m := range loadedModels { sysmodels = append(sysmodels, schema.SysInfoModel{ID: m.ID}) } return c.JSON(200, schema.SystemInformationResponse{ Backends: availableBackends, Models: sysmodels, }, ) } } ================================================ FILE: core/http/endpoints/localai/tokenize.go ================================================ package localai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" ) // TokenizeEndpoint exposes a REST API to tokenize the content // @Summary Tokenize the input. // @Param request body schema.TokenizeRequest true "Request" // @Success 200 {object} schema.TokenizeResponse "Response" // @Router /v1/tokenize [post] func TokenizeEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TokenizeRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } tokenResponse, err := backend.ModelTokenize(input.Content, ml, *cfg, appConfig) if err != nil { return err } return c.JSON(200, tokenResponse) } } ================================================ FILE: core/http/endpoints/localai/tts.go ================================================ package localai import ( "path/filepath" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/audio" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/xlog" ) // TTSEndpoint is the OpenAI Speech API endpoint https://platform.openai.com/docs/api-reference/audio/createSpeech // // @Summary Generates audio from the input text. // @Accept json // @Produce audio/x-wav // @Param request body schema.TTSRequest true "query params" // @Success 200 {string} binary "generated audio/wav file" // @Router /v1/audio/speech [post] // @Router /tts [post] func TTSEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.TTSRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } xlog.Debug("LocalAI TTS Request received", "model", input.Model) if cfg.Backend == "" && input.Backend != "" { cfg.Backend = input.Backend } if input.Language != "" { cfg.Language = input.Language } if input.Voice != "" { cfg.Voice = input.Voice } // Handle streaming TTS if input.Stream { // Set headers for streaming audio c.Response().Header().Set("Content-Type", "audio/wav") c.Response().Header().Set("Transfer-Encoding", "chunked") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") // Stream audio chunks as they're generated err := backend.ModelTTSStream(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg, func(audioChunk []byte) error { _, writeErr := c.Response().Write(audioChunk) if writeErr != nil { return writeErr } c.Response().Flush() return nil }) if err != nil { return err } return nil } // Non-streaming TTS (existing behavior) filePath, _, err := backend.ModelTTS(input.Input, cfg.Voice, cfg.Language, ml, appConfig, *cfg) if err != nil { return err } // Resample to requested sample rate if specified if input.SampleRate > 0 { filePath, err = utils.AudioResample(filePath, input.SampleRate) if err != nil { return err } } // Convert generated file to target format filePath, err = utils.AudioConvert(filePath, input.Format) if err != nil { return err } filePath, contentType := audio.NormalizeAudioFile(filePath) if contentType != "" { c.Response().Header().Set("Content-Type", contentType) } return c.Attachment(filePath, filepath.Base(filePath)) } } ================================================ FILE: core/http/endpoints/localai/types.go ================================================ package localai // ModelResponse represents the common response structure for model operations type ModelResponse struct { Success bool `json:"success"` Message string `json:"message"` Filename string `json:"filename,omitempty"` Config interface{} `json:"config,omitempty"` Error string `json:"error,omitempty"` Details []string `json:"details,omitempty"` } ================================================ FILE: core/http/endpoints/localai/vad.go ================================================ package localai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // VADEndpoint is Voice-Activation-Detection endpoint // @Summary Detect voice fragments in an audio stream // @Accept json // @Param request body schema.VADRequest true "query params" // @Success 200 {object} proto.VADResponse "Response" // @Router /vad [post] func VADEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VADRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } xlog.Debug("LocalAI VAD Request received", "model", input.Model) resp, err := backend.VAD(input, c.Request().Context(), ml, appConfig, *cfg) if err != nil { return err } return c.JSON(200, resp) } } ================================================ FILE: core/http/endpoints/localai/video.go ================================================ package localai import ( "bufio" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "strings" "time" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/backend" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) func downloadFile(url string) (string, error) { // Get the data resp, err := http.Get(url) if err != nil { return "", err } defer resp.Body.Close() // Create the file out, err := os.CreateTemp("", "video") if err != nil { return "", err } defer out.Close() // Write the body to file _, err = io.Copy(out, resp.Body) return out.Name(), err } // /* * curl http://localhost:8080/v1/images/generations \ -H "Content-Type: application/json" \ -d '{ "prompt": "A cute baby sea otter", "n": 1, "size": "512x512" }' * */ // VideoEndpoint // @Summary Creates a video given a prompt. // @Param request body schema.VideoRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /video [post] func VideoEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.VideoRequest) if !ok || input.Model == "" { xlog.Error("Video Endpoint - Invalid Input") return echo.ErrBadRequest } config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { xlog.Error("Video Endpoint - Invalid Config") return echo.ErrBadRequest } src := "" if input.StartImage != "" { var fileData []byte var err error // check if input.File is an URL, if so download it and save it // to a temporary file if strings.HasPrefix(input.StartImage, "http://") || strings.HasPrefix(input.StartImage, "https://") { out, err := downloadFile(input.StartImage) if err != nil { return fmt.Errorf("failed downloading file:%w", err) } defer os.RemoveAll(out) fileData, err = os.ReadFile(out) if err != nil { return fmt.Errorf("failed reading file:%w", err) } } else { // base 64 decode the file and write it somewhere // that we will cleanup fileData, err = base64.StdEncoding.DecodeString(input.StartImage) if err != nil { return err } } // Create a temporary file outputFile, err := os.CreateTemp(appConfig.GeneratedContentDir, "b64") if err != nil { return err } // write the base64 result writer := bufio.NewWriter(outputFile) _, err = writer.Write(fileData) if err != nil { outputFile.Close() return err } outputFile.Close() src = outputFile.Name() defer os.RemoveAll(src) } xlog.Debug("Parameter Config", "config", config) switch config.Backend { case "stablediffusion": config.Backend = model.StableDiffusionGGMLBackend case "": config.Backend = model.StableDiffusionGGMLBackend } width := input.Width height := input.Height if width == 0 { width = 512 } if height == 0 { height = 512 } b64JSON := input.ResponseFormat == "b64_json" tempDir := "" if !b64JSON { tempDir = filepath.Join(appConfig.GeneratedContentDir, "videos") } // Create a temporary file outputFile, err := os.CreateTemp(tempDir, "b64") if err != nil { return err } outputFile.Close() // TODO: use mime type to determine the extension output := outputFile.Name() + ".mp4" // Rename the temporary file err = os.Rename(outputFile.Name(), output) if err != nil { return err } baseURL := middleware.BaseURL(c) xlog.Debug("VideoEndpoint: Calling VideoGeneration", "num_frames", input.NumFrames, "fps", input.FPS, "cfg_scale", input.CFGScale, "step", input.Step, "seed", input.Seed, "width", width, "height", height, "negative_prompt", input.NegativePrompt) fn, err := backend.VideoGeneration( height, width, input.Prompt, input.NegativePrompt, src, input.EndImage, output, input.NumFrames, input.FPS, input.Seed, input.CFGScale, input.Step, ml, *config, appConfig, ) if err != nil { return err } if err := fn(); err != nil { return err } item := &schema.Item{} if b64JSON { defer os.RemoveAll(output) data, err := os.ReadFile(output) if err != nil { return err } item.B64JSON = base64.StdEncoding.EncodeToString(data) } else { base := filepath.Base(output) item.URL, err = url.JoinPath(baseURL, "generated-videos", base) if err != nil { return err } } id := uuid.New().String() created := int(time.Now().Unix()) resp := &schema.OpenAIResponse{ ID: id, Created: created, Data: []schema.Item{*item}, } jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) // Return the prediction in the response body return c.JSON(200, resp) } } ================================================ FILE: core/http/endpoints/localai/welcome.go ================================================ package localai import ( "strings" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/internal" "github.com/mudler/LocalAI/pkg/model" ) func WelcomeEndpoint(appConfig *config.ApplicationConfig, cl *config.ModelConfigLoader, ml *model.ModelLoader, opcache *services.OpCache) echo.HandlerFunc { return func(c echo.Context) error { modelConfigs := cl.GetAllModelsConfigs() galleryConfigs := map[string]*gallery.ModelConfig{} installedBackends, err := gallery.ListSystemBackends(appConfig.SystemState) if err != nil { return err } for _, m := range modelConfigs { cfg, err := gallery.GetLocalModelConfiguration(ml.ModelPath, m.Name) if err != nil { continue } galleryConfigs[m.Name] = cfg } loadedModels := ml.ListLoadedModels() loadedModelsMap := map[string]bool{} for _, m := range loadedModels { loadedModelsMap[m.ID] = true } modelsWithoutConfig, _ := services.ListModels(cl, ml, config.NoFilterFn, services.LOOSE_ONLY) // Get model statuses to display in the UI the operation in progress processingModels, taskTypes := opcache.GetStatus() summary := map[string]interface{}{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), "BaseURL": middleware.BaseURL(c), "Models": modelsWithoutConfig, "ModelsConfig": modelConfigs, "GalleryConfig": galleryConfigs, "ApplicationConfig": appConfig, "ProcessingModels": processingModels, "TaskTypes": taskTypes, "LoadedModels": loadedModelsMap, "InstalledBackends": installedBackends, "DisableRuntimeSettings": appConfig.DisableRuntimeSettings, } contentType := c.Request().Header.Get("Content-Type") accept := c.Request().Header.Get("Accept") // Default to HTML if Accept header is empty (browser behavior) // Only return JSON if explicitly requested or Content-Type is application/json if strings.Contains(contentType, "application/json") || (accept != "" && !strings.Contains(accept, "text/html")) { // The client expects a JSON response return c.JSON(200, summary) } else { // Check if this is the manage route templateName := "views/index" if strings.HasSuffix(c.Request().URL.Path, "/manage") || c.Request().URL.Path == "/manage" { templateName = "views/manage" } // Render appropriate template return c.Render(200, templateName, summary) } } } ================================================ FILE: core/http/endpoints/mcp/tools.go ================================================ package mcp import ( "context" "encoding/json" "fmt" "net/http" "os" "os/exec" "strings" "sync" "time" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/signals" "github.com/modelcontextprotocol/go-sdk/mcp" "github.com/mudler/xlog" ) // NamedSession pairs an MCP session with its server name and type. type NamedSession struct { Name string Type string // "remote" or "stdio" Session *mcp.ClientSession } // MCPToolInfo holds a discovered MCP tool along with its origin session. type MCPToolInfo struct { ServerName string ToolName string Function functions.Function Session *mcp.ClientSession } // MCPServerInfo describes an MCP server and its available tools, prompts, and resources. type MCPServerInfo struct { Name string `json:"name"` Type string `json:"type"` Tools []string `json:"tools"` Prompts []string `json:"prompts,omitempty"` Resources []string `json:"resources,omitempty"` } // MCPPromptInfo holds a discovered MCP prompt along with its origin session. type MCPPromptInfo struct { ServerName string PromptName string Description string Title string Arguments []*mcp.PromptArgument Session *mcp.ClientSession } // MCPResourceInfo holds a discovered MCP resource along with its origin session. type MCPResourceInfo struct { ServerName string Name string URI string Description string MIMEType string Session *mcp.ClientSession } type sessionCache struct { mu sync.Mutex cache map[string][]*mcp.ClientSession cancels map[string]context.CancelFunc } type namedSessionCache struct { mu sync.Mutex cache map[string][]NamedSession cancels map[string]context.CancelFunc } var ( cache = sessionCache{ cache: make(map[string][]*mcp.ClientSession), cancels: make(map[string]context.CancelFunc), } namedCache = namedSessionCache{ cache: make(map[string][]NamedSession), cancels: make(map[string]context.CancelFunc), } client = mcp.NewClient(&mcp.Implementation{Name: "LocalAI", Version: "v1.0.0"}, nil) ) // MCPServersFromMetadata extracts the MCP server list from the metadata map // and returns the list. The "mcp_servers" key is consumed (deleted from the map) // so it doesn't leak to the backend. func MCPServersFromMetadata(metadata map[string]string) []string { raw, ok := metadata["mcp_servers"] if !ok || raw == "" { return nil } delete(metadata, "mcp_servers") servers := strings.Split(raw, ",") for i := range servers { servers[i] = strings.TrimSpace(servers[i]) } return servers } func SessionsFromMCPConfig( name string, remote config.MCPGenericConfig[config.MCPRemoteServers], stdio config.MCPGenericConfig[config.MCPSTDIOServers], ) ([]*mcp.ClientSession, error) { cache.mu.Lock() defer cache.mu.Unlock() sessions, exists := cache.cache[name] if exists { return sessions, nil } allSessions := []*mcp.ClientSession{} ctx, cancel := context.WithCancel(context.Background()) // Get the list of all the tools that the Agent will be esposed to for _, server := range remote.Servers { xlog.Debug("[MCP remote server] Configuration", "server", server) // Create HTTP client with custom roundtripper for bearer token injection httpClient := &http.Client{ Timeout: 360 * time.Second, Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport), } transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient} mcpSession, err := client.Connect(ctx, transport, nil) if err != nil { xlog.Error("Failed to connect to MCP server", "error", err, "url", server.URL) continue } xlog.Debug("[MCP remote server] Connected to MCP server", "url", server.URL) cache.cache[name] = append(cache.cache[name], mcpSession) allSessions = append(allSessions, mcpSession) } for _, server := range stdio.Servers { xlog.Debug("[MCP stdio server] Configuration", "server", server) command := exec.Command(server.Command, server.Args...) command.Env = os.Environ() for key, value := range server.Env { command.Env = append(command.Env, key+"="+value) } transport := &mcp.CommandTransport{Command: command} mcpSession, err := client.Connect(ctx, transport, nil) if err != nil { xlog.Error("Failed to start MCP server", "error", err, "command", command) continue } xlog.Debug("[MCP stdio server] Connected to MCP server", "command", command) cache.cache[name] = append(cache.cache[name], mcpSession) allSessions = append(allSessions, mcpSession) } cache.cancels[name] = cancel return allSessions, nil } // NamedSessionsFromMCPConfig returns sessions with their server names preserved. // If enabledServers is non-empty, only servers with matching names are returned. func NamedSessionsFromMCPConfig( name string, remote config.MCPGenericConfig[config.MCPRemoteServers], stdio config.MCPGenericConfig[config.MCPSTDIOServers], enabledServers []string, ) ([]NamedSession, error) { namedCache.mu.Lock() defer namedCache.mu.Unlock() allSessions, exists := namedCache.cache[name] if !exists { ctx, cancel := context.WithCancel(context.Background()) for serverName, server := range remote.Servers { xlog.Debug("[MCP remote server] Configuration", "name", serverName, "server", server) httpClient := &http.Client{ Timeout: 360 * time.Second, Transport: newBearerTokenRoundTripper(server.Token, http.DefaultTransport), } transport := &mcp.StreamableClientTransport{Endpoint: server.URL, HTTPClient: httpClient} mcpSession, err := client.Connect(ctx, transport, nil) if err != nil { xlog.Error("Failed to connect to MCP server", "error", err, "name", serverName, "url", server.URL) continue } xlog.Debug("[MCP remote server] Connected", "name", serverName, "url", server.URL) allSessions = append(allSessions, NamedSession{ Name: serverName, Type: "remote", Session: mcpSession, }) } for serverName, server := range stdio.Servers { xlog.Debug("[MCP stdio server] Configuration", "name", serverName, "server", server) command := exec.Command(server.Command, server.Args...) command.Env = os.Environ() for key, value := range server.Env { command.Env = append(command.Env, key+"="+value) } transport := &mcp.CommandTransport{Command: command} mcpSession, err := client.Connect(ctx, transport, nil) if err != nil { xlog.Error("Failed to start MCP server", "error", err, "name", serverName, "command", command) continue } xlog.Debug("[MCP stdio server] Connected", "name", serverName, "command", command) allSessions = append(allSessions, NamedSession{ Name: serverName, Type: "stdio", Session: mcpSession, }) } namedCache.cache[name] = allSessions namedCache.cancels[name] = cancel } if len(enabledServers) == 0 { return allSessions, nil } enabled := make(map[string]bool, len(enabledServers)) for _, s := range enabledServers { enabled[s] = true } var filtered []NamedSession for _, ns := range allSessions { if enabled[ns.Name] { filtered = append(filtered, ns) } } return filtered, nil } // DiscoverMCPTools queries each session for its tools and converts them to functions.Function. // Deduplicates by tool name (first server wins). func DiscoverMCPTools(ctx context.Context, sessions []NamedSession) ([]MCPToolInfo, error) { seen := make(map[string]bool) var result []MCPToolInfo for _, ns := range sessions { toolsResult, err := ns.Session.ListTools(ctx, nil) if err != nil { xlog.Error("Failed to list tools from MCP server", "error", err, "server", ns.Name) continue } for _, tool := range toolsResult.Tools { if seen[tool.Name] { continue } seen[tool.Name] = true f := functions.Function{ Name: tool.Name, Description: tool.Description, } // Convert InputSchema to map[string]interface{} for functions.Function if tool.InputSchema != nil { schemaBytes, err := json.Marshal(tool.InputSchema) if err == nil { var params map[string]interface{} if json.Unmarshal(schemaBytes, ¶ms) == nil { f.Parameters = params } } } if f.Parameters == nil { f.Parameters = map[string]interface{}{ "type": "object", "properties": map[string]interface{}{}, } } result = append(result, MCPToolInfo{ ServerName: ns.Name, ToolName: tool.Name, Function: f, Session: ns.Session, }) } } return result, nil } // ExecuteMCPToolCall finds the matching tool and executes it. func ExecuteMCPToolCall(ctx context.Context, tools []MCPToolInfo, toolName string, arguments string) (string, error) { var toolInfo *MCPToolInfo for i := range tools { if tools[i].ToolName == toolName { toolInfo = &tools[i] break } } if toolInfo == nil { return "", fmt.Errorf("MCP tool %q not found", toolName) } var args map[string]any if arguments != "" { if err := json.Unmarshal([]byte(arguments), &args); err != nil { return "", fmt.Errorf("failed to parse arguments for tool %q: %w", toolName, err) } } result, err := toolInfo.Session.CallTool(ctx, &mcp.CallToolParams{ Name: toolName, Arguments: args, }) if err != nil { return "", fmt.Errorf("MCP tool %q call failed: %w", toolName, err) } // Extract text content from result var texts []string for _, content := range result.Content { if tc, ok := content.(*mcp.TextContent); ok { texts = append(texts, tc.Text) } } if len(texts) == 0 { // Fallback: marshal the whole result data, _ := json.Marshal(result.Content) return string(data), nil } if len(texts) == 1 { return texts[0], nil } combined, _ := json.Marshal(texts) return string(combined), nil } // ListMCPServers returns server info with tool, prompt, and resource names for each session. func ListMCPServers(ctx context.Context, sessions []NamedSession) ([]MCPServerInfo, error) { var result []MCPServerInfo for _, ns := range sessions { info := MCPServerInfo{ Name: ns.Name, Type: ns.Type, } toolsResult, err := ns.Session.ListTools(ctx, nil) if err != nil { xlog.Error("Failed to list tools from MCP server", "error", err, "server", ns.Name) } else { for _, tool := range toolsResult.Tools { info.Tools = append(info.Tools, tool.Name) } } promptsResult, err := ns.Session.ListPrompts(ctx, nil) if err != nil { xlog.Debug("Failed to list prompts from MCP server", "error", err, "server", ns.Name) } else { for _, p := range promptsResult.Prompts { info.Prompts = append(info.Prompts, p.Name) } } resourcesResult, err := ns.Session.ListResources(ctx, nil) if err != nil { xlog.Debug("Failed to list resources from MCP server", "error", err, "server", ns.Name) } else { for _, r := range resourcesResult.Resources { info.Resources = append(info.Resources, r.URI) } } result = append(result, info) } return result, nil } // IsMCPTool checks if a tool name is in the MCP tool list. func IsMCPTool(tools []MCPToolInfo, name string) bool { for _, t := range tools { if t.ToolName == name { return true } } return false } // DiscoverMCPPrompts queries each session for its prompts. // Deduplicates by prompt name (first server wins). func DiscoverMCPPrompts(ctx context.Context, sessions []NamedSession) ([]MCPPromptInfo, error) { seen := make(map[string]bool) var result []MCPPromptInfo for _, ns := range sessions { promptsResult, err := ns.Session.ListPrompts(ctx, nil) if err != nil { xlog.Error("Failed to list prompts from MCP server", "error", err, "server", ns.Name) continue } for _, p := range promptsResult.Prompts { if seen[p.Name] { continue } seen[p.Name] = true result = append(result, MCPPromptInfo{ ServerName: ns.Name, PromptName: p.Name, Description: p.Description, Title: p.Title, Arguments: p.Arguments, Session: ns.Session, }) } } return result, nil } // GetMCPPrompt finds and expands a prompt by name using the discovered prompts list. func GetMCPPrompt(ctx context.Context, prompts []MCPPromptInfo, name string, args map[string]string) ([]*mcp.PromptMessage, error) { var info *MCPPromptInfo for i := range prompts { if prompts[i].PromptName == name { info = &prompts[i] break } } if info == nil { return nil, fmt.Errorf("MCP prompt %q not found", name) } result, err := info.Session.GetPrompt(ctx, &mcp.GetPromptParams{ Name: name, Arguments: args, }) if err != nil { return nil, fmt.Errorf("MCP prompt %q get failed: %w", name, err) } return result.Messages, nil } // DiscoverMCPResources queries each session for its resources. // Deduplicates by URI (first server wins). func DiscoverMCPResources(ctx context.Context, sessions []NamedSession) ([]MCPResourceInfo, error) { seen := make(map[string]bool) var result []MCPResourceInfo for _, ns := range sessions { resourcesResult, err := ns.Session.ListResources(ctx, nil) if err != nil { xlog.Error("Failed to list resources from MCP server", "error", err, "server", ns.Name) continue } for _, r := range resourcesResult.Resources { if seen[r.URI] { continue } seen[r.URI] = true result = append(result, MCPResourceInfo{ ServerName: ns.Name, Name: r.Name, URI: r.URI, Description: r.Description, MIMEType: r.MIMEType, Session: ns.Session, }) } } return result, nil } // ReadMCPResource reads a resource by URI from the matching session. func ReadMCPResource(ctx context.Context, resources []MCPResourceInfo, uri string) (string, error) { var info *MCPResourceInfo for i := range resources { if resources[i].URI == uri { info = &resources[i] break } } if info == nil { return "", fmt.Errorf("MCP resource %q not found", uri) } result, err := info.Session.ReadResource(ctx, &mcp.ReadResourceParams{URI: uri}) if err != nil { return "", fmt.Errorf("MCP resource %q read failed: %w", uri, err) } var texts []string for _, c := range result.Contents { if c.Text != "" { texts = append(texts, c.Text) } } return strings.Join(texts, "\n"), nil } // MCPPromptFromMetadata extracts the prompt name and arguments from metadata. // The "mcp_prompt" and "mcp_prompt_args" keys are consumed (deleted from the map). func MCPPromptFromMetadata(metadata map[string]string) (string, map[string]string) { name, ok := metadata["mcp_prompt"] if !ok || name == "" { return "", nil } delete(metadata, "mcp_prompt") var args map[string]string if raw, ok := metadata["mcp_prompt_args"]; ok && raw != "" { json.Unmarshal([]byte(raw), &args) delete(metadata, "mcp_prompt_args") } return name, args } // MCPResourcesFromMetadata extracts resource URIs from metadata. // The "mcp_resources" key is consumed (deleted from the map). func MCPResourcesFromMetadata(metadata map[string]string) []string { raw, ok := metadata["mcp_resources"] if !ok || raw == "" { return nil } delete(metadata, "mcp_resources") uris := strings.Split(raw, ",") for i := range uris { uris[i] = strings.TrimSpace(uris[i]) } return uris } // PromptMessageToText extracts text from a PromptMessage's Content. func PromptMessageToText(msg *mcp.PromptMessage) string { if tc, ok := msg.Content.(*mcp.TextContent); ok { return tc.Text } // Fallback: marshal content data, _ := json.Marshal(msg.Content) return string(data) } // CloseMCPSessions closes all MCP sessions for a given model and removes them from the cache. // This should be called when a model is unloaded or shut down. func CloseMCPSessions(modelName string) { // Close sessions in the unnamed cache cache.mu.Lock() if sessions, ok := cache.cache[modelName]; ok { for _, s := range sessions { s.Close() } delete(cache.cache, modelName) } if cancel, ok := cache.cancels[modelName]; ok { cancel() delete(cache.cancels, modelName) } cache.mu.Unlock() // Close sessions in the named cache namedCache.mu.Lock() if sessions, ok := namedCache.cache[modelName]; ok { for _, ns := range sessions { ns.Session.Close() } delete(namedCache.cache, modelName) } if cancel, ok := namedCache.cancels[modelName]; ok { cancel() delete(namedCache.cancels, modelName) } namedCache.mu.Unlock() xlog.Debug("Closed MCP sessions for model", "model", modelName) } // CloseAllMCPSessions closes all cached MCP sessions across all models. // This should be called during graceful shutdown. func CloseAllMCPSessions() { cache.mu.Lock() for name, sessions := range cache.cache { for _, s := range sessions { s.Close() } if cancel, ok := cache.cancels[name]; ok { cancel() } } cache.cache = make(map[string][]*mcp.ClientSession) cache.cancels = make(map[string]context.CancelFunc) cache.mu.Unlock() namedCache.mu.Lock() for name, sessions := range namedCache.cache { for _, ns := range sessions { ns.Session.Close() } if cancel, ok := namedCache.cancels[name]; ok { cancel() } } namedCache.cache = make(map[string][]NamedSession) namedCache.cancels = make(map[string]context.CancelFunc) namedCache.mu.Unlock() xlog.Debug("Closed all MCP sessions") } func init() { signals.RegisterGracefulTerminationHandler(func() { CloseAllMCPSessions() }) } // bearerTokenRoundTripper is a custom roundtripper that injects a bearer token // into HTTP requests type bearerTokenRoundTripper struct { token string base http.RoundTripper } // RoundTrip implements the http.RoundTripper interface func (rt *bearerTokenRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { if rt.token != "" { req.Header.Set("Authorization", "Bearer "+rt.token) } return rt.base.RoundTrip(req) } // newBearerTokenRoundTripper creates a new roundtripper that injects the given token func newBearerTokenRoundTripper(token string, base http.RoundTripper) http.RoundTripper { if base == nil { base = http.DefaultTransport } return &bearerTokenRoundTripper{ token: token, base: base, } } ================================================ FILE: core/http/endpoints/openai/chat.go ================================================ package openai import ( "encoding/json" "fmt" "strings" "time" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/functions" reason "github.com/mudler/LocalAI/pkg/reasoning" "github.com/mudler/LocalAI/core/templates" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // mergeToolCallDeltas merges streaming tool call deltas into complete tool calls. // In SSE streaming, a single tool call arrives as multiple chunks sharing the same Index: // the first chunk carries the ID, Type, and Name; subsequent chunks append to Arguments. func mergeToolCallDeltas(existing []schema.ToolCall, deltas []schema.ToolCall) []schema.ToolCall { byIndex := make(map[int]int, len(existing)) // tool call Index -> position in slice for i, tc := range existing { byIndex[tc.Index] = i } for _, d := range deltas { pos, found := byIndex[d.Index] if !found { byIndex[d.Index] = len(existing) existing = append(existing, d) continue } // Merge into existing entry tc := &existing[pos] if d.ID != "" { tc.ID = d.ID } if d.Type != "" { tc.Type = d.Type } if d.FunctionCall.Name != "" { tc.FunctionCall.Name = d.FunctionCall.Name } tc.FunctionCall.Arguments += d.FunctionCall.Arguments } return existing } // ChatEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/chat/create // @Summary Generate a chat completions for a given prompt and model. // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/chat/completions [post] func ChatEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, startupOptions *config.ApplicationConfig) echo.HandlerFunc { var id, textContentToReturn string var created int process := func(s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { initialMessage := schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0, FinishReason: nil}}, Object: "chat.completion.chunk", } responses <- initialMessage // Detect if thinking token is already in prompt or template // When UseTokenizerTemplate is enabled, predInput is empty, so we check the template var template string if config.TemplateConfig.UseTokenizerTemplate { template = config.GetModelTemplate() } else { template = s } thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig) extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig) _, _, _, err := ComputeChoices(req, s, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, tokenUsage backend.TokenUsage) bool { reasoningDelta, contentDelta := extractor.ProcessToken(s) usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, CompletionTokens: tokenUsage.Completion, TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, } if extraUsage { usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing } delta := &schema.Message{} if contentDelta != "" { delta.Content = &contentDelta } if reasoningDelta != "" { delta.Reasoning = &reasoningDelta } resp := schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{{Delta: delta, Index: 0, FinishReason: nil}}, Object: "chat.completion.chunk", Usage: usage, } responses <- resp return true }) close(responses) return err } processTools := func(noAction string, prompt string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { // Detect if thinking token is already in prompt or template var template string if config.TemplateConfig.UseTokenizerTemplate { template = config.GetModelTemplate() } else { template = prompt } thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig) extractor := reason.NewReasoningExtractor(thinkingStartToken, config.ReasoningConfig) result := "" lastEmittedCount := 0 sentInitialRole := false _, tokenUsage, chatDeltas, err := ComputeChoices(req, prompt, config, cl, startupOptions, loader, func(s string, c *[]schema.Choice) {}, func(s string, usage backend.TokenUsage) bool { result += s reasoningDelta, contentDelta := extractor.ProcessToken(s) // Emit reasoning deltas in their own SSE chunks before any tool-call chunks // (OpenAI spec: reasoning and tool_calls never share a delta) if reasoningDelta != "" { responses <- schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, Choices: []schema.Choice{{ Delta: &schema.Message{Reasoning: &reasoningDelta}, Index: 0, }}, Object: "chat.completion.chunk", } } // Stream content deltas (cleaned of reasoning tags) while no tool calls // have been detected. Once the incremental parser finds tool calls, // content stops — per OpenAI spec, content and tool_calls don't mix. if lastEmittedCount == 0 && contentDelta != "" { if !sentInitialRole { responses <- schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}}, Object: "chat.completion.chunk", } sentInitialRole = true } responses <- schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, Choices: []schema.Choice{{ Delta: &schema.Message{Content: &contentDelta}, Index: 0, }}, Object: "chat.completion.chunk", } } // Try incremental XML parsing for streaming support using iterative parser // This allows emitting partial tool calls as they're being generated cleanedResult := functions.CleanupLLMResult(result, config.FunctionsConfig) // Determine XML format from config var xmlFormat *functions.XMLToolCallFormat if config.FunctionsConfig.XMLFormat != nil { xmlFormat = config.FunctionsConfig.XMLFormat } else if config.FunctionsConfig.XMLFormatPreset != "" { xmlFormat = functions.GetXMLFormatPreset(config.FunctionsConfig.XMLFormatPreset) } // Use iterative parser for streaming (partial parsing enabled) // Try XML parsing first partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true) if parseErr == nil && len(partialResults) > 0 { // Emit new XML tool calls that weren't emitted before if len(partialResults) > lastEmittedCount { for i := lastEmittedCount; i < len(partialResults); i++ { toolCall := partialResults[i] initialMessage := schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, Choices: []schema.Choice{{ Delta: &schema.Message{ Role: "assistant", ToolCalls: []schema.ToolCall{ { Index: i, ID: id, Type: "function", FunctionCall: schema.FunctionCall{ Name: toolCall.Name, }, }, }, }, Index: 0, FinishReason: nil, }}, Object: "chat.completion.chunk", } select { case responses <- initialMessage: default: } } lastEmittedCount = len(partialResults) } } else { // Try JSON tool call parsing for streaming // Check if the result looks like JSON tool calls jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true) if jsonErr == nil && len(jsonResults) > 0 { // Check if these are tool calls (have "name" and optionally "arguments") for _, jsonObj := range jsonResults { if name, ok := jsonObj["name"].(string); ok && name != "" { // This looks like a tool call args := "{}" if argsVal, ok := jsonObj["arguments"]; ok { if argsStr, ok := argsVal.(string); ok { args = argsStr } else { argsBytes, _ := json.Marshal(argsVal) args = string(argsBytes) } } // Emit tool call initialMessage := schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, Choices: []schema.Choice{{ Delta: &schema.Message{ Role: "assistant", ToolCalls: []schema.ToolCall{ { Index: lastEmittedCount, ID: id, Type: "function", FunctionCall: schema.FunctionCall{ Name: name, Arguments: args, }, }, }, }, Index: 0, FinishReason: nil, }}, Object: "chat.completion.chunk", } select { case responses <- initialMessage: default: } lastEmittedCount++ } } } } return true }, func(attempt int) bool { // After streaming completes: check if we got actionable content cleaned := extractor.CleanedContent() // Check for tool calls from chat deltas (will be re-checked after ComputeChoices, // but we need to know here whether to retry) hasToolCalls := lastEmittedCount > 0 if cleaned == "" && !hasToolCalls { xlog.Warn("Streaming: backend produced only reasoning, retrying", "reasoning_len", len(extractor.Reasoning()), "attempt", attempt+1) extractor.ResetAndSuppressReasoning() result = "" lastEmittedCount = 0 sentInitialRole = false return true } return false }, ) if err != nil { return err } // Try using pre-parsed tool calls from C++ autoparser (chat deltas) var functionResults []functions.FuncCallResults var reasoning string if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { xlog.Debug("[ChatDeltas] Using pre-parsed tool calls from C++ autoparser", "count", len(deltaToolCalls)) functionResults = deltaToolCalls // Use content/reasoning from deltas too textContentToReturn = functions.ContentFromChatDeltas(chatDeltas) reasoning = functions.ReasoningFromChatDeltas(chatDeltas) } else { // Fallback: parse tool calls from raw text (no chat deltas from backend) xlog.Debug("[ChatDeltas] no pre-parsed tool calls, falling back to Go-side text parsing") reasoning = extractor.Reasoning() cleanedResult := extractor.CleanedContent() textContentToReturn = functions.ParseTextContent(cleanedResult, config.FunctionsConfig) cleanedResult = functions.CleanupLLMResult(cleanedResult, config.FunctionsConfig) functionResults = functions.ParseFunctionCall(cleanedResult, config.FunctionsConfig) } xlog.Debug("[ChatDeltas] final tool call decision", "tool_calls", len(functionResults), "text_content", textContentToReturn) noActionToRun := len(functionResults) > 0 && functionResults[0].Name == noAction || len(functionResults) == 0 switch { case noActionToRun: usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, CompletionTokens: tokenUsage.Completion, TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, } if extraUsage { usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing } if sentInitialRole { // Content was already streamed during the callback — just emit usage. delta := &schema.Message{} if reasoning != "" && extractor.Reasoning() == "" { delta.Reasoning = &reasoning } responses <- schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, Choices: []schema.Choice{{Delta: delta, Index: 0}}, Object: "chat.completion.chunk", Usage: usage, } } else { // Content was NOT streamed — send everything at once (fallback). responses <- schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, Choices: []schema.Choice{{Delta: &schema.Message{Role: "assistant"}, Index: 0}}, Object: "chat.completion.chunk", } result, err := handleQuestion(config, functionResults, extractor.CleanedContent(), prompt) if err != nil { xlog.Error("error handling question", "error", err) return err } delta := &schema.Message{Content: &result} if reasoning != "" { delta.Reasoning = &reasoning } responses <- schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, Choices: []schema.Choice{{Delta: delta, Index: 0}}, Object: "chat.completion.chunk", Usage: usage, } } default: for i, ss := range functionResults { name, args := ss.Name, ss.Arguments toolCallID := ss.ID if toolCallID == "" { toolCallID = id } initialMessage := schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{{ Delta: &schema.Message{ Role: "assistant", ToolCalls: []schema.ToolCall{ { Index: i, ID: toolCallID, Type: "function", FunctionCall: schema.FunctionCall{ Name: name, }, }, }, }, Index: 0, FinishReason: nil, }}, Object: "chat.completion.chunk", } responses <- initialMessage responses <- schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{{ Delta: &schema.Message{ Role: "assistant", Content: &textContentToReturn, ToolCalls: []schema.ToolCall{ { Index: i, ID: toolCallID, Type: "function", FunctionCall: schema.FunctionCall{ Arguments: args, }, }, }, }, Index: 0, FinishReason: nil, }}, Object: "chat.completion.chunk", } } } close(responses) return err } return func(c echo.Context) error { textContentToReturn = "" id = uuid.New().String() created = int(time.Now().Unix()) input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } extraUsage := c.Request().Header.Get("Extra-Usage") != "" config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return echo.ErrBadRequest } xlog.Debug("Chat endpoint configuration read", "config", config) funcs := input.Functions shouldUseFn := len(input.Functions) > 0 && config.ShouldUseFunctions() strictMode := false // MCP tool injection: when mcp_servers is set in metadata and model has MCP config var mcpToolInfos []mcpTools.MCPToolInfo mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata) // MCP prompt and resource injection (extracted before tool injection) mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata) mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata) if (len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0) && (config.MCP.Servers != "" || config.MCP.Stdio != "") { remote, stdio, mcpErr := config.MCP.MCPConfigFromYAML() if mcpErr == nil { namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(config.Name, remote, stdio, mcpServers) if sessErr == nil && len(namedSessions) > 0 { // Prompt injection: prepend prompt messages to the conversation if mcpPromptName != "" { prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions) if discErr == nil { promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs) if getErr == nil { var injected []schema.Message for _, pm := range promptMsgs { injected = append(injected, schema.Message{ Role: string(pm.Role), Content: mcpTools.PromptMessageToText(pm), }) } input.Messages = append(injected, input.Messages...) xlog.Debug("MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected)) } else { xlog.Error("Failed to get MCP prompt", "error", getErr) } } else { xlog.Error("Failed to discover MCP prompts", "error", discErr) } } // Resource injection: append resource content to the last user message if len(mcpResourceURIs) > 0 { resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions) if discErr == nil { var resourceTexts []string for _, uri := range mcpResourceURIs { content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri) if readErr != nil { xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri) continue } // Find resource name name := uri for _, r := range resources { if r.URI == uri { name = r.Name break } } resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content)) } if len(resourceTexts) > 0 && len(input.Messages) > 0 { lastIdx := len(input.Messages) - 1 suffix := "\n\n" + strings.Join(resourceTexts, "\n\n") switch ct := input.Messages[lastIdx].Content.(type) { case string: input.Messages[lastIdx].Content = ct + suffix default: input.Messages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix) } xlog.Debug("MCP resources injected", "count", len(resourceTexts)) } } else { xlog.Error("Failed to discover MCP resources", "error", discErr) } } // Tool injection if len(mcpServers) > 0 { discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions) if discErr == nil { mcpToolInfos = discovered for _, ti := range mcpToolInfos { funcs = append(funcs, ti.Function) input.Tools = append(input.Tools, functions.Tool{Type: "function", Function: ti.Function}) } shouldUseFn = len(funcs) > 0 && config.ShouldUseFunctions() xlog.Debug("MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs)) } else { xlog.Error("Failed to discover MCP tools", "error", discErr) } } } } else { xlog.Error("Failed to parse MCP config", "error", mcpErr) } } xlog.Debug("Tool call routing decision", "shouldUseFn", shouldUseFn, "len(input.Functions)", len(input.Functions), "len(input.Tools)", len(input.Tools), "config.ShouldUseFunctions()", config.ShouldUseFunctions(), "config.FunctionToCall()", config.FunctionToCall(), ) for _, f := range input.Functions { if f.Strict { strictMode = true break } } // Allow the user to set custom actions via config file // to be "embedded" in each model noActionName := "answer" noActionDescription := "use this action to answer without performing any action" if config.FunctionsConfig.NoActionFunctionName != "" { noActionName = config.FunctionsConfig.NoActionFunctionName } if config.FunctionsConfig.NoActionDescriptionName != "" { noActionDescription = config.FunctionsConfig.NoActionDescriptionName } // If we are using a response format, we need to generate a grammar for it if config.ResponseFormatMap != nil { d := schema.ChatCompletionResponseFormat{} dat, err := json.Marshal(config.ResponseFormatMap) if err != nil { return err } err = json.Unmarshal(dat, &d) if err != nil { return err } switch d.Type { case "json_object": input.Grammar = functions.JSONBNF case "json_schema": d := schema.JsonSchemaRequest{} dat, err := json.Marshal(config.ResponseFormatMap) if err != nil { return err } err = json.Unmarshal(dat, &d) if err != nil { return err } fs := &functions.JSONFunctionStructure{ AnyOf: []functions.Item{d.JsonSchema.Schema}, } g, err := fs.Grammar(config.FunctionsConfig.GrammarOptions()...) if err == nil { input.Grammar = g } else { xlog.Error("Failed generating grammar", "error", err) } } } config.Grammar = input.Grammar if shouldUseFn { xlog.Debug("Response needs to process functions") } switch { // Generates grammar with internal's LocalAI engine case (!config.FunctionsConfig.GrammarConfig.NoGrammar || strictMode) && shouldUseFn: noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, Parameters: map[string]interface{}{ "properties": map[string]interface{}{ "message": map[string]interface{}{ "type": "string", "description": "The message to reply the user with", }}, }, } // Append the no action function if !config.FunctionsConfig.DisableNoAction && !strictMode { funcs = append(funcs, noActionGrammar) } // Force picking one of the functions by the request if config.FunctionToCall() != "" { funcs = funcs.Select(config.FunctionToCall()) } // Update input grammar or json_schema based on use_llama_grammar option jsStruct := funcs.ToJSONStructure(config.FunctionsConfig.FunctionNameKey, config.FunctionsConfig.FunctionNameKey) g, err := jsStruct.Grammar(config.FunctionsConfig.GrammarOptions()...) if err == nil { config.Grammar = g } else { xlog.Error("Failed generating grammar", "error", err) } case input.JSONFunctionGrammarObject != nil: g, err := input.JSONFunctionGrammarObject.Grammar(config.FunctionsConfig.GrammarOptions()...) if err == nil { config.Grammar = g } else { xlog.Error("Failed generating grammar", "error", err) } default: // Force picking one of the functions by the request if config.FunctionToCall() != "" { funcs = funcs.Select(config.FunctionToCall()) } } // process functions if we have any defined or if we have a function call string // functions are not supported in stream mode (yet?) toStream := input.Stream xlog.Debug("Parameters", "config", config) var predInput string // If we are using the tokenizer template, we don't need to process the messages // unless we are processing functions if !config.TemplateConfig.UseTokenizerTemplate { predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) xlog.Debug("Prompt (after templating)", "prompt", predInput) if config.Grammar != "" { xlog.Debug("Grammar", "grammar", config.Grammar) } } switch { case toStream: xlog.Debug("Stream request received") c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") c.Response().Header().Set("X-Correlation-ID", id) mcpStreamMaxIterations := 10 if config.Agent.MaxIterations > 0 { mcpStreamMaxIterations = config.Agent.MaxIterations } hasMCPToolsStream := len(mcpToolInfos) > 0 for mcpStreamIter := 0; mcpStreamIter <= mcpStreamMaxIterations; mcpStreamIter++ { // Re-template on MCP iterations if mcpStreamIter > 0 && !config.TemplateConfig.UseTokenizerTemplate { predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) xlog.Debug("MCP stream re-templating", "iteration", mcpStreamIter) } responses := make(chan schema.OpenAIResponse) ended := make(chan error, 1) go func() { if !shouldUseFn { ended <- process(predInput, input, config, ml, responses, extraUsage) } else { ended <- processTools(noActionName, predInput, input, config, ml, responses, extraUsage) } }() usage := &schema.OpenAIUsage{} toolsCalled := false var collectedToolCalls []schema.ToolCall var collectedContent string LOOP: for { select { case <-input.Context.Done(): // Context was cancelled (client disconnected or request cancelled) xlog.Debug("Request context cancelled, stopping stream") input.Cancel() break LOOP case ev := <-responses: if len(ev.Choices) == 0 { xlog.Debug("No choices in the response, skipping") continue } usage = &ev.Usage // Copy a pointer to the latest usage chunk so that the stop message can reference it if len(ev.Choices[0].Delta.ToolCalls) > 0 { toolsCalled = true // Collect and merge tool call deltas for MCP execution if hasMCPToolsStream { collectedToolCalls = mergeToolCallDeltas(collectedToolCalls, ev.Choices[0].Delta.ToolCalls) } } // Collect content for MCP conversation history if hasMCPToolsStream && ev.Choices[0].Delta != nil && ev.Choices[0].Delta.Content != nil { if s, ok := ev.Choices[0].Delta.Content.(string); ok { collectedContent += s } else if sp, ok := ev.Choices[0].Delta.Content.(*string); ok && sp != nil { collectedContent += *sp } } respData, err := json.Marshal(ev) if err != nil { xlog.Debug("Failed to marshal response", "error", err) input.Cancel() continue } xlog.Debug("Sending chunk", "chunk", string(respData)) _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) if err != nil { xlog.Debug("Sending chunk failed", "error", err) input.Cancel() return err } c.Response().Flush() case err := <-ended: if err == nil { break LOOP } xlog.Error("Stream ended with error", "error", err) stopReason := FinishReasonStop resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { FinishReason: &stopReason, Index: 0, Delta: &schema.Message{Content: "Internal error: " + err.Error()}, }}, Object: "chat.completion.chunk", Usage: *usage, } respData, marshalErr := json.Marshal(resp) if marshalErr != nil { xlog.Error("Failed to marshal error response", "error", marshalErr) // Send a simple error message as fallback fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n") } else { fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) } fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } } // MCP streaming tool execution: if we collected MCP tool calls, execute and loop if hasMCPToolsStream && toolsCalled && len(collectedToolCalls) > 0 { var hasMCPCalls bool for _, tc := range collectedToolCalls { if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { hasMCPCalls = true break } } if hasMCPCalls { // Append assistant message with tool_calls assistantMsg := schema.Message{ Role: "assistant", Content: collectedContent, ToolCalls: collectedToolCalls, } input.Messages = append(input.Messages, assistantMsg) // Execute MCP tool calls and stream results as tool_result events for _, tc := range collectedToolCalls { if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool (stream)", "tool", tc.FunctionCall.Name, "iteration", mcpStreamIter) toolResult, toolErr := mcpTools.ExecuteMCPToolCall( c.Request().Context(), mcpToolInfos, tc.FunctionCall.Name, tc.FunctionCall.Arguments, ) if toolErr != nil { xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) toolResult = fmt.Sprintf("Error: %v", toolErr) } input.Messages = append(input.Messages, schema.Message{ Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tc.ID, Name: tc.FunctionCall.Name, }) // Stream tool result event to client mcpEvent := map[string]any{ "type": "mcp_tool_result", "name": tc.FunctionCall.Name, "result": toolResult, } if mcpEventData, err := json.Marshal(mcpEvent); err == nil { fmt.Fprintf(c.Response().Writer, "data: %s\n\n", mcpEventData) c.Response().Flush() } } xlog.Debug("MCP streaming tools executed, re-running inference", "iteration", mcpStreamIter) continue // next MCP stream iteration } } // No MCP tools to execute, send final stop message finishReason := FinishReasonStop if toolsCalled && len(input.Tools) > 0 { finishReason = FinishReasonToolCalls } else if toolsCalled { finishReason = FinishReasonFunctionCall } resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { FinishReason: &finishReason, Index: 0, Delta: &schema.Message{}, }}, Object: "chat.completion.chunk", Usage: *usage, } respData, _ := json.Marshal(resp) fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() xlog.Debug("Stream ended") return nil } // end MCP stream iteration loop // Safety fallback fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil // no streaming mode default: mcpMaxIterations := 10 if config.Agent.MaxIterations > 0 { mcpMaxIterations = config.Agent.MaxIterations } hasMCPTools := len(mcpToolInfos) > 0 for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { // Re-template on each MCP iteration since messages may have changed if mcpIteration > 0 && !config.TemplateConfig.UseTokenizerTemplate { predInput = evaluator.TemplateMessages(*input, input.Messages, config, funcs, shouldUseFn) xlog.Debug("MCP re-templating", "iteration", mcpIteration, "prompt_len", len(predInput)) } // Detect if thinking token is already in prompt or template var template string if config.TemplateConfig.UseTokenizerTemplate { template = config.GetModelTemplate() // TODO: this should be the parsed jinja template. But for now this is the best we can do. } else { template = predInput } thinkingStartToken := reason.DetectThinkingStartToken(template, &config.ReasoningConfig) xlog.Debug("Thinking start token", "thinkingStartToken", thinkingStartToken, "template", template) // When shouldUseFn, the callback just stores the raw text — tool parsing // is deferred to after ComputeChoices so we can check chat deltas first // and avoid redundant Go-side parsing. var cbRawResult, cbReasoning string tokenCallback := func(s string, c *[]schema.Choice) { reasoning, s := reason.ExtractReasoningWithConfig(s, thinkingStartToken, config.ReasoningConfig) if !shouldUseFn { stopReason := FinishReasonStop message := &schema.Message{Role: "assistant", Content: &s} if reasoning != "" { message.Reasoning = &reasoning } *c = append(*c, schema.Choice{FinishReason: &stopReason, Index: 0, Message: message}) return } // Store raw text for deferred tool parsing cbRawResult = s cbReasoning = reasoning } var result []schema.Choice var tokenUsage backend.TokenUsage var err error var chatDeltas []*pb.ChatDelta result, tokenUsage, chatDeltas, err = ComputeChoices( input, predInput, config, cl, startupOptions, ml, tokenCallback, nil, func(attempt int) bool { if !shouldUseFn { return false } // Retry when backend produced only reasoning and no content/tool calls. // Full tool parsing is deferred until after ComputeChoices returns // (when chat deltas are available), but we can detect the empty case here. if cbRawResult == "" && textContentToReturn == "" { xlog.Warn("Backend produced reasoning without actionable content, retrying", "reasoning_len", len(cbReasoning), "attempt", attempt+1) cbRawResult = "" cbReasoning = "" textContentToReturn = "" return true } return false }, ) if err != nil { return err } // Tool parsing is deferred here (only when shouldUseFn) so chat deltas are available if shouldUseFn { var funcResults []functions.FuncCallResults // Try pre-parsed tool calls from C++ autoparser first if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { xlog.Debug("[ChatDeltas] non-SSE: using C++ autoparser tool calls, skipping Go-side parsing", "count", len(deltaToolCalls)) funcResults = deltaToolCalls textContentToReturn = functions.ContentFromChatDeltas(chatDeltas) cbReasoning = functions.ReasoningFromChatDeltas(chatDeltas) } else { // Fallback: parse tool calls from raw text xlog.Debug("[ChatDeltas] non-SSE: no chat deltas, falling back to Go-side text parsing") textContentToReturn = functions.ParseTextContent(cbRawResult, config.FunctionsConfig) cbRawResult = functions.CleanupLLMResult(cbRawResult, config.FunctionsConfig) funcResults = functions.ParseFunctionCall(cbRawResult, config.FunctionsConfig) } noActionsToRun := len(funcResults) > 0 && funcResults[0].Name == noActionName || len(funcResults) == 0 switch { case noActionsToRun: qResult, qErr := handleQuestion(config, funcResults, cbRawResult, predInput) if qErr != nil { xlog.Error("error handling question", "error", qErr) } stopReason := FinishReasonStop message := &schema.Message{Role: "assistant", Content: &qResult} if cbReasoning != "" { message.Reasoning = &cbReasoning } result = append(result, schema.Choice{ FinishReason: &stopReason, Message: message, }) default: toolCallsReason := FinishReasonToolCalls toolChoice := schema.Choice{ FinishReason: &toolCallsReason, Message: &schema.Message{ Role: "assistant", }, } if cbReasoning != "" { toolChoice.Message.Reasoning = &cbReasoning } for _, ss := range funcResults { name, args := ss.Name, ss.Arguments toolCallID := ss.ID if toolCallID == "" { toolCallID = id } if len(input.Tools) > 0 { toolChoice.Message.Content = textContentToReturn toolChoice.Message.ToolCalls = append(toolChoice.Message.ToolCalls, schema.ToolCall{ ID: toolCallID, Type: "function", FunctionCall: schema.FunctionCall{ Name: name, Arguments: args, }, }, ) } else { // Deprecated function_call format functionCallReason := FinishReasonFunctionCall message := &schema.Message{ Role: "assistant", Content: &textContentToReturn, FunctionCall: map[string]interface{}{ "name": name, "arguments": args, }, } if cbReasoning != "" { message.Reasoning = &cbReasoning } result = append(result, schema.Choice{ FinishReason: &functionCallReason, Message: message, }) } } if len(input.Tools) > 0 { result = append(result, toolChoice) } } } // MCP server-side tool execution loop: // If we have MCP tools and the model returned tool_calls, execute MCP tools // and re-run inference with the results appended to the conversation. if hasMCPTools && len(result) > 0 { var mcpCallsExecuted bool for _, choice := range result { if choice.Message == nil || len(choice.Message.ToolCalls) == 0 { continue } // Check if any tool calls are MCP tools var hasMCPCalls bool for _, tc := range choice.Message.ToolCalls { if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { hasMCPCalls = true break } } if !hasMCPCalls { continue } // Append assistant message with tool_calls to conversation assistantContent := "" if choice.Message.Content != nil { if s, ok := choice.Message.Content.(string); ok { assistantContent = s } else if sp, ok := choice.Message.Content.(*string); ok && sp != nil { assistantContent = *sp } } assistantMsg := schema.Message{ Role: "assistant", Content: assistantContent, ToolCalls: choice.Message.ToolCalls, } input.Messages = append(input.Messages, assistantMsg) // Execute each MCP tool call and append results for _, tc := range choice.Message.ToolCalls { if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool", "tool", tc.FunctionCall.Name, "arguments", tc.FunctionCall.Arguments, "iteration", mcpIteration) toolResult, toolErr := mcpTools.ExecuteMCPToolCall( c.Request().Context(), mcpToolInfos, tc.FunctionCall.Name, tc.FunctionCall.Arguments, ) if toolErr != nil { xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) toolResult = fmt.Sprintf("Error: %v", toolErr) } input.Messages = append(input.Messages, schema.Message{ Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tc.ID, Name: tc.FunctionCall.Name, }) mcpCallsExecuted = true } } if mcpCallsExecuted { xlog.Debug("MCP tools executed, re-running inference", "iteration", mcpIteration, "messages", len(input.Messages)) continue // next MCP iteration } } // No MCP tools to execute (or no MCP tools configured), return response usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, CompletionTokens: tokenUsage.Completion, TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, } if extraUsage { usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing } resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "chat.completion", Usage: usage, } respData, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(respData)) // Return the prediction in the response body return c.JSON(200, resp) } // end MCP iteration loop // Should not reach here, but safety fallback return fmt.Errorf("MCP iteration limit reached") } } } func handleQuestion(config *config.ModelConfig, funcResults []functions.FuncCallResults, result, prompt string) (string, error) { if len(funcResults) == 0 && result != "" { xlog.Debug("nothing function results but we had a message from the LLM") return result, nil } xlog.Debug("nothing to do, computing a reply") arg := "" if len(funcResults) > 0 { arg = funcResults[0].Arguments } // If there is a message that the LLM already sends as part of the JSON reply, use it arguments := map[string]interface{}{} if err := json.Unmarshal([]byte(arg), &arguments); err != nil { xlog.Debug("handleQuestion: function result did not contain a valid JSON object") } m, exists := arguments["message"] if exists { switch message := m.(type) { case string: if message != "" { xlog.Debug("Reply received from LLM", "message", message) message = backend.Finetune(*config, prompt, message) xlog.Debug("Reply received from LLM(finetuned)", "message", message) return message, nil } } } xlog.Debug("No action received from LLM, without a message, computing a reply") return "", nil } ================================================ FILE: core/http/endpoints/openai/chat_test.go ================================================ package openai import ( "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/pkg/functions" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/mudler/LocalAI/core/schema" ) var _ = Describe("handleQuestion", func() { var cfg *config.ModelConfig BeforeEach(func() { cfg = &config.ModelConfig{} }) Context("with no function results but non-empty result", func() { It("should return the result directly", func() { result, err := handleQuestion(cfg, nil, "Hello world", "prompt") Expect(err).ToNot(HaveOccurred()) Expect(result).To(Equal("Hello world")) }) }) Context("with no function results and empty result", func() { It("should return empty string", func() { result, err := handleQuestion(cfg, nil, "", "prompt") Expect(err).ToNot(HaveOccurred()) Expect(result).To(BeEmpty()) }) }) Context("with function result containing a message argument", func() { It("should extract the message from function arguments", func() { funcResults := []functions.FuncCallResults{ { Name: "answer", Arguments: `{"message": "This is the answer"}`, }, } result, err := handleQuestion(cfg, funcResults, "", "prompt") Expect(err).ToNot(HaveOccurred()) Expect(result).To(Equal("This is the answer")) }) }) Context("with function result containing empty message", func() { It("should return empty string when message is empty", func() { funcResults := []functions.FuncCallResults{ { Name: "answer", Arguments: `{"message": ""}`, }, } result, err := handleQuestion(cfg, funcResults, "", "prompt") Expect(err).ToNot(HaveOccurred()) Expect(result).To(BeEmpty()) }) }) Context("with function result containing invalid JSON arguments", func() { It("should return empty string gracefully", func() { funcResults := []functions.FuncCallResults{ { Name: "answer", Arguments: "not json", }, } result, err := handleQuestion(cfg, funcResults, "", "prompt") Expect(err).ToNot(HaveOccurred()) Expect(result).To(BeEmpty()) }) }) Context("with cleaned content (no think tags)", func() { It("should return content without think tags", func() { // This tests the bug fix: handleQuestion should receive cleaned content, // not raw text with tags result, err := handleQuestion(cfg, nil, "Just the answer", "prompt") Expect(err).ToNot(HaveOccurred()) Expect(result).To(Equal("Just the answer")) Expect(result).ToNot(ContainSubstring("")) }) }) Context("with raw think tags passed as result", func() { It("would return content with think tags", func() { result, err := handleQuestion(cfg, nil, "reasoninganswer", "prompt") Expect(err).ToNot(HaveOccurred()) Expect(result).To(Equal("reasoninganswer")) }) }) }) var _ = Describe("mergeToolCallDeltas", func() { Context("with new tool calls", func() { It("should append new tool calls", func() { existing := []schema.ToolCall{} deltas := []schema.ToolCall{ {Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}}, } result := mergeToolCallDeltas(existing, deltas) Expect(result).To(HaveLen(1)) Expect(result[0].ID).To(Equal("tc1")) Expect(result[0].FunctionCall.Name).To(Equal("search")) }) }) Context("with argument appending", func() { It("should append arguments to existing tool call", func() { existing := []schema.ToolCall{ {Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search", Arguments: `{"q":`}}, } deltas := []schema.ToolCall{ {Index: 0, FunctionCall: schema.FunctionCall{Arguments: `"hello"}`}}, } result := mergeToolCallDeltas(existing, deltas) Expect(result).To(HaveLen(1)) Expect(result[0].FunctionCall.Arguments).To(Equal(`{"q":"hello"}`)) }) }) Context("with multiple tool calls", func() { It("should track multiple tool calls by index", func() { existing := []schema.ToolCall{} deltas1 := []schema.ToolCall{ {Index: 0, ID: "tc1", Type: "function", FunctionCall: schema.FunctionCall{Name: "search"}}, } result := mergeToolCallDeltas(existing, deltas1) deltas2 := []schema.ToolCall{ {Index: 1, ID: "tc2", Type: "function", FunctionCall: schema.FunctionCall{Name: "browse"}}, } result = mergeToolCallDeltas(result, deltas2) Expect(result).To(HaveLen(2)) Expect(result[0].FunctionCall.Name).To(Equal("search")) Expect(result[1].FunctionCall.Name).To(Equal("browse")) }) }) Context("with ID update on existing tool call", func() { It("should update ID when provided in delta", func() { existing := []schema.ToolCall{ {Index: 0, FunctionCall: schema.FunctionCall{Name: "search"}}, } deltas := []schema.ToolCall{ {Index: 0, ID: "new-id"}, } result := mergeToolCallDeltas(existing, deltas) Expect(result).To(HaveLen(1)) Expect(result[0].ID).To(Equal("new-id")) Expect(result[0].FunctionCall.Name).To(Equal("search")) }) }) }) ================================================ FILE: core/http/endpoints/openai/completion.go ================================================ package openai import ( "encoding/json" "errors" "fmt" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // CompletionEndpoint is the OpenAI Completion API endpoint https://platform.openai.com/docs/api-reference/completions // @Summary Generate completions for a given prompt and model. // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/completions [post] func CompletionEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { process := func(id string, s string, req *schema.OpenAIRequest, config *config.ModelConfig, loader *model.ModelLoader, responses chan schema.OpenAIResponse, extraUsage bool) error { tokenCallback := func(s string, tokenUsage backend.TokenUsage) bool { created := int(time.Now().Unix()) usage := schema.OpenAIUsage{ PromptTokens: tokenUsage.Prompt, CompletionTokens: tokenUsage.Completion, TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, } if extraUsage { usage.TimingTokenGeneration = tokenUsage.TimingTokenGeneration usage.TimingPromptProcessing = tokenUsage.TimingPromptProcessing } resp := schema.OpenAIResponse{ ID: id, Created: created, Model: req.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { Index: 0, Text: s, FinishReason: nil, }, }, Object: "text_completion", Usage: usage, } xlog.Debug("Sending goroutine", "text", s) responses <- resp return true } _, _, _, err := ComputeChoices(req, s, config, cl, appConfig, loader, func(s string, c *[]schema.Choice) {}, tokenCallback) close(responses) return err } return func(c echo.Context) error { created := int(time.Now().Unix()) // Handle Correlation id := c.Request().Header.Get("X-Correlation-ID") if id == "" { id = uuid.New().String() } extraUsage := c.Request().Header.Get("Extra-Usage") != "" input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return echo.ErrBadRequest } if config.ResponseFormatMap != nil { d := schema.ChatCompletionResponseFormat{} dat, _ := json.Marshal(config.ResponseFormatMap) _ = json.Unmarshal(dat, &d) if d.Type == "json_object" { input.Grammar = functions.JSONBNF } } config.Grammar = input.Grammar xlog.Debug("Parameter Config", "config", config) if input.Stream { xlog.Debug("Stream request received") c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") if len(config.PromptStrings) > 1 { return errors.New("cannot handle more than 1 `PromptStrings` when Streaming") } predInput := config.PromptStrings[0] templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{ Input: predInput, SystemPrompt: config.SystemPrompt, ReasoningEffort: input.ReasoningEffort, Metadata: input.Metadata, }) if err == nil { predInput = templatedInput xlog.Debug("Template found, input modified", "input", predInput) } responses := make(chan schema.OpenAIResponse) ended := make(chan error) go func() { ended <- process(id, predInput, input, config, ml, responses, extraUsage) }() LOOP: for { select { case ev := <-responses: if len(ev.Choices) == 0 { xlog.Debug("No choices in the response, skipping") continue } respData, err := json.Marshal(ev) if err != nil { xlog.Debug("Failed to marshal response", "error", err) continue } xlog.Debug("Sending chunk", "chunk", string(respData)) _, err = fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(respData)) if err != nil { return err } c.Response().Flush() case err := <-ended: if err == nil { break LOOP } xlog.Error("Stream ended with error", "error", err) stopReason := FinishReasonStop errorResp := schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, Choices: []schema.Choice{ { Index: 0, FinishReason: &stopReason, Text: "Internal error: " + err.Error(), }, }, Object: "text_completion", } errorData, marshalErr := json.Marshal(errorResp) if marshalErr != nil { xlog.Error("Failed to marshal error response", "error", marshalErr) // Send a simple error message as fallback fmt.Fprintf(c.Response().Writer, "data: {\"error\":\"Internal error\"}\n\n") } else { fmt.Fprintf(c.Response().Writer, "data: %s\n\n", string(errorData)) } c.Response().Flush() return nil } } stopReason := FinishReasonStop resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: []schema.Choice{ { Index: 0, FinishReason: &stopReason, }, }, Object: "text_completion", } respData, _ := json.Marshal(resp) fmt.Fprintf(c.Response().Writer, "data: %s\n\n", respData) fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } var result []schema.Choice totalTokenUsage := backend.TokenUsage{} for k, i := range config.PromptStrings { templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.CompletionPromptTemplate, *config, templates.PromptTemplateData{ SystemPrompt: config.SystemPrompt, Input: i, ReasoningEffort: input.ReasoningEffort, Metadata: input.Metadata, }) if err == nil { i = templatedInput xlog.Debug("Template found, input modified", "input", i) } r, tokenUsage, _, err := ComputeChoices( input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) { stopReason := FinishReasonStop *c = append(*c, schema.Choice{Text: s, FinishReason: &stopReason, Index: k}) }, nil) if err != nil { return err } totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing result = append(result, r...) } usage := schema.OpenAIUsage{ PromptTokens: totalTokenUsage.Prompt, CompletionTokens: totalTokenUsage.Completion, TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, } if extraUsage { usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing } resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "text_completion", Usage: usage, } jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) // Return the prediction in the response body return c.JSON(200, resp) } } ================================================ FILE: core/http/endpoints/openai/constants.go ================================================ package openai // Finish reason constants for OpenAI API responses const ( FinishReasonStop = "stop" FinishReasonToolCalls = "tool_calls" FinishReasonFunctionCall = "function_call" ) ================================================ FILE: core/http/endpoints/openai/edit.go ================================================ package openai import ( "encoding/json" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // EditEndpoint is the OpenAI edit API endpoint // @Summary OpenAI edit endpoint // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/edits [post] func EditEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } // Opt-in extra usage flag extraUsage := c.Request().Header.Get("Extra-Usage") != "" config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return echo.ErrBadRequest } xlog.Debug("Edit Endpoint Input", "input", input) xlog.Debug("Edit Endpoint Config", "config", *config) var result []schema.Choice totalTokenUsage := backend.TokenUsage{} for _, i := range config.InputStrings { templatedInput, err := evaluator.EvaluateTemplateForPrompt(templates.EditPromptTemplate, *config, templates.PromptTemplateData{ Input: i, Instruction: input.Instruction, SystemPrompt: config.SystemPrompt, ReasoningEffort: input.ReasoningEffort, Metadata: input.Metadata, }) if err == nil { i = templatedInput xlog.Debug("Template found, input modified", "input", i) } r, tokenUsage, _, err := ComputeChoices(input, i, config, cl, appConfig, ml, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil) if err != nil { return err } totalTokenUsage.Prompt += tokenUsage.Prompt totalTokenUsage.Completion += tokenUsage.Completion totalTokenUsage.TimingTokenGeneration += tokenUsage.TimingTokenGeneration totalTokenUsage.TimingPromptProcessing += tokenUsage.TimingPromptProcessing result = append(result, r...) } usage := schema.OpenAIUsage{ PromptTokens: totalTokenUsage.Prompt, CompletionTokens: totalTokenUsage.Completion, TotalTokens: totalTokenUsage.Prompt + totalTokenUsage.Completion, } if extraUsage { usage.TimingTokenGeneration = totalTokenUsage.TimingTokenGeneration usage.TimingPromptProcessing = totalTokenUsage.TimingPromptProcessing } id := uuid.New().String() created := int(time.Now().Unix()) resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Choices: result, Object: "edit", Usage: usage, } jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) // Return the prediction in the response body return c.JSON(200, resp) } } ================================================ FILE: core/http/endpoints/openai/embeddings.go ================================================ package openai import ( "encoding/json" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/pkg/model" "github.com/google/uuid" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/xlog" ) // EmbeddingsEndpoint is the OpenAI Embeddings API endpoint https://platform.openai.com/docs/api-reference/embeddings // @Summary Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms. // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/embeddings [post] func EmbeddingsEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return echo.ErrBadRequest } xlog.Debug("Parameter Config", "config", config) items := []schema.Item{} for i, s := range config.InputToken { // get the model function to call for the result embedFn, err := backend.ModelEmbedding("", s, ml, *config, appConfig) if err != nil { return err } embeddings, err := embedFn() if err != nil { return err } items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) } for i, s := range config.InputStrings { // get the model function to call for the result embedFn, err := backend.ModelEmbedding(s, []int{}, ml, *config, appConfig) if err != nil { return err } embeddings, err := embedFn() if err != nil { return err } items = append(items, schema.Item{Embedding: embeddings, Index: i, Object: "embedding"}) } id := uuid.New().String() created := int(time.Now().Unix()) resp := &schema.OpenAIResponse{ ID: id, Created: created, Model: input.Model, // we have to return what the user sent here, due to OpenAI spec. Data: items, Object: "list", } jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) // Return the prediction in the response body return c.JSON(200, resp) } } ================================================ FILE: core/http/endpoints/openai/image.go ================================================ package openai import ( "bufio" "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "strconv" "strings" "time" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/backend" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/xlog" ) func downloadFile(url string) (string, error) { if err := utils.ValidateExternalURL(url); err != nil { return "", fmt.Errorf("URL validation failed: %w", err) } // Get the data resp, err := http.Get(url) if err != nil { return "", err } defer resp.Body.Close() // Create the file out, err := os.CreateTemp("", "image") if err != nil { return "", err } defer out.Close() // Write the body to file _, err = io.Copy(out, resp.Body) return out.Name(), err } // /* * curl http://localhost:8080/v1/images/generations \ -H "Content-Type: application/json" \ -d '{ "prompt": "A cute baby sea otter", "n": 1, "size": "512x512" }' * */ // ImageEndpoint is the OpenAI Image generation API endpoint https://platform.openai.com/docs/api-reference/images/create // @Summary Creates an image given a prompt. // @Param request body schema.OpenAIRequest true "query params" // @Success 200 {object} schema.OpenAIResponse "Response" // @Router /v1/images/generations [post] func ImageEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { xlog.Error("Image Endpoint - Invalid Input") return echo.ErrBadRequest } config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { xlog.Error("Image Endpoint - Invalid Config") return echo.ErrBadRequest } // Process input images (for img2img/inpainting) src := "" if input.File != "" { src = processImageFile(input.File, appConfig.GeneratedContentDir) if src != "" { defer os.RemoveAll(src) } } // Process multiple input images var inputImages []string if len(input.Files) > 0 { for _, file := range input.Files { processedFile := processImageFile(file, appConfig.GeneratedContentDir) if processedFile != "" { inputImages = append(inputImages, processedFile) defer os.RemoveAll(processedFile) } } } // Process reference images var refImages []string if len(input.RefImages) > 0 { for _, file := range input.RefImages { processedFile := processImageFile(file, appConfig.GeneratedContentDir) if processedFile != "" { refImages = append(refImages, processedFile) defer os.RemoveAll(processedFile) } } } xlog.Debug("Parameter Config", "config", config) switch config.Backend { case "stablediffusion": config.Backend = model.StableDiffusionGGMLBackend case "": config.Backend = model.StableDiffusionGGMLBackend } if !strings.Contains(input.Size, "x") { input.Size = "512x512" xlog.Warn("Invalid size, using default 512x512") } sizeParts := strings.Split(input.Size, "x") if len(sizeParts) != 2 { return fmt.Errorf("invalid value for 'size'") } width, err := strconv.Atoi(sizeParts[0]) if err != nil { return fmt.Errorf("invalid value for 'size'") } height, err := strconv.Atoi(sizeParts[1]) if err != nil { return fmt.Errorf("invalid value for 'size'") } b64JSON := config.ResponseFormat == "b64_json" // src and clip_skip var result []schema.Item for _, i := range config.PromptStrings { n := input.N if input.N == 0 { n = 1 } for j := 0; j < n; j++ { prompts := strings.Split(i, "|") positive_prompt := prompts[0] negative_prompt := "" if len(prompts) > 1 { negative_prompt = prompts[1] } step := config.Step if step == 0 { step = 15 } if input.Step != 0 { step = input.Step } tempDir := "" if !b64JSON { tempDir = filepath.Join(appConfig.GeneratedContentDir, "images") } // Create a temporary file outputFile, err := os.CreateTemp(tempDir, "b64") if err != nil { return err } outputFile.Close() output := outputFile.Name() + ".png" // Rename the temporary file err = os.Rename(outputFile.Name(), output) if err != nil { return err } baseURL := middleware.BaseURL(c) // Use the first input image as src if available, otherwise use the original src inputSrc := src if len(inputImages) > 0 { inputSrc = inputImages[0] } fn, err := backend.ImageGeneration(height, width, step, *config.Seed, positive_prompt, negative_prompt, inputSrc, output, ml, *config, appConfig, refImages) if err != nil { return err } if err := fn(); err != nil { return err } item := &schema.Item{} if b64JSON { defer os.RemoveAll(output) data, err := os.ReadFile(output) if err != nil { return err } item.B64JSON = base64.StdEncoding.EncodeToString(data) } else { base := filepath.Base(output) item.URL, err = url.JoinPath(baseURL, "generated-images", base) if err != nil { return err } } result = append(result, *item) } } id := uuid.New().String() created := int(time.Now().Unix()) resp := &schema.OpenAIResponse{ ID: id, Created: created, Data: result, Usage: schema.OpenAIUsage{ PromptTokens: 0, CompletionTokens: 0, TotalTokens: 0, InputTokens: 0, OutputTokens: 0, InputTokensDetails: &schema.InputTokensDetails{ TextTokens: 0, ImageTokens: 0, }, }, } jsonResult, _ := json.Marshal(resp) xlog.Debug("Response", "response", string(jsonResult)) // Return the prediction in the response body return c.JSON(200, resp) } } // processImageFile handles a single image file (URL or base64) and returns the path to the temporary file func processImageFile(file string, generatedContentDir string) string { fileData := []byte{} var err error // check if file is an URL, if so download it and save it to a temporary file if strings.HasPrefix(file, "http://") || strings.HasPrefix(file, "https://") { out, err := downloadFile(file) if err != nil { xlog.Error("Failed downloading file", "error", err, "file", file) return "" } defer os.RemoveAll(out) fileData, err = os.ReadFile(out) if err != nil { xlog.Error("Failed reading downloaded file", "error", err, "file", out) return "" } } else { // base 64 decode the file and write it somewhere that we will cleanup fileData, err = base64.StdEncoding.DecodeString(file) if err != nil { xlog.Error("Failed decoding base64 file", "error", err) return "" } } // Create a temporary file outputFile, err := os.CreateTemp(generatedContentDir, "b64") if err != nil { xlog.Error("Failed creating temporary file", "error", err) return "" } // write the decoded result writer := bufio.NewWriter(outputFile) _, err = writer.Write(fileData) if err != nil { outputFile.Close() xlog.Error("Failed writing to temporary file", "error", err) return "" } if err := writer.Flush(); err != nil { outputFile.Close() xlog.Error("Failed flushing to temporary file", "error", err) return "" } outputFile.Close() return outputFile.Name() } ================================================ FILE: core/http/endpoints/openai/image_test.go ================================================ package openai import ( "encoding/base64" "os" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("processImageFile", func() { var tmpDir string BeforeEach(func() { var err error tmpDir, err = os.MkdirTemp("", "processimage") Expect(err).ToNot(HaveOccurred()) }) AfterEach(func() { os.RemoveAll(tmpDir) }) It("should decode base64 and write all bytes to disk", func() { // 4x4 red pixel PNG (68 bytes raw) — small enough to fit in bufio's // default 4096-byte buffer, which is exactly the scenario where a // missing Flush() produces a 0-byte file. pngBytes := []byte{ 0x89, 0x50, 0x4e, 0x47, 0x0d, 0x0a, 0x1a, 0x0a, // PNG signature 0x00, 0x00, 0x00, 0x0d, 0x49, 0x48, 0x44, 0x52, // IHDR chunk 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x04, 0x08, 0x02, 0x00, 0x00, 0x00, 0x26, 0x93, 0x09, 0x29, 0x00, 0x00, 0x00, 0x1c, 0x49, 0x44, 0x41, // IDAT chunk 0x54, 0x78, 0x9c, 0x62, 0xf8, 0xcf, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0xc0, 0x00, 0x00, 0x00, 0x31, 0x00, 0x01, 0x2e, 0xa8, 0xd1, 0xe5, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4e, 0x44, // IEND chunk 0xae, 0x42, 0x60, 0x82, } b64 := base64.StdEncoding.EncodeToString(pngBytes) outPath := processImageFile(b64, tmpDir) Expect(outPath).ToNot(BeEmpty(), "processImageFile should return a file path") written, err := os.ReadFile(outPath) Expect(err).ToNot(HaveOccurred()) Expect(written).To(Equal(pngBytes), "file on disk must match the original bytes") }) It("should return empty string for invalid base64", func() { outPath := processImageFile("not-valid-base64!!!", tmpDir) Expect(outPath).To(BeEmpty(), "should return empty string for invalid base64") }) }) ================================================ FILE: core/http/endpoints/openai/inference.go ================================================ package openai import ( "encoding/json" "strings" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" pb "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) func ComputeChoices( req *schema.OpenAIRequest, predInput string, config *config.ModelConfig, bcl *config.ModelConfigLoader, o *config.ApplicationConfig, loader *model.ModelLoader, cb func(string, *[]schema.Choice), tokenCallback func(string, backend.TokenUsage) bool, shouldRetry ...func(int) bool, ) ([]schema.Choice, backend.TokenUsage, []*pb.ChatDelta, error) { n := req.N // number of completions to return result := []schema.Choice{} if n == 0 { n = 1 } // Extract the optional shouldRetry callback var shouldRetryFn func(int) bool if len(shouldRetry) > 0 { shouldRetryFn = shouldRetry[0] } images := []string{} for _, m := range req.Messages { images = append(images, m.StringImages...) } videos := []string{} for _, m := range req.Messages { videos = append(videos, m.StringVideos...) } audios := []string{} for _, m := range req.Messages { audios = append(audios, m.StringAudios...) } // Serialize tools and tool_choice to JSON strings toolsJSON := "" if len(req.Tools) > 0 { toolsBytes, err := json.Marshal(req.Tools) if err == nil { toolsJSON = string(toolsBytes) } } toolChoiceJSON := "" if req.ToolsChoice != nil { toolChoiceBytes, err := json.Marshal(req.ToolsChoice) if err == nil { toolChoiceJSON = string(toolChoiceBytes) } } // Extract logprobs from request // According to OpenAI API: logprobs is boolean, top_logprobs (0-20) controls how many top tokens per position var logprobs *int var topLogprobs *int if req.Logprobs.IsEnabled() { // If logprobs is enabled, use top_logprobs if provided, otherwise default to 1 if req.TopLogprobs != nil { topLogprobs = req.TopLogprobs // For backend compatibility, set logprobs to the top_logprobs value logprobs = req.TopLogprobs } else { // Default to 1 if logprobs is true but top_logprobs not specified val := 1 logprobs = &val topLogprobs = &val } } // Extract logit_bias from request // According to OpenAI API: logit_bias is a map of token IDs (as strings) to bias values (-100 to 100) var logitBias map[string]float64 if len(req.LogitBias) > 0 { logitBias = req.LogitBias } // get the model function to call for the result predFunc, err := backend.ModelInferenceFunc( req.Context, predInput, req.Messages, images, videos, audios, loader, config, bcl, o, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, req.Metadata) if err != nil { return result, backend.TokenUsage{}, nil, err } tokenUsage := backend.TokenUsage{} var allChatDeltas []*pb.ChatDelta const maxRetries = 5 for i := 0; i < n; i++ { var prediction backend.LLMResponse for attempt := 0; attempt <= maxRetries; attempt++ { p, err := predFunc() if err != nil { return result, backend.TokenUsage{}, nil, err } prediction = p // Built-in: retry on truly empty response (no tokens at all) if strings.TrimSpace(prediction.Response) == "" && attempt < maxRetries { xlog.Warn("Backend returned empty response, retrying", "attempt", attempt+1, "maxRetries", maxRetries) continue } tokenUsage.Prompt = prediction.Usage.Prompt tokenUsage.Completion = prediction.Usage.Completion tokenUsage.TimingPromptProcessing = prediction.Usage.TimingPromptProcessing tokenUsage.TimingTokenGeneration = prediction.Usage.TimingTokenGeneration allChatDeltas = prediction.ChatDeltas finetunedResponse := backend.Finetune(*config, predInput, prediction.Response) cb(finetunedResponse, &result) // Caller-driven retry (tool parsing, reasoning-only, etc.) if shouldRetryFn != nil && shouldRetryFn(attempt) && attempt < maxRetries { // Caller has already reset its state inside shouldRetry result = result[:0] allChatDeltas = nil continue } break } // Add logprobs to the last choice if present if prediction.Logprobs != nil && len(result) > 0 { result[len(result)-1].Logprobs = prediction.Logprobs } } return result, tokenUsage, allChatDeltas, err } ================================================ FILE: core/http/endpoints/openai/inference_test.go ================================================ package openai import ( "context" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" pb "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) type modelInferenceFunc = func( ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, backend.TokenUsage) bool, tools, toolChoice string, logprobs, topLogprobs *int, logitBias map[string]float64, metadata map[string]string, ) (func() (backend.LLMResponse, error), error) var _ = Describe("ComputeChoices", func() { var ( origInference modelInferenceFunc cfg *config.ModelConfig appCfg *config.ApplicationConfig ) // mockInference installs a stub that yields the given responses sequentially. // After all responses are consumed, the last one is repeated. mockInference := func(responses []backend.LLMResponse) { idx := 0 backend.ModelInferenceFunc = func( ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, backend.TokenUsage) bool, tools, toolChoice string, logprobs, topLogprobs *int, logitBias map[string]float64, metadata map[string]string, ) (func() (backend.LLMResponse, error), error) { predFunc := func() (backend.LLMResponse, error) { resp := responses[idx] if idx < len(responses)-1 { idx++ } return resp, nil } return predFunc, nil } } BeforeEach(func() { origInference = backend.ModelInferenceFunc cfg = &config.ModelConfig{} appCfg = config.NewApplicationConfig() }) AfterEach(func() { backend.ModelInferenceFunc = origInference }) makeReq := func() *schema.OpenAIRequest { ctx, cancel := context.WithCancel(context.Background()) _ = cancel return &schema.OpenAIRequest{ Context: ctx, Cancel: cancel, } } Context("normal response (no retry needed)", func() { It("should return choices on first attempt", func() { mockInference([]backend.LLMResponse{ {Response: "Hello world", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}}, }) var captured string choices, usage, _, err := ComputeChoices( makeReq(), "test prompt", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { captured = s *c = append(*c, schema.Choice{Text: s}) }, nil, ) Expect(err).ToNot(HaveOccurred()) Expect(choices).To(HaveLen(1)) Expect(captured).To(Equal("Hello world")) Expect(usage.Prompt).To(Equal(10)) Expect(usage.Completion).To(Equal(5)) }) }) Context("empty response triggers built-in retry", func() { It("should retry and eventually return non-empty response", func() { mockInference([]backend.LLMResponse{ {Response: ""}, // attempt 0: empty {Response: " "}, // attempt 1: whitespace-only {Response: "Got it", Usage: backend.TokenUsage{Prompt: 8, Completion: 3}}, // attempt 2: success }) choices, usage, _, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil, ) Expect(err).ToNot(HaveOccurred()) Expect(choices).To(HaveLen(1)) Expect(choices[0].Text).To(Equal("Got it")) Expect(usage.Prompt).To(Equal(8)) Expect(usage.Completion).To(Equal(3)) }) }) Context("all retries exhausted on empty response", func() { It("should return the empty response after max retries", func() { mockInference([]backend.LLMResponse{ {Response: ""}, // always empty }) choices, _, _, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil, ) Expect(err).ToNot(HaveOccurred()) // After maxRetries, it proceeds with the empty response Expect(choices).To(HaveLen(1)) Expect(choices[0].Text).To(BeEmpty()) }) }) Context("shouldRetry callback", func() { It("should call shouldRetry and retry when it returns true", func() { callCount := 0 mockInference([]backend.LLMResponse{ {Response: "reasoning-only", Usage: backend.TokenUsage{Prompt: 5, Completion: 2}}, {Response: "actual-answer", Usage: backend.TokenUsage{Prompt: 5, Completion: 4}}, }) retryAttempts := []int{} choices, usage, _, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { callCount++ *c = append(*c, schema.Choice{Text: s}) }, nil, func(attempt int) bool { retryAttempts = append(retryAttempts, attempt) // Retry on first attempt only return attempt == 0 }, ) Expect(err).ToNot(HaveOccurred()) Expect(choices).To(HaveLen(1)) Expect(choices[0].Text).To(Equal("actual-answer")) // shouldRetry was called twice: once returning true (retry), once returning false (proceed) Expect(retryAttempts).To(Equal([]int{0, 1})) // cb was called twice (once per attempt) Expect(callCount).To(Equal(2)) // Token usage should be from the LATEST attempt Expect(usage.Prompt).To(Equal(5)) Expect(usage.Completion).To(Equal(4)) }) It("should not retry when shouldRetry returns false", func() { mockInference([]backend.LLMResponse{ {Response: "first-response"}, }) shouldRetryCalled := false choices, _, _, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil, func(attempt int) bool { shouldRetryCalled = true return false }, ) Expect(err).ToNot(HaveOccurred()) Expect(choices).To(HaveLen(1)) Expect(choices[0].Text).To(Equal("first-response")) Expect(shouldRetryCalled).To(BeTrue()) }) }) Context("shouldRetry not provided (variadic omitted)", func() { It("should work without shouldRetry parameter", func() { mockInference([]backend.LLMResponse{ {Response: "works"}, }) choices, _, _, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil, ) Expect(err).ToNot(HaveOccurred()) Expect(choices).To(HaveLen(1)) Expect(choices[0].Text).To(Equal("works")) }) }) Context("token usage from latest attempt", func() { It("should use token usage from the last attempt, not accumulated", func() { mockInference([]backend.LLMResponse{ {Response: "retry-me", Usage: backend.TokenUsage{Prompt: 100, Completion: 50}}, {Response: "final", Usage: backend.TokenUsage{Prompt: 10, Completion: 5}}, }) _, usage, _, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil, func(attempt int) bool { return attempt == 0 }, ) Expect(err).ToNot(HaveOccurred()) // Should be the LATEST attempt's usage, not accumulated Expect(usage.Prompt).To(Equal(10)) Expect(usage.Completion).To(Equal(5)) }) }) Context("chat deltas from latest attempt", func() { It("should return chat deltas from the last attempt only", func() { mockInference([]backend.LLMResponse{ { Response: "retry-me", ChatDeltas: []*pb.ChatDelta{{Content: "old"}}, }, { Response: "final", ChatDeltas: []*pb.ChatDelta{{Content: "new"}}, }, }) _, _, deltas, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil, func(attempt int) bool { return attempt == 0 }, ) Expect(err).ToNot(HaveOccurred()) Expect(deltas).To(HaveLen(1)) Expect(deltas[0].Content).To(Equal("new")) }) }) Context("result choices cleared on retry", func() { It("should only contain choices from the final attempt", func() { mockInference([]backend.LLMResponse{ {Response: "bad-choice"}, {Response: "good-choice"}, }) choices, _, _, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil, func(attempt int) bool { return attempt == 0 }, ) Expect(err).ToNot(HaveOccurred()) Expect(choices).To(HaveLen(1)) Expect(choices[0].Text).To(Equal("good-choice")) }) }) Context("shouldRetry with max retries cap", func() { It("should stop retrying after maxRetries even if shouldRetry returns true", func() { attempts := 0 mockInference([]backend.LLMResponse{ {Response: "always-retry"}, }) choices, _, _, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil, func(attempt int) bool { attempts++ return true // always want to retry }, ) Expect(err).ToNot(HaveOccurred()) Expect(choices).To(HaveLen(1)) // maxRetries is 5, so shouldRetry is called for attempts 0..4, // but attempt 5 is the final one where shouldRetry can't trigger continue Expect(attempts).To(BeNumerically("<=", 6)) }) }) Context("N > 1 completions", func() { It("should produce N separate completions", func() { callIdx := 0 responses := []string{"first", "second", "third"} backend.ModelInferenceFunc = func( ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, backend.TokenUsage) bool, tools, toolChoice string, logprobs, topLogprobs *int, logitBias map[string]float64, metadata map[string]string, ) (func() (backend.LLMResponse, error), error) { predFunc := func() (backend.LLMResponse, error) { resp := backend.LLMResponse{Response: responses[callIdx]} if callIdx < len(responses)-1 { callIdx++ } return resp, nil } return predFunc, nil } req := makeReq() req.N = 3 choices, _, _, err := ComputeChoices( req, "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, nil, ) Expect(err).ToNot(HaveOccurred()) Expect(choices).To(HaveLen(3)) Expect(choices[0].Text).To(Equal("first")) Expect(choices[1].Text).To(Equal("second")) Expect(choices[2].Text).To(Equal("third")) }) }) Context("with streaming token callback", func() { It("should call tokenCallback for streaming responses", func() { var streamedTokens []string backend.ModelInferenceFunc = func( ctx context.Context, s string, messages schema.Messages, images, videos, audios []string, loader *model.ModelLoader, c *config.ModelConfig, cl *config.ModelConfigLoader, o *config.ApplicationConfig, tokenCallback func(string, backend.TokenUsage) bool, tools, toolChoice string, logprobs, topLogprobs *int, logitBias map[string]float64, metadata map[string]string, ) (func() (backend.LLMResponse, error), error) { predFunc := func() (backend.LLMResponse, error) { if tokenCallback != nil { tokenCallback("Hello", backend.TokenUsage{Prompt: 5}) tokenCallback(" world", backend.TokenUsage{Prompt: 5, Completion: 2}) } return backend.LLMResponse{ Response: "Hello world", Usage: backend.TokenUsage{Prompt: 5, Completion: 2}, }, nil } return predFunc, nil } choices, _, _, err := ComputeChoices( makeReq(), "test", cfg, nil, appCfg, nil, func(s string, c *[]schema.Choice) { *c = append(*c, schema.Choice{Text: s}) }, func(s string, usage backend.TokenUsage) bool { streamedTokens = append(streamedTokens, s) return true }, ) Expect(err).ToNot(HaveOccurred()) Expect(choices).To(HaveLen(1)) Expect(streamedTokens).To(Equal([]string{"Hello", " world"})) }) }) }) ================================================ FILE: core/http/endpoints/openai/inpainting.go ================================================ package openai import ( "encoding/base64" "encoding/json" "fmt" "io" "net/http" "net/url" "os" "path/filepath" "strconv" "time" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/xlog" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" model "github.com/mudler/LocalAI/pkg/model" ) // InpaintingEndpoint handles POST /v1/images/inpainting // // Swagger / OpenAPI docstring (swaggo): // @Summary Image inpainting // @Description Perform image inpainting. Accepts multipart/form-data with `image` and `mask` files. // @Tags images // @Accept multipart/form-data // @Produce application/json // @Param model formData string true "Model identifier" // @Param prompt formData string true "Text prompt guiding the generation" // @Param steps formData int false "Number of inference steps (default 25)" // @Param image formData file true "Original image file" // @Param mask formData file true "Mask image file (white = area to inpaint)" // @Success 200 {object} schema.OpenAIResponse // @Failure 400 {object} map[string]string // @Failure 500 {object} map[string]string // @Router /v1/images/inpainting [post] func InpaintingEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { // Parse basic form values modelName := c.FormValue("model") prompt := c.FormValue("prompt") stepsStr := c.FormValue("steps") if modelName == "" || prompt == "" { xlog.Error("Inpainting Endpoint - missing model or prompt") return echo.ErrBadRequest } // steps default steps := 25 if stepsStr != "" { if v, err := strconv.Atoi(stepsStr); err == nil { steps = v } } // Get uploaded files imageFile, err := c.FormFile("image") if err != nil { xlog.Error("Inpainting Endpoint - missing image file", "error", err) return echo.NewHTTPError(http.StatusBadRequest, "missing image file") } maskFile, err := c.FormFile("mask") if err != nil { xlog.Error("Inpainting Endpoint - missing mask file", "error", err) return echo.NewHTTPError(http.StatusBadRequest, "missing mask file") } // Read files into memory (small files expected) imgSrc, err := imageFile.Open() if err != nil { return err } defer imgSrc.Close() imgBytes, err := io.ReadAll(imgSrc) if err != nil { return err } maskSrc, err := maskFile.Open() if err != nil { return err } defer maskSrc.Close() maskBytes, err := io.ReadAll(maskSrc) if err != nil { return err } // Create JSON with base64 fields expected by backend b64Image := base64.StdEncoding.EncodeToString(imgBytes) b64Mask := base64.StdEncoding.EncodeToString(maskBytes) // get model config from context (middleware set it) cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { xlog.Error("Inpainting Endpoint - model config not found in context") return echo.ErrBadRequest } // Use the GeneratedContentDir so the generated PNG is placed where the // HTTP static handler serves `/generated-images`. tmpDir := appConfig.GeneratedContentDir // Ensure the directory exists if err := os.MkdirAll(tmpDir, 0750); err != nil { xlog.Error("Inpainting Endpoint - failed to create generated content dir", "error", err, "dir", tmpDir) return echo.NewHTTPError(http.StatusInternalServerError, "failed to prepare storage") } id := uuid.New().String() jsonPath := filepath.Join(tmpDir, fmt.Sprintf("inpaint_%s.json", id)) jsonFile := map[string]string{ "image": b64Image, "mask_image": b64Mask, } jf, err := os.CreateTemp(tmpDir, "inpaint_") if err != nil { return err } // setup cleanup on error; if everything succeeds we set success = true success := false var dst string var origRef string var maskRef string defer func() { if !success { // Best-effort cleanup; log any failures if jf != nil { if cerr := jf.Close(); cerr != nil { xlog.Warn("Inpainting Endpoint - failed to close temp json file in cleanup", "error", cerr) } if name := jf.Name(); name != "" { if rerr := os.Remove(name); rerr != nil && !os.IsNotExist(rerr) { xlog.Warn("Inpainting Endpoint - failed to remove temp json file in cleanup", "error", rerr, "file", name) } } } if jsonPath != "" { if rerr := os.Remove(jsonPath); rerr != nil && !os.IsNotExist(rerr) { xlog.Warn("Inpainting Endpoint - failed to remove json file in cleanup", "error", rerr, "file", jsonPath) } } if dst != "" { if rerr := os.Remove(dst); rerr != nil && !os.IsNotExist(rerr) { xlog.Warn("Inpainting Endpoint - failed to remove dst file in cleanup", "error", rerr, "file", dst) } } if origRef != "" { if rerr := os.Remove(origRef); rerr != nil && !os.IsNotExist(rerr) { xlog.Warn("Inpainting Endpoint - failed to remove orig ref file in cleanup", "error", rerr, "file", origRef) } } if maskRef != "" { if rerr := os.Remove(maskRef); rerr != nil && !os.IsNotExist(rerr) { xlog.Warn("Inpainting Endpoint - failed to remove mask ref file in cleanup", "error", rerr, "file", maskRef) } } } }() // write original image and mask to disk as ref images so backends that // accept reference image files can use them (maintainer request). origTmp, err := os.CreateTemp(tmpDir, "refimg_") if err != nil { return err } if _, err := origTmp.Write(imgBytes); err != nil { _ = origTmp.Close() _ = os.Remove(origTmp.Name()) return err } if cerr := origTmp.Close(); cerr != nil { xlog.Warn("Inpainting Endpoint - failed to close orig temp file", "error", cerr) } origRef = origTmp.Name() maskTmp, err := os.CreateTemp(tmpDir, "refmask_") if err != nil { // cleanup origTmp on error _ = os.Remove(origRef) return err } if _, err := maskTmp.Write(maskBytes); err != nil { _ = maskTmp.Close() _ = os.Remove(maskTmp.Name()) _ = os.Remove(origRef) return err } if cerr := maskTmp.Close(); cerr != nil { xlog.Warn("Inpainting Endpoint - failed to close mask temp file", "error", cerr) } maskRef = maskTmp.Name() // write JSON enc := json.NewEncoder(jf) if err := enc.Encode(jsonFile); err != nil { if cerr := jf.Close(); cerr != nil { xlog.Warn("Inpainting Endpoint - failed to close temp json file after encode error", "error", cerr) } return err } if cerr := jf.Close(); cerr != nil { xlog.Warn("Inpainting Endpoint - failed to close temp json file", "error", cerr) } // rename to desired name if err := os.Rename(jf.Name(), jsonPath); err != nil { return err } // prepare dst outTmp, err := os.CreateTemp(tmpDir, "out_") if err != nil { return err } if cerr := outTmp.Close(); cerr != nil { xlog.Warn("Inpainting Endpoint - failed to close out temp file", "error", cerr) } dst = outTmp.Name() + ".png" if err := os.Rename(outTmp.Name(), dst); err != nil { return err } // Determine width/height default width := 512 height := 512 // Call backend image generation via indirection so tests can stub it // Note: ImageGenerationFunc will call into the loaded model's GenerateImage which expects src JSON // Also pass ref images (orig + mask) so backends that support ref images can use them. refImages := []string{origRef, maskRef} fn, err := backend.ImageGenerationFunc(height, width, steps, 0, prompt, "", jsonPath, dst, ml, *cfg, appConfig, refImages) if err != nil { return err } // Execute generation function (blocking) if err := fn(); err != nil { return err } // On success, build response URL using BaseURL middleware helper and // the same `generated-images` prefix used by the server static mount. baseURL := middleware.BaseURL(c) // Build response using url.JoinPath for correct URL escaping imgPath, err := url.JoinPath(baseURL, "generated-images", filepath.Base(dst)) if err != nil { return err } created := int(time.Now().Unix()) resp := &schema.OpenAIResponse{ ID: id, Created: created, Data: []schema.Item{{ URL: imgPath, }}, Usage: schema.OpenAIUsage{ PromptTokens: 0, CompletionTokens: 0, TotalTokens: 0, InputTokens: 0, OutputTokens: 0, InputTokensDetails: &schema.InputTokensDetails{ TextTokens: 0, ImageTokens: 0, }, }, } // mark success so defer cleanup will not remove output files success = true return c.JSON(http.StatusOK, resp) } } ================================================ FILE: core/http/endpoints/openai/inpainting_test.go ================================================ package openai import ( "bytes" "mime/multipart" "net/http" "net/http/httptest" "os" "path/filepath" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" model "github.com/mudler/LocalAI/pkg/model" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func makeMultipartRequest(fields map[string]string, files map[string][]byte) (*http.Request, string) { b := &bytes.Buffer{} w := multipart.NewWriter(b) for k, v := range fields { _ = w.WriteField(k, v) } for fname, content := range files { fw, err := w.CreateFormFile(fname, fname+".png") Expect(err).ToNot(HaveOccurred()) _, err = fw.Write(content) Expect(err).ToNot(HaveOccurred()) } Expect(w.Close()).To(Succeed()) req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", b) req.Header.Set("Content-Type", w.FormDataContentType()) return req, w.FormDataContentType() } var _ = Describe("Inpainting", func() { It("returns error for missing files", func() { e := echo.New() h := InpaintingEndpoint(nil, nil, config.NewApplicationConfig()) req := httptest.NewRequest(http.MethodPost, "/v1/images/inpainting", nil) rec := httptest.NewRecorder() c := e.NewContext(req, rec) err := h(c) Expect(err).To(HaveOccurred()) }) It("handles the happy path", func() { tmpDir, err := os.MkdirTemp("", "gencontent") Expect(err).ToNot(HaveOccurred()) DeferCleanup(func() { os.RemoveAll(tmpDir) }) appConf := config.NewApplicationConfig(config.WithGeneratedContentDir(tmpDir)) orig := backend.ImageGenerationFunc backend.ImageGenerationFunc = func(height, width, step, seed int, positive_prompt, negative_prompt, src, dst string, loader *model.ModelLoader, modelConfig config.ModelConfig, appConfig *config.ApplicationConfig, refImages []string) (func() error, error) { fn := func() error { return os.WriteFile(dst, []byte("PNGDATA"), 0644) } return fn, nil } DeferCleanup(func() { backend.ImageGenerationFunc = orig }) fields := map[string]string{"model": "dreamshaper-8-inpainting", "prompt": "A test"} files := map[string][]byte{"image": []byte("IMAGEDATA"), "mask": []byte("MASKDATA")} reqBuf, _ := makeMultipartRequest(fields, files) rec := httptest.NewRecorder() e := echo.New() c := e.NewContext(reqBuf, rec) c.Set(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG, &config.ModelConfig{Backend: "diffusers"}) h := InpaintingEndpoint(nil, nil, appConf) err = h(c) Expect(err).ToNot(HaveOccurred()) Expect(rec.Code).To(Equal(http.StatusOK)) body := rec.Body.String() Expect(body).To(ContainSubstring("generated-images")) idx := bytes.Index(rec.Body.Bytes(), []byte("generated-images/")) Expect(idx).To(BeNumerically(">=", 0)) rest := rec.Body.Bytes()[idx:] end := bytes.IndexAny(rest, "\",}\n") if end == -1 { end = len(rest) } fname := string(rest[len("generated-images/"):end]) _, err = os.Stat(filepath.Join(tmpDir, fname)) Expect(err).ToNot(HaveOccurred()) }) }) ================================================ FILE: core/http/endpoints/openai/list.go ================================================ package openai import ( "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" model "github.com/mudler/LocalAI/pkg/model" "gorm.io/gorm" ) // ListModelsEndpoint is the OpenAI Models API endpoint https://platform.openai.com/docs/api-reference/models // @Summary List and describe the various models available in the API. // @Success 200 {object} schema.ModelsDataResponse "Response" // @Router /v1/models [get] func ListModelsEndpoint(bcl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, db ...*gorm.DB) echo.HandlerFunc { var authDB *gorm.DB if len(db) > 0 { authDB = db[0] } return func(c echo.Context) error { // If blank, no filter is applied. filter := c.QueryParam("filter") // By default, exclude any loose files that are already referenced by a configuration file. var policy services.LooseFilePolicy excludeConfigured := c.QueryParam("excludeConfigured") if excludeConfigured == "" || excludeConfigured == "true" { policy = services.SKIP_IF_CONFIGURED } else { policy = services.ALWAYS_INCLUDE // This replicates current behavior. TODO: give more options to the user? } filterFn, err := config.BuildNameFilterFn(filter) if err != nil { return err } modelNames, err := services.ListModels(bcl, ml, filterFn, policy) if err != nil { return err } // Filter models by user's allowlist if auth is enabled if authDB != nil { if user := auth.GetUser(c); user != nil && user.Role != auth.RoleAdmin { perm, err := auth.GetCachedUserPermissions(c, authDB, user.ID) if err == nil && perm.AllowedModels.Enabled { allowed := map[string]bool{} for _, m := range perm.AllowedModels.Models { allowed[m] = true } filtered := make([]string, 0, len(modelNames)) for _, m := range modelNames { if allowed[m] { filtered = append(filtered, m) } } modelNames = filtered } } } // Map from a slice of names to a slice of OpenAIModel response objects dataModels := []schema.OpenAIModel{} for _, m := range modelNames { dataModels = append(dataModels, schema.OpenAIModel{ID: m, Object: "model"}) } return c.JSON(200, schema.ModelsDataResponse{ Object: "list", Data: dataModels, }) } } ================================================ FILE: core/http/endpoints/openai/openai_suite_test.go ================================================ package openai import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestOpenAI(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "OpenAI Endpoints Suite") } ================================================ FILE: core/http/endpoints/openai/realtime.go ================================================ package openai import ( "context" "encoding/base64" "encoding/binary" "encoding/json" "fmt" "math" "os" "sync" "time" "net/http" "github.com/go-audio/audio" "github.com/gorilla/websocket" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/openai/types" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" laudio "github.com/mudler/LocalAI/pkg/audio" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/grpc" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/reasoning" "github.com/mudler/LocalAI/pkg/sound" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/xlog" ) const ( // XXX: Presently it seems all ASR/VAD backends use 16Khz. If a backend uses 24Khz then it will likely still work, but have reduced performance localSampleRate = 16000 defaultRemoteSampleRate = 24000 // Maximum audio buffer size in bytes (100MB) to prevent memory exhaustion maxAudioBufferSize = 100 * 1024 * 1024 // Maximum WebSocket message size in bytes (10MB) to prevent DoS attacks maxWebSocketMessageSize = 10 * 1024 * 1024 defaultInstructions = "You are a helpful voice assistant. " + "Your responses will be spoken aloud using text-to-speech, so keep them concise and conversational. " + "Do not use markdown formatting, bullet points, numbered lists, code blocks, or special characters. " + "Speak naturally as you would in a phone conversation. " + "Avoid parenthetical asides, URLs, and anything that cannot be clearly vocalized." ) // A model can be "emulated" that is: transcribe audio to text -> feed text to the LLM -> generate audio as result // If the model support instead audio-to-audio, we will use the specific gRPC calls instead // Session represents a single WebSocket connection and its state type Session struct { ID string TranscriptionOnly bool // The pipeline or any-to-any model name (full realtime mode) Model string // The voice may be a TTS model name or a parameter passed to a TTS model Voice string TurnDetection *types.TurnDetectionUnion // "server_vad", "semantic_vad" or "none" InputAudioTranscription *types.AudioTranscription Tools []types.ToolUnion ToolChoice *types.ToolChoiceUnion Conversations map[string]*Conversation InputAudioBuffer []byte AudioBufferLock sync.Mutex OpusFrames [][]byte OpusFramesLock sync.Mutex Instructions string DefaultConversationID string ModelInterface Model // The pipeline model config or the config for an any-to-any model ModelConfig *config.ModelConfig InputSampleRate int OutputSampleRate int MaxOutputTokens types.IntOrInf // Response cancellation: protects activeResponseCancel/activeResponseDone responseMu sync.Mutex activeResponseCancel context.CancelFunc activeResponseDone chan struct{} } // cancelActiveResponse cancels any in-flight response and waits for its // goroutine to exit. This ensures we never have overlapping responses and // that interrupted responses are fully cleaned up before starting a new one. func (s *Session) cancelActiveResponse() { s.responseMu.Lock() cancel := s.activeResponseCancel done := s.activeResponseDone s.responseMu.Unlock() if cancel != nil { cancel() } if done != nil { <-done } } // startResponse cancels any active response and returns a new context for // the replacement response. The caller MUST close the returned done channel // when the response goroutine exits. func (s *Session) startResponse(parent context.Context) (context.Context, chan struct{}) { s.cancelActiveResponse() ctx, cancel := context.WithCancel(parent) done := make(chan struct{}) s.responseMu.Lock() s.activeResponseCancel = cancel s.activeResponseDone = done s.responseMu.Unlock() return ctx, done } func (s *Session) FromClient(session *types.SessionUnion) { } func (s *Session) ToServer() types.SessionUnion { if s.TranscriptionOnly { return types.SessionUnion{ Transcription: &types.TranscriptionSession{ ID: s.ID, Object: "realtime.transcription_session", Audio: &types.TranscriptionSessionAudio{ Input: &types.SessionAudioInput{ Transcription: s.InputAudioTranscription, }, }, }, } } else { return types.SessionUnion{ Realtime: &types.RealtimeSession{ ID: s.ID, Object: "realtime.session", Model: s.Model, Instructions: s.Instructions, Tools: s.Tools, ToolChoice: s.ToolChoice, MaxOutputTokens: s.MaxOutputTokens, Audio: &types.RealtimeSessionAudio{ Input: &types.SessionAudioInput{ TurnDetection: s.TurnDetection, Transcription: s.InputAudioTranscription, }, Output: &types.SessionAudioOutput{ Voice: types.Voice(s.Voice), }, }, }, } } } // Conversation represents a conversation with a list of items type Conversation struct { ID string Items []*types.MessageItemUnion Lock sync.Mutex } func (c *Conversation) ToServer() types.Conversation { return types.Conversation{ ID: c.ID, Object: "realtime.conversation", } } // Map to store sessions (in-memory) var sessions = make(map[string]*Session) var sessionLock sync.Mutex type Model interface { VAD(ctx context.Context, request *schema.VADRequest) (*schema.VADResponse, error) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) PredictConfig() *config.ModelConfig } var upgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true // Allow all origins }, } // TODO: Implement ephemeral keys to allow these endpoints to be used func RealtimeSessions(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { return c.NoContent(501) } } func RealtimeTranscriptionSession(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { return c.NoContent(501) } } func Realtime(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { ws, err := upgrader.Upgrade(c.Response(), c.Request(), nil) if err != nil { return err } defer ws.Close() // Set maximum message size to prevent DoS attacks ws.SetReadLimit(maxWebSocketMessageSize) // Extract query parameters from Echo context before passing to websocket handler model := c.QueryParam("model") registerRealtime(application, model)(ws) return nil } } func registerRealtime(application *application.Application, model string) func(c *websocket.Conn) { return func(conn *websocket.Conn) { t := NewWebSocketTransport(conn) evaluator := application.TemplatesEvaluator() xlog.Debug("Realtime WebSocket connection established", "address", conn.RemoteAddr().String(), "model", model) runRealtimeSession(application, t, model, evaluator) } } // runRealtimeSession runs the main event loop for a realtime session. // It is transport-agnostic and works with both WebSocket and WebRTC. func runRealtimeSession(application *application.Application, t Transport, model string, evaluator *templates.Evaluator) { // TODO: Allow any-to-any model to be specified cl := application.ModelConfigLoader() cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(model, application.ApplicationConfig()) if err != nil { xlog.Error("failed to load model config", "error", err) sendError(t, "model_load_error", "Failed to load model config", "", "") return } if cfg == nil || (cfg.Pipeline.VAD == "" && cfg.Pipeline.Transcription == "" && cfg.Pipeline.TTS == "" && cfg.Pipeline.LLM == "") { xlog.Error("model is not a pipeline", "model", model) sendError(t, "invalid_model", "Model is not a pipeline model", "", "") return } sttModel := cfg.Pipeline.Transcription sessionID := generateSessionID() session := &Session{ ID: sessionID, TranscriptionOnly: false, Model: model, Voice: cfg.TTSConfig.Voice, Instructions: defaultInstructions, ModelConfig: cfg, TurnDetection: &types.TurnDetectionUnion{ ServerVad: &types.ServerVad{ Threshold: 0.5, PrefixPaddingMs: 300, SilenceDurationMs: 500, CreateResponse: true, }, }, InputAudioTranscription: &types.AudioTranscription{ Model: sttModel, }, Conversations: make(map[string]*Conversation), InputSampleRate: defaultRemoteSampleRate, OutputSampleRate: defaultRemoteSampleRate, } // Create a default conversation conversationID := generateConversationID() conversation := &Conversation{ ID: conversationID, // TODO: We need to truncate the conversation items when a new item is added and we have run out of space. There are multiple places where items // can be added so we could use a datastructure here that enforces truncation upon addition Items: []*types.MessageItemUnion{}, } session.Conversations[conversationID] = conversation session.DefaultConversationID = conversationID m, err := newModel( &cfg.Pipeline, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), evaluator, ) if err != nil { xlog.Error("failed to load model", "error", err) sendError(t, "model_load_error", "Failed to load model", "", "") return } session.ModelInterface = m // Store the session and notify the transport (for WebRTC audio track handling) sessionLock.Lock() sessions[sessionID] = session sessionLock.Unlock() // For WebRTC, inbound audio arrives as Opus (48kHz) and is decoded+resampled // to localSampleRate in handleIncomingAudioTrack. Set InputSampleRate to // match so handleVAD doesn't needlessly double-resample. if _, ok := t.(*WebRTCTransport); ok { session.InputSampleRate = localSampleRate } if sn, ok := t.(interface{ SetSession(*Session) }); ok { sn.SetSession(session) } sendEvent(t, types.SessionCreatedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, Session: session.ToServer(), }) var ( msg []byte wg sync.WaitGroup done = make(chan struct{}) ) vadServerStarted := false toggleVAD := func() { if session.TurnDetection != nil && session.TurnDetection.ServerVad != nil && !vadServerStarted { xlog.Debug("Starting VAD goroutine...") done = make(chan struct{}) wg.Add(1) go func() { defer wg.Done() conversation := session.Conversations[session.DefaultConversationID] handleVAD(session, conversation, t, done) }() vadServerStarted = true } else if (session.TurnDetection == nil || session.TurnDetection.ServerVad == nil) && vadServerStarted { xlog.Debug("Stopping VAD goroutine...") close(done) vadServerStarted = false } } // For WebRTC sessions, start the Opus decode loop before VAD so that // decoded PCM is already flowing when VAD's first tick fires. var decodeDone chan struct{} if wt, ok := t.(*WebRTCTransport); ok { decodeDone = make(chan struct{}) go decodeOpusLoop(session, wt.opusBackend, decodeDone) } toggleVAD() for { msg, err = t.ReadEvent() if err != nil { xlog.Error("read error", "error", err) break } // Handle diagnostic events that aren't part of the OpenAI protocol var rawType struct { Type string `json:"type"` } if json.Unmarshal(msg, &rawType) == nil && rawType.Type == "test_tone" { if _, ok := t.(*WebSocketTransport); ok { sendError(t, "not_supported", "test_tone is only supported on WebRTC connections", "", "") } else { xlog.Debug("Generating test tone") go sendTestTone(t) } continue } // Parse the incoming message event, err := types.UnmarshalClientEvent(msg) if err != nil { xlog.Error("invalid json", "error", err) sendError(t, "invalid_json", "Invalid JSON format", "", "") continue } switch e := event.(type) { case types.SessionUpdateEvent: xlog.Debug("recv", "message", string(msg)) // Handle transcription session update if e.Session.Transcription != nil { if err := updateTransSession( session, &e.Session, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), ); err != nil { xlog.Error("failed to update session", "error", err) sendError(t, "session_update_error", "Failed to update session", "", "") continue } toggleVAD() sendEvent(t, types.SessionUpdatedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, Session: session.ToServer(), }) } // Handle realtime session update if e.Session.Realtime != nil { if err := updateSession( session, &e.Session, application.ModelConfigLoader(), application.ModelLoader(), application.ApplicationConfig(), evaluator, ); err != nil { xlog.Error("failed to update session", "error", err) sendError(t, "session_update_error", "Failed to update session", "", "") continue } toggleVAD() sendEvent(t, types.SessionUpdatedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, Session: session.ToServer(), }) } case types.InputAudioBufferAppendEvent: // Handle 'input_audio_buffer.append' if e.Audio == "" { xlog.Error("Audio data is missing in 'input_audio_buffer.append'") sendError(t, "missing_audio_data", "Audio data is missing", "", "") continue } // Decode base64 audio data decodedAudio, err := base64.StdEncoding.DecodeString(e.Audio) if err != nil { xlog.Error("failed to decode audio data", "error", err) sendError(t, "invalid_audio_data", "Failed to decode audio data", "", "") continue } // Check buffer size limits before appending session.AudioBufferLock.Lock() newSize := len(session.InputAudioBuffer) + len(decodedAudio) if newSize > maxAudioBufferSize { session.AudioBufferLock.Unlock() xlog.Error("audio buffer size limit exceeded", "current_size", len(session.InputAudioBuffer), "incoming_size", len(decodedAudio), "limit", maxAudioBufferSize) sendError(t, "buffer_size_exceeded", fmt.Sprintf("Audio buffer size limit exceeded (max %d bytes)", maxAudioBufferSize), "", "") continue } // Append to InputAudioBuffer session.InputAudioBuffer = append(session.InputAudioBuffer, decodedAudio...) session.AudioBufferLock.Unlock() case types.InputAudioBufferCommitEvent: xlog.Debug("recv", "message", string(msg)) sessionLock.Lock() isServerVAD := session.TurnDetection != nil && session.TurnDetection.ServerVad != nil sessionLock.Unlock() // TODO: At the least need to check locking and timer state in the VAD Go routine before allowing this if isServerVAD { sendNotImplemented(t, "input_audio_buffer.commit in conjunction with VAD") continue } session.AudioBufferLock.Lock() allAudio := make([]byte, len(session.InputAudioBuffer)) copy(allAudio, session.InputAudioBuffer) session.InputAudioBuffer = nil session.AudioBufferLock.Unlock() sendEvent(t, types.InputAudioBufferCommittedEvent{ ServerEventBase: types.ServerEventBase{}, ItemID: generateItemID(), }) respCtx, respDone := session.startResponse(context.Background()) go func() { defer close(respDone) commitUtterance(respCtx, allAudio, session, conversation, t) }() case types.ConversationItemCreateEvent: xlog.Debug("recv", "message", string(msg)) // Add the item to the conversation item := e.Item // Ensure IDs are present if item.User != nil && item.User.ID == "" { item.User.ID = generateItemID() } if item.Assistant != nil && item.Assistant.ID == "" { item.Assistant.ID = generateItemID() } if item.System != nil && item.System.ID == "" { item.System.ID = generateItemID() } if item.FunctionCall != nil && item.FunctionCall.ID == "" { item.FunctionCall.ID = generateItemID() } if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { item.FunctionCallOutput.ID = generateItemID() } conversation.Lock.Lock() conversation.Items = append(conversation.Items, &item) conversation.Lock.Unlock() sendEvent(t, types.ConversationItemAddedEvent{ ServerEventBase: types.ServerEventBase{ EventID: e.EventID, }, PreviousItemID: e.PreviousItemID, Item: item, }) case types.ConversationItemDeleteEvent: sendError(t, "not_implemented", "Deleting items not implemented", "", "event_TODO") case types.ConversationItemRetrieveEvent: xlog.Debug("recv", "message", string(msg)) if e.ItemID == "" { sendError(t, "invalid_item_id", "Need item_id, but none specified", "", "event_TODO") continue } conversation.Lock.Lock() var retrievedItem types.MessageItemUnion for _, item := range conversation.Items { // We need to check ID in the union var id string if item.System != nil { id = item.System.ID } else if item.User != nil { id = item.User.ID } else if item.Assistant != nil { id = item.Assistant.ID } else if item.FunctionCall != nil { id = item.FunctionCall.ID } else if item.FunctionCallOutput != nil { id = item.FunctionCallOutput.ID } if id == e.ItemID { retrievedItem = *item break } } conversation.Lock.Unlock() sendEvent(t, types.ConversationItemRetrievedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, Item: retrievedItem, }) case types.ResponseCreateEvent: xlog.Debug("recv", "message", string(msg)) // Handle optional items to add to context if len(e.Response.Input) > 0 { conversation.Lock.Lock() for _, item := range e.Response.Input { // Ensure IDs are present if item.User != nil && item.User.ID == "" { item.User.ID = generateItemID() } if item.Assistant != nil && item.Assistant.ID == "" { item.Assistant.ID = generateItemID() } if item.System != nil && item.System.ID == "" { item.System.ID = generateItemID() } if item.FunctionCall != nil && item.FunctionCall.ID == "" { item.FunctionCall.ID = generateItemID() } if item.FunctionCallOutput != nil && item.FunctionCallOutput.ID == "" { item.FunctionCallOutput.ID = generateItemID() } conversation.Items = append(conversation.Items, &item) } conversation.Lock.Unlock() } respCtx, respDone := session.startResponse(context.Background()) go func() { defer close(respDone) triggerResponse(respCtx, session, conversation, t, &e.Response) }() case types.ResponseCancelEvent: xlog.Debug("recv", "message", string(msg)) session.cancelActiveResponse() default: xlog.Error("unknown message type") // sendError(t, "unknown_message_type", fmt.Sprintf("Unknown message type: %s", incomingMsg.Type), "", "") } } // Cancel any in-flight response before tearing down session.cancelActiveResponse() // Stop the Opus decode goroutine (if running) if decodeDone != nil { close(decodeDone) } // Signal any running VAD goroutine to exit. if vadServerStarted { close(done) } wg.Wait() // Remove the session from the sessions map sessionLock.Lock() delete(sessions, sessionID) sessionLock.Unlock() } // sendEvent sends a server event via the transport, logging any errors. func sendEvent(t Transport, event types.ServerEvent) { if err := t.SendEvent(event); err != nil { xlog.Error("write error", "error", err) } } // sendError sends an error event to the client. func sendError(t Transport, code, message, param, eventID string) { errorEvent := types.ErrorEvent{ ServerEventBase: types.ServerEventBase{ EventID: eventID, }, Error: types.Error{ Type: "invalid_request_error", Code: code, Message: message, Param: param, EventID: eventID, }, } sendEvent(t, errorEvent) } func sendNotImplemented(t Transport, message string) { sendError(t, "not_implemented", message, "", "event_TODO") } // sendTestTone generates a 1-second 440 Hz sine wave and sends it through // the transport's audio path. This exercises the full Opus encode → RTP → // browser decode pipeline without involving TTS. func sendTestTone(t Transport) { const ( freq = 440.0 sampleRate = 24000 duration = 1 // seconds amplitude = 16000 numSamples = sampleRate * duration ) pcm := make([]byte, numSamples*2) // 16-bit samples = 2 bytes each for i := 0; i < numSamples; i++ { sample := int16(amplitude * math.Sin(2*math.Pi*freq*float64(i)/sampleRate)) binary.LittleEndian.PutUint16(pcm[i*2:], uint16(sample)) } xlog.Debug("Sending test tone", "samples", numSamples, "sample_rate", sampleRate, "freq", freq) if err := t.SendAudio(context.Background(), pcm, sampleRate); err != nil { xlog.Error("test tone send failed", "error", err) } } func updateTransSession(session *Session, update *types.SessionUnion, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) error { sessionLock.Lock() defer sessionLock.Unlock() // In transcription session update, we look at Transcription field if update.Transcription == nil || update.Transcription.Audio == nil || update.Transcription.Audio.Input == nil { return nil } trUpd := update.Transcription.Audio.Input.Transcription trCur := session.InputAudioTranscription session.TranscriptionOnly = true if trUpd != nil && trUpd.Model != "" && trUpd.Model != trCur.Model { cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(trUpd.Model, appConfig) if err != nil { return err } if cfg == nil || (cfg.Pipeline.VAD == "" || cfg.Pipeline.Transcription == "") { return fmt.Errorf("model is not a valid pipeline model: %s", trUpd.Model) } m, cfg, err := newTranscriptionOnlyModel(&cfg.Pipeline, cl, ml, appConfig) if err != nil { return err } session.ModelInterface = m session.ModelConfig = cfg } if trUpd != nil { trCur.Language = trUpd.Language trCur.Prompt = trUpd.Prompt } if update.Transcription.Audio.Input.TurnDetectionSet { session.TurnDetection = update.Transcription.Audio.Input.TurnDetection } if update.Transcription.Audio.Input.Format != nil && update.Transcription.Audio.Input.Format.PCM != nil { if update.Transcription.Audio.Input.Format.PCM.Rate > 0 { session.InputSampleRate = update.Transcription.Audio.Input.Format.PCM.Rate } } return nil } func updateSession(session *Session, update *types.SessionUnion, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, evaluator *templates.Evaluator) error { sessionLock.Lock() defer sessionLock.Unlock() if update.Realtime == nil { return nil } session.TranscriptionOnly = false rt := update.Realtime if rt.Model != "" { cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(rt.Model, appConfig) if err != nil { return err } if cfg == nil || (cfg.Pipeline.VAD == "" || cfg.Pipeline.Transcription == "" || cfg.Pipeline.TTS == "" || cfg.Pipeline.LLM == "") { return fmt.Errorf("model is not a valid pipeline model: %s", rt.Model) } if session.InputAudioTranscription == nil { session.InputAudioTranscription = &types.AudioTranscription{} } session.InputAudioTranscription.Model = cfg.Pipeline.Transcription session.Voice = cfg.TTSConfig.Voice session.Model = rt.Model session.ModelConfig = cfg } if rt.Audio != nil && rt.Audio.Output != nil && rt.Audio.Output.Voice != "" { session.Voice = string(rt.Audio.Output.Voice) } if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.Transcription != nil { session.InputAudioTranscription = rt.Audio.Input.Transcription session.ModelConfig.Pipeline.Transcription = rt.Audio.Input.Transcription.Model } if rt.Model != "" || (rt.Audio != nil && rt.Audio.Output != nil && rt.Audio.Output.Voice != "") || (rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.Transcription != nil) { m, err := newModel(&session.ModelConfig.Pipeline, cl, ml, appConfig, evaluator) if err != nil { return err } session.ModelInterface = m } if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.TurnDetectionSet { session.TurnDetection = rt.Audio.Input.TurnDetection } if rt.Audio != nil && rt.Audio.Input != nil && rt.Audio.Input.Format != nil && rt.Audio.Input.Format.PCM != nil { if rt.Audio.Input.Format.PCM.Rate > 0 { session.InputSampleRate = rt.Audio.Input.Format.PCM.Rate } } if rt.Audio != nil && rt.Audio.Output != nil && rt.Audio.Output.Format != nil && rt.Audio.Output.Format.PCM != nil { if rt.Audio.Output.Format.PCM.Rate > 0 { session.OutputSampleRate = rt.Audio.Output.Format.PCM.Rate } } if rt.Instructions != "" { session.Instructions = rt.Instructions } if rt.Tools != nil { session.Tools = rt.Tools } if rt.ToolChoice != nil { session.ToolChoice = rt.ToolChoice } if rt.MaxOutputTokens != 0 { session.MaxOutputTokens = rt.MaxOutputTokens } return nil } // decodeOpusLoop runs a ticker that drains buffered raw Opus frames from the // session, decodes them in a single batched gRPC call, and appends the // resulting PCM to InputAudioBuffer. This gives ~3 gRPC calls/sec instead of // 50 (one per RTP packet) and keeps decode diagnostics once-per-batch. func decodeOpusLoop(session *Session, opusBackend grpc.Backend, done chan struct{}) { ticker := time.NewTicker(300 * time.Millisecond) defer ticker.Stop() for { select { case <-ticker.C: session.OpusFramesLock.Lock() frames := session.OpusFrames session.OpusFrames = nil session.OpusFramesLock.Unlock() if len(frames) == 0 { continue } result, err := opusBackend.AudioDecode(context.Background(), &proto.AudioDecodeRequest{ Frames: frames, Options: map[string]string{ "session_id": session.ID, }, }) if err != nil { xlog.Warn("opus decode batch error", "error", err, "frames", len(frames)) continue } samples := sound.BytesToInt16sLE(result.PcmData) xlog.Debug("opus decode batch", "frames", len(frames), "decoded_samples", len(samples), "sample_rate", result.SampleRate, ) // Resample from 48kHz to session input rate (16kHz) if needed if result.SampleRate != int32(session.InputSampleRate) { samples = sound.ResampleInt16(samples, int(result.SampleRate), session.InputSampleRate) } pcmBytes := sound.Int16toBytesLE(samples) session.AudioBufferLock.Lock() newSize := len(session.InputAudioBuffer) + len(pcmBytes) if newSize <= maxAudioBufferSize { session.InputAudioBuffer = append(session.InputAudioBuffer, pcmBytes...) } session.AudioBufferLock.Unlock() case <-done: return } } } // handleVAD is a goroutine that listens for audio data from the client, // runs VAD on the audio data, and commits utterances to the conversation func handleVAD(session *Session, conv *Conversation, t Transport, done chan struct{}) { vadContext, cancel := context.WithCancel(context.Background()) go func() { <-done cancel() }() silenceThreshold := 0.5 // Default 500ms if session.TurnDetection != nil && session.TurnDetection.ServerVad != nil { silenceThreshold = float64(session.TurnDetection.ServerVad.SilenceDurationMs) / 1000 } speechStarted := false startTime := time.Now() ticker := time.NewTicker(300 * time.Millisecond) defer ticker.Stop() for { select { case <-done: return case <-ticker.C: session.AudioBufferLock.Lock() allAudio := make([]byte, len(session.InputAudioBuffer)) copy(allAudio, session.InputAudioBuffer) session.AudioBufferLock.Unlock() aints := sound.BytesToInt16sLE(allAudio) if len(aints) == 0 || len(aints) < int(silenceThreshold*float64(session.InputSampleRate)) { continue } // Resample from InputSampleRate to 16kHz aints = sound.ResampleInt16(aints, session.InputSampleRate, localSampleRate) segments, err := runVAD(vadContext, session, aints) if err != nil { if err.Error() == "unexpected speech end" { xlog.Debug("VAD cancelled") continue } xlog.Error("failed to process audio", "error", err) sendError(t, "processing_error", "Failed to process audio: "+err.Error(), "", "") continue } audioLength := float64(len(aints)) / localSampleRate // TODO: When resetting the buffer we should retain a small postfix if len(segments) == 0 && audioLength > silenceThreshold { session.AudioBufferLock.Lock() session.InputAudioBuffer = nil session.AudioBufferLock.Unlock() continue } else if len(segments) == 0 { continue } if !speechStarted { // Barge-in: cancel any in-flight response so we stop // sending audio and don't keep the interrupted reply in history. session.cancelActiveResponse() sendEvent(t, types.InputAudioBufferSpeechStartedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, AudioStartMs: time.Since(startTime).Milliseconds(), }) speechStarted = true } // Segment still in progress when audio ended segEndTime := segments[len(segments)-1].End if segEndTime == 0 { continue } if float32(audioLength)-segEndTime > float32(silenceThreshold) { xlog.Debug("Detected end of speech segment") session.AudioBufferLock.Lock() session.InputAudioBuffer = nil session.AudioBufferLock.Unlock() sendEvent(t, types.InputAudioBufferSpeechStoppedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, AudioEndMs: time.Since(startTime).Milliseconds(), }) speechStarted = false sendEvent(t, types.InputAudioBufferCommittedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, ItemID: generateItemID(), PreviousItemID: "TODO", }) abytes := sound.Int16toBytesLE(aints) // TODO: Remove prefix silence that is is over TurnDetectionParams.PrefixPaddingMs respCtx, respDone := session.startResponse(vadContext) go func() { defer close(respDone) commitUtterance(respCtx, abytes, session, conv, t) }() } } } } func commitUtterance(ctx context.Context, utt []byte, session *Session, conv *Conversation, t Transport) { if len(utt) == 0 { return } f, err := os.CreateTemp("", "realtime-audio-chunk-*.wav") if err != nil { xlog.Error("failed to create temp file", "error", err) return } defer f.Close() defer os.Remove(f.Name()) xlog.Debug("Writing to file", "file", f.Name()) hdr := laudio.NewWAVHeader(uint32(len(utt))) if err := hdr.Write(f); err != nil { xlog.Error("Failed to write WAV header", "error", err) return } if _, err := f.Write(utt); err != nil { xlog.Error("Failed to write audio data", "error", err) return } f.Sync() // TODO: If we have a real any-to-any model then transcription is optional var transcript string if session.InputAudioTranscription != nil { tr, err := session.ModelInterface.Transcribe(ctx, f.Name(), session.InputAudioTranscription.Language, false, false, session.InputAudioTranscription.Prompt) if err != nil { sendError(t, "transcription_failed", err.Error(), "", "event_TODO") return } else if tr == nil { sendError(t, "transcription_failed", "trancribe result is nil", "", "event_TODO") return } transcript = tr.Text sendEvent(t, types.ConversationItemInputAudioTranscriptionCompletedEvent{ ServerEventBase: types.ServerEventBase{ EventID: "event_TODO", }, ItemID: generateItemID(), // ResponseID: "resp_TODO", // Not needed for transcription completed event // OutputIndex: 0, ContentIndex: 0, Transcript: transcript, }) } else { sendNotImplemented(t, "any-to-any models") return } if !session.TranscriptionOnly { generateResponse(ctx, session, utt, transcript, conv, t) } } func runVAD(ctx context.Context, session *Session, adata []int16) ([]schema.VADSegment, error) { soundIntBuffer := &audio.IntBuffer{ Format: &audio.Format{SampleRate: localSampleRate, NumChannels: 1}, SourceBitDepth: 16, Data: sound.ConvertInt16ToInt(adata), } float32Data := soundIntBuffer.AsFloat32Buffer().Data resp, err := session.ModelInterface.VAD(ctx, &schema.VADRequest{ Audio: float32Data, }) if err != nil { return nil, err } // If resp.Segments is empty => no speech return resp.Segments, nil } // Function to generate a response based on the conversation func generateResponse(ctx context.Context, session *Session, utt []byte, transcript string, conv *Conversation, t Transport) { xlog.Debug("Generating realtime response...") // Create user message item item := types.MessageItemUnion{ User: &types.MessageItemUser{ ID: generateItemID(), Status: types.ItemStatusCompleted, Content: []types.MessageContentInput{ { Type: types.MessageContentTypeInputAudio, Audio: base64.StdEncoding.EncodeToString(utt), Transcript: transcript, }, }, }, } conv.Lock.Lock() conv.Items = append(conv.Items, &item) conv.Lock.Unlock() sendEvent(t, types.ConversationItemAddedEvent{ Item: item, }) triggerResponse(ctx, session, conv, t, nil) } func triggerResponse(ctx context.Context, session *Session, conv *Conversation, t Transport, overrides *types.ResponseCreateParams) { config := session.ModelInterface.PredictConfig() // Default values tools := session.Tools toolChoice := session.ToolChoice instructions := session.Instructions maxOutputTokens := session.MaxOutputTokens // Overrides if overrides != nil { if overrides.Tools != nil { tools = overrides.Tools } if overrides.ToolChoice != nil { toolChoice = overrides.ToolChoice } if overrides.Instructions != "" { instructions = overrides.Instructions } if overrides.MaxOutputTokens != 0 { maxOutputTokens = overrides.MaxOutputTokens } } // Apply MaxOutputTokens to model config if specified // Save original value to restore after prediction var originalMaxTokens *int if config != nil { originalMaxTokens = config.Maxtokens if maxOutputTokens != 0 && !maxOutputTokens.IsInf() { tokenValue := int(maxOutputTokens) config.Maxtokens = &tokenValue xlog.Debug("Applied max_output_tokens to config", "value", tokenValue) } } // Defer restoration of original value defer func() { if config != nil { config.Maxtokens = originalMaxTokens } }() var conversationHistory schema.Messages conversationHistory = append(conversationHistory, schema.Message{ Role: string(types.MessageRoleSystem), StringContent: instructions, Content: instructions, }) imgIndex := 0 conv.Lock.Lock() for _, item := range conv.Items { if item.User != nil { msg := schema.Message{ Role: string(types.MessageRoleUser), } textContent := "" nrOfImgsInMessage := 0 for _, content := range item.User.Content { switch content.Type { case types.MessageContentTypeInputText: textContent += content.Text case types.MessageContentTypeInputAudio: textContent += content.Transcript case types.MessageContentTypeInputImage: img, err := utils.GetContentURIAsBase64(content.ImageURL) if err != nil { xlog.Warn("Failed to process image", "error", err) continue } msg.StringImages = append(msg.StringImages, img) imgIndex++ nrOfImgsInMessage++ } } if nrOfImgsInMessage > 0 { templated, err := templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ TotalImages: imgIndex, ImagesInMessage: nrOfImgsInMessage, }, textContent) if err != nil { xlog.Warn("Failed to apply multimodal template", "error", err) templated = textContent } msg.StringContent = templated msg.Content = templated } else { msg.StringContent = textContent msg.Content = textContent } conversationHistory = append(conversationHistory, msg) } else if item.Assistant != nil { for _, content := range item.Assistant.Content { switch content.Type { case types.MessageContentTypeOutputText: conversationHistory = append(conversationHistory, schema.Message{ Role: string(types.MessageRoleAssistant), StringContent: content.Text, Content: content.Text, }) case types.MessageContentTypeOutputAudio: conversationHistory = append(conversationHistory, schema.Message{ Role: string(types.MessageRoleAssistant), StringContent: content.Transcript, Content: content.Transcript, StringAudios: []string{content.Audio}, }) } } } else if item.System != nil { for _, content := range item.System.Content { conversationHistory = append(conversationHistory, schema.Message{ Role: string(types.MessageRoleSystem), StringContent: content.Text, Content: content.Text, }) } } else if item.FunctionCall != nil { conversationHistory = append(conversationHistory, schema.Message{ Role: string(types.MessageRoleAssistant), ToolCalls: []schema.ToolCall{ { ID: item.FunctionCall.CallID, Type: "function", FunctionCall: schema.FunctionCall{ Name: item.FunctionCall.Name, Arguments: item.FunctionCall.Arguments, }, }, }, }) } else if item.FunctionCallOutput != nil { conversationHistory = append(conversationHistory, schema.Message{ Role: "tool", Name: item.FunctionCallOutput.CallID, Content: item.FunctionCallOutput.Output, StringContent: item.FunctionCallOutput.Output, }) } } conv.Lock.Unlock() var images []string for _, m := range conversationHistory { images = append(images, m.StringImages...) } responseID := generateUniqueID() sendEvent(t, types.ResponseCreatedEvent{ ServerEventBase: types.ServerEventBase{}, Response: types.Response{ ID: responseID, Object: "realtime.response", Status: types.ResponseStatusInProgress, }, }) predFunc, err := session.ModelInterface.Predict(ctx, conversationHistory, images, nil, nil, nil, tools, toolChoice, nil, nil, nil) if err != nil { sendError(t, "inference_failed", fmt.Sprintf("backend error: %v", err), "", "") // item.Assistant.ID is unknown here return } pred, err := predFunc() if err != nil { sendError(t, "prediction_failed", fmt.Sprintf("backend error: %v", err), "", "") return } // Check for cancellation after LLM inference (barge-in may have fired) if ctx.Err() != nil { xlog.Debug("Response cancelled after LLM inference (barge-in)") sendEvent(t, types.ResponseDoneEvent{ ServerEventBase: types.ServerEventBase{}, Response: types.Response{ ID: responseID, Object: "realtime.response", Status: types.ResponseStatusCancelled, }, }) return } xlog.Debug("Function config for parsing", "function_name_key", config.FunctionsConfig.FunctionNameKey, "function_arguments_key", config.FunctionsConfig.FunctionArgumentsKey) xlog.Debug("LLM raw response", "text", pred.Response, "response_length", len(pred.Response), "usage", pred.Usage) // Safely dereference pointer fields for logging maxTokens := "nil" if config.Maxtokens != nil { maxTokens = fmt.Sprintf("%d", *config.Maxtokens) } contextSize := "nil" if config.ContextSize != nil { contextSize = fmt.Sprintf("%d", *config.ContextSize) } xlog.Debug("Model parameters", "max_tokens", maxTokens, "context_size", contextSize, "stopwords", config.StopWords) rawResponse := pred.Response if config.TemplateConfig.ReplyPrefix != "" { rawResponse = config.TemplateConfig.ReplyPrefix + rawResponse } // Detect thinking start token from template for reasoning extraction var template string if config.TemplateConfig.UseTokenizerTemplate { template = config.GetModelTemplate() } else { template = config.TemplateConfig.Chat } thinkingStartToken := reasoning.DetectThinkingStartToken(template, &config.ReasoningConfig) reasoningText, responseWithoutReasoning := reasoning.ExtractReasoningWithConfig(rawResponse, thinkingStartToken, config.ReasoningConfig) xlog.Debug("LLM Response", "reasoning", reasoningText, "response_without_reasoning", responseWithoutReasoning) textContent := functions.ParseTextContent(responseWithoutReasoning, config.FunctionsConfig) cleanedResponse := functions.CleanupLLMResult(responseWithoutReasoning, config.FunctionsConfig) toolCalls := functions.ParseFunctionCall(cleanedResponse, config.FunctionsConfig) xlog.Debug("Function call parsing", "textContent", textContent, "cleanedResponse", cleanedResponse, "toolCallsCount", len(toolCalls)) noActionName := "answer" if config.FunctionsConfig.NoActionFunctionName != "" { noActionName = config.FunctionsConfig.NoActionFunctionName } isNoAction := len(toolCalls) > 0 && toolCalls[0].Name == noActionName var finalSpeech string var finalToolCalls []functions.FuncCallResults if isNoAction { arg := toolCalls[0].Arguments arguments := map[string]interface{}{} if err := json.Unmarshal([]byte(arg), &arguments); err == nil { if m, exists := arguments["message"]; exists { if message, ok := m.(string); ok { finalSpeech = message } else { xlog.Warn("NoAction function message field is not a string", "type", fmt.Sprintf("%T", m)) } } else { xlog.Warn("NoAction function missing 'message' field in arguments") } } else { xlog.Warn("Failed to unmarshal NoAction function arguments", "error", err, "arguments", arg) } if finalSpeech == "" { // Fallback if parsing failed xlog.Warn("NoAction function did not produce speech, using cleaned response as fallback") finalSpeech = cleanedResponse } } else { finalToolCalls = toolCalls xlog.Debug("Setting finalToolCalls", "count", len(finalToolCalls)) if len(toolCalls) > 0 { finalSpeech = textContent } else { finalSpeech = cleanedResponse } } if finalSpeech != "" { // Create the assistant item now that we have content item := types.MessageItemUnion{ Assistant: &types.MessageItemAssistant{ ID: generateItemID(), Status: types.ItemStatusInProgress, Content: []types.MessageContentOutput{ { Type: types.MessageContentTypeOutputAudio, Transcript: finalSpeech, }, }, }, } conv.Lock.Lock() conv.Items = append(conv.Items, &item) conv.Lock.Unlock() sendEvent(t, types.ResponseOutputItemAddedEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: 0, Item: item, }) sendEvent(t, types.ResponseContentPartAddedEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, OutputIndex: 0, ContentIndex: 0, Part: item.Assistant.Content[0], }) // removeItemFromConv removes the last occurrence of an item with // the given assistant ID from conversation history. removeItemFromConv := func(assistantID string) { conv.Lock.Lock() for i := len(conv.Items) - 1; i >= 0; i-- { if conv.Items[i].Assistant != nil && conv.Items[i].Assistant.ID == assistantID { conv.Items = append(conv.Items[:i], conv.Items[i+1:]...) break } } conv.Lock.Unlock() } // sendCancelledResponse emits the cancelled status and cleans up the // assistant item so the interrupted reply is not in chat history. sendCancelledResponse := func() { removeItemFromConv(item.Assistant.ID) sendEvent(t, types.ResponseDoneEvent{ ServerEventBase: types.ServerEventBase{}, Response: types.Response{ ID: responseID, Object: "realtime.response", Status: types.ResponseStatusCancelled, }, }) } // Check for cancellation before TTS if ctx.Err() != nil { xlog.Debug("Response cancelled before TTS (barge-in)") sendCancelledResponse() return } audioFilePath, res, err := session.ModelInterface.TTS(ctx, finalSpeech, session.Voice, session.InputAudioTranscription.Language) if err != nil { if ctx.Err() != nil { xlog.Debug("TTS cancelled (barge-in)") sendCancelledResponse() return } xlog.Error("TTS failed", "error", err) sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %v", err), "", item.Assistant.ID) return } if !res.Success { xlog.Error("TTS failed", "message", res.Message) sendError(t, "tts_error", fmt.Sprintf("TTS generation failed: %s", res.Message), "", item.Assistant.ID) return } defer os.Remove(audioFilePath) audioBytes, err := os.ReadFile(audioFilePath) if err != nil { xlog.Error("failed to read TTS file", "error", err) sendError(t, "tts_error", fmt.Sprintf("Failed to read TTS audio: %v", err), "", item.Assistant.ID) return } // Parse WAV header to get raw PCM and the actual sample rate from the TTS backend. pcmData, ttsSampleRate := laudio.ParseWAV(audioBytes) if ttsSampleRate == 0 { ttsSampleRate = localSampleRate } xlog.Debug("TTS audio parsed", "raw_bytes", len(audioBytes), "pcm_bytes", len(pcmData), "sample_rate", ttsSampleRate) // SendAudio (WebRTC) passes PCM at the TTS sample rate directly to the // Opus encoder, which resamples to 48kHz internally. This avoids a // lossy intermediate resample through 16kHz. // XXX: This is a noop in websocket mode; it's included in the JSON instead if err := t.SendAudio(ctx, pcmData, ttsSampleRate); err != nil { if ctx.Err() != nil { xlog.Debug("Audio playback cancelled (barge-in)") sendCancelledResponse() return } xlog.Error("failed to send audio via transport", "error", err) } _, isWebRTC := t.(*WebRTCTransport) // For WebSocket clients, resample to the session's output rate and // deliver audio as base64 in JSON events. WebRTC clients already // received audio over the RTP track, so skip the base64 payload. var audioString string if !isWebRTC { wsPCM := pcmData if ttsSampleRate != session.OutputSampleRate { samples := sound.BytesToInt16sLE(pcmData) resampled := sound.ResampleInt16(samples, ttsSampleRate, session.OutputSampleRate) wsPCM = sound.Int16toBytesLE(resampled) } audioString = base64.StdEncoding.EncodeToString(wsPCM) } sendEvent(t, types.ResponseOutputAudioTranscriptDeltaEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, OutputIndex: 0, ContentIndex: 0, Delta: finalSpeech, }) sendEvent(t, types.ResponseOutputAudioTranscriptDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, OutputIndex: 0, ContentIndex: 0, Transcript: finalSpeech, }) if !isWebRTC { sendEvent(t, types.ResponseOutputAudioDeltaEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, OutputIndex: 0, ContentIndex: 0, Delta: audioString, }) sendEvent(t, types.ResponseOutputAudioDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, OutputIndex: 0, ContentIndex: 0, }) } sendEvent(t, types.ResponseContentPartDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: item.Assistant.ID, OutputIndex: 0, ContentIndex: 0, Part: item.Assistant.Content[0], }) conv.Lock.Lock() item.Assistant.Status = types.ItemStatusCompleted if !isWebRTC { item.Assistant.Content[0].Audio = audioString } conv.Lock.Unlock() sendEvent(t, types.ResponseOutputItemDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: 0, Item: item, }) } // Handle Tool Calls xlog.Debug("About to handle tool calls", "finalToolCallsCount", len(finalToolCalls)) for i, tc := range finalToolCalls { toolCallID := generateItemID() callID := "call_" + generateUniqueID() // OpenAI uses call_xyz // Create FunctionCall Item fcItem := types.MessageItemUnion{ FunctionCall: &types.MessageItemFunctionCall{ ID: toolCallID, CallID: callID, Name: tc.Name, Arguments: tc.Arguments, Status: types.ItemStatusCompleted, }, } conv.Lock.Lock() conv.Items = append(conv.Items, &fcItem) conv.Lock.Unlock() outputIndex := i if finalSpeech != "" { outputIndex++ } sendEvent(t, types.ResponseOutputItemAddedEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: outputIndex, Item: fcItem, }) sendEvent(t, types.ResponseFunctionCallArgumentsDeltaEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: toolCallID, OutputIndex: outputIndex, CallID: callID, Delta: tc.Arguments, }) sendEvent(t, types.ResponseFunctionCallArgumentsDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, ItemID: toolCallID, OutputIndex: outputIndex, CallID: callID, Arguments: tc.Arguments, Name: tc.Name, }) sendEvent(t, types.ResponseOutputItemDoneEvent{ ServerEventBase: types.ServerEventBase{}, ResponseID: responseID, OutputIndex: outputIndex, Item: fcItem, }) } sendEvent(t, types.ResponseDoneEvent{ ServerEventBase: types.ServerEventBase{}, Response: types.Response{ ID: responseID, Object: "realtime.response", Status: types.ResponseStatusCompleted, }, }) } // Helper functions to generate unique IDs func generateSessionID() string { // Generate a unique session ID // Implement as needed return "sess_" + generateUniqueID() } func generateConversationID() string { // Generate a unique conversation ID // Implement as needed return "conv_" + generateUniqueID() } func generateItemID() string { // Generate a unique item ID // Implement as needed return "item_" + generateUniqueID() } func generateUniqueID() string { // Generate a unique ID string // For simplicity, use a counter or UUID // Implement as needed return "unique_id" } ================================================ FILE: core/http/endpoints/openai/realtime_model.go ================================================ package openai import ( "context" "encoding/json" "fmt" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/endpoints/openai/types" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/grpc/proto" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) var ( _ Model = new(wrappedModel) _ Model = new(transcriptOnlyModel) ) // wrappedModel represent a model which does not support Any-to-Any operations // This means that we will fake an Any-to-Any model by overriding some of the gRPC client methods // which are for Any-To-Any models, but instead we will call a pipeline (for e.g STT->LLM->TTS) type wrappedModel struct { TTSConfig *config.ModelConfig TranscriptionConfig *config.ModelConfig LLMConfig *config.ModelConfig VADConfig *config.ModelConfig appConfig *config.ApplicationConfig modelLoader *model.ModelLoader confLoader *config.ModelConfigLoader evaluator *templates.Evaluator } // anyToAnyModel represent a model which supports Any-to-Any operations // We have to wrap this out as well because we want to load two models one for VAD and one for the actual model. // In the future there could be models that accept continous audio input only so this design will be useful for that type anyToAnyModel struct { LLMConfig *config.ModelConfig VADConfig *config.ModelConfig appConfig *config.ApplicationConfig modelLoader *model.ModelLoader confLoader *config.ModelConfigLoader } type transcriptOnlyModel struct { TranscriptionConfig *config.ModelConfig VADConfig *config.ModelConfig appConfig *config.ApplicationConfig modelLoader *model.ModelLoader confLoader *config.ModelConfigLoader } func (m *transcriptOnlyModel) VAD(ctx context.Context, request *schema.VADRequest) (*schema.VADResponse, error) { return backend.VAD(request, ctx, m.modelLoader, m.appConfig, *m.VADConfig) } func (m *transcriptOnlyModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) } func (m *transcriptOnlyModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { return nil, fmt.Errorf("predict operation not supported in transcript-only mode") } func (m *transcriptOnlyModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) { return "", nil, fmt.Errorf("TTS not supported in transcript-only mode") } func (m *transcriptOnlyModel) PredictConfig() *config.ModelConfig { return nil } func (m *wrappedModel) VAD(ctx context.Context, request *schema.VADRequest) (*schema.VADResponse, error) { return backend.VAD(request, ctx, m.modelLoader, m.appConfig, *m.VADConfig) } func (m *wrappedModel) Transcribe(ctx context.Context, audio, language string, translate bool, diarize bool, prompt string) (*schema.TranscriptionResult, error) { return backend.ModelTranscription(audio, language, translate, diarize, prompt, m.modelLoader, *m.TranscriptionConfig, m.appConfig) } func (m *wrappedModel) Predict(ctx context.Context, messages schema.Messages, images, videos, audios []string, tokenCallback func(string, backend.TokenUsage) bool, tools []types.ToolUnion, toolChoice *types.ToolChoiceUnion, logprobs *int, topLogprobs *int, logitBias map[string]float64) (func() (backend.LLMResponse, error), error) { input := schema.OpenAIRequest{ Messages: messages, } var predInput string var funcs []functions.Function if !m.LLMConfig.TemplateConfig.UseTokenizerTemplate { if len(tools) > 0 { for _, t := range tools { if t.Function != nil { var params map[string]any switch p := t.Function.Parameters.(type) { case map[string]any: params = p case string: if err := json.Unmarshal([]byte(p), ¶ms); err != nil { xlog.Warn("Failed to parse parameters JSON string", "error", err, "function", t.Function.Name) } } funcs = append(funcs, functions.Function{ Name: t.Function.Name, Description: t.Function.Description, Parameters: params, }) } } // Add noAction function before templating so it's included in the prompt // Allow the user to set custom actions via config file noActionName := "answer" noActionDescription := "use this action to answer without performing any action" if m.LLMConfig.FunctionsConfig.NoActionFunctionName != "" { noActionName = m.LLMConfig.FunctionsConfig.NoActionFunctionName } if m.LLMConfig.FunctionsConfig.NoActionDescriptionName != "" { noActionDescription = m.LLMConfig.FunctionsConfig.NoActionDescriptionName } noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, Parameters: map[string]interface{}{ "properties": map[string]interface{}{ "message": map[string]interface{}{ "type": "string", "description": "The message to reply the user with", }, }, }, } if !m.LLMConfig.FunctionsConfig.DisableNoAction { funcs = append(funcs, noActionGrammar) } } predInput = m.evaluator.TemplateMessages(input, input.Messages, m.LLMConfig, funcs, len(funcs) > 0) xlog.Debug("Prompt (after templating)", "prompt", predInput) if m.LLMConfig.Grammar != "" { xlog.Debug("Grammar", "grammar", m.LLMConfig.Grammar) } } // Handle tool_choice parameter similar to the chat endpoint if toolChoice != nil { if toolChoice.Mode != "" { // String values: "auto", "required", "none" switch toolChoice.Mode { case types.ToolChoiceModeRequired: m.LLMConfig.SetFunctionCallString("required") case types.ToolChoiceModeNone: // Don't use tools m.LLMConfig.SetFunctionCallString("none") case types.ToolChoiceModeAuto: // Default behavior - let model decide } } else if toolChoice.Function != nil { // Specific function specified m.LLMConfig.SetFunctionCallString(toolChoice.Function.Name) } } // Generate grammar for function calling if tools are provided and grammar generation is enabled shouldUseFn := len(tools) > 0 && m.LLMConfig.ShouldUseFunctions() if !m.LLMConfig.FunctionsConfig.GrammarConfig.NoGrammar && shouldUseFn { // Force picking one of the functions by the request if m.LLMConfig.FunctionToCall() != "" { funcs = functions.Functions(funcs).Select(m.LLMConfig.FunctionToCall()) } // Generate grammar from function definitions jsStruct := functions.Functions(funcs).ToJSONStructure(m.LLMConfig.FunctionsConfig.FunctionNameKey, m.LLMConfig.FunctionsConfig.FunctionNameKey) g, err := jsStruct.Grammar(m.LLMConfig.FunctionsConfig.GrammarOptions()...) if err == nil { m.LLMConfig.Grammar = g xlog.Debug("Generated grammar for function calling", "grammar", g) } else { xlog.Error("Failed generating grammar", "error", err) } } var toolsJSON string if len(tools) > 0 { // Convert tools to OpenAI Chat Completions format (nested) // as expected by most backends (including llama.cpp) var chatTools []functions.Tool for _, t := range tools { if t.Function != nil { var params map[string]interface{} switch p := t.Function.Parameters.(type) { case map[string]interface{}: params = p case string: if err := json.Unmarshal([]byte(p), ¶ms); err != nil { xlog.Warn("Failed to parse parameters JSON string", "error", err, "function", t.Function.Name) } case nil: params = map[string]interface{}{} default: // Try to marshal/unmarshal to get map b, err := json.Marshal(p) if err == nil { _ = json.Unmarshal(b, ¶ms) } } chatTools = append(chatTools, functions.Tool{ Type: "function", Function: functions.Function{ Name: t.Function.Name, Description: t.Function.Description, Parameters: params, }, }) } } b, _ := json.Marshal(chatTools) toolsJSON = string(b) } var toolChoiceJSON string if toolChoice != nil { b, _ := json.Marshal(toolChoice) toolChoiceJSON = string(b) } return backend.ModelInference(ctx, predInput, messages, images, videos, audios, m.modelLoader, m.LLMConfig, m.confLoader, m.appConfig, tokenCallback, toolsJSON, toolChoiceJSON, logprobs, topLogprobs, logitBias, nil) } func (m *wrappedModel) TTS(ctx context.Context, text, voice, language string) (string, *proto.Result, error) { return backend.ModelTTS(text, voice, language, m.modelLoader, m.appConfig, *m.TTSConfig) } func (m *wrappedModel) PredictConfig() *config.ModelConfig { return m.LLMConfig } func newTranscriptionOnlyModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) (Model, *config.ModelConfig, error) { cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath) if err != nil { return nil, nil, fmt.Errorf("failed to load backend config: %w", err) } if valid, _ := cfgVAD.Validate(); !valid { return nil, nil, fmt.Errorf("failed to validate config: %w", err) } cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) if err != nil { return nil, nil, fmt.Errorf("failed to load backend config: %w", err) } if valid, _ := cfgSST.Validate(); !valid { return nil, nil, fmt.Errorf("failed to validate config: %w", err) } return &transcriptOnlyModel{ TranscriptionConfig: cfgSST, VADConfig: cfgVAD, confLoader: cl, modelLoader: ml, appConfig: appConfig, }, cfgSST, nil } // returns and loads either a wrapped model or a model that support audio-to-audio func newModel(pipeline *config.Pipeline, cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig, evaluator *templates.Evaluator) (Model, error) { xlog.Debug("Creating new model pipeline model", "pipeline", pipeline) cfgVAD, err := cl.LoadModelConfigFileByName(pipeline.VAD, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) } if valid, _ := cfgVAD.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } // TODO: Do we always need a transcription model? It can be disabled. Note that any-to-any instruction following models don't transcribe as such, so if transcription is required it is a separate process cfgSST, err := cl.LoadModelConfigFileByName(pipeline.Transcription, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) } if valid, _ := cfgSST.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } // TODO: Decide when we have a real any-to-any model // if false { // // cfgAnyToAny, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) // if err != nil { // // return nil, fmt.Errorf("failed to load backend config: %w", err) // } // // if valid, _ := cfgAnyToAny.Validate(); !valid { // return nil, fmt.Errorf("failed to validate config: %w", err) // } // // return &anyToAnyModel{ // LLMConfig: cfgAnyToAny, // VADConfig: cfgVAD, // }, nil // } xlog.Debug("Loading a wrapped model") // Otherwise we want to return a wrapped model, which is a "virtual" model that re-uses other models to perform operations cfgLLM, err := cl.LoadModelConfigFileByName(pipeline.LLM, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) } if valid, _ := cfgLLM.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } cfgTTS, err := cl.LoadModelConfigFileByName(pipeline.TTS, ml.ModelPath) if err != nil { return nil, fmt.Errorf("failed to load backend config: %w", err) } if valid, _ := cfgTTS.Validate(); !valid { return nil, fmt.Errorf("failed to validate config: %w", err) } return &wrappedModel{ TTSConfig: cfgTTS, TranscriptionConfig: cfgSST, LLMConfig: cfgLLM, VADConfig: cfgVAD, confLoader: cl, modelLoader: ml, appConfig: appConfig, evaluator: evaluator, }, nil } ================================================ FILE: core/http/endpoints/openai/realtime_transport.go ================================================ package openai import ( "context" "github.com/mudler/LocalAI/core/http/endpoints/openai/types" ) // Transport abstracts event and audio I/O so the same session logic // can serve both WebSocket and WebRTC connections. type Transport interface { // SendEvent marshals and sends a server event to the client. SendEvent(event types.ServerEvent) error // ReadEvent reads the next raw client event (JSON bytes). ReadEvent() ([]byte, error) // SendAudio sends raw PCM audio to the client at the given sample rate. // For WebSocket this is a no-op (audio is sent via JSON events). // For WebRTC this encodes to Opus and writes to the media track. // The context allows cancellation for barge-in support. SendAudio(ctx context.Context, pcmData []byte, sampleRate int) error // Close tears down the underlying connection. Close() error } ================================================ FILE: core/http/endpoints/openai/realtime_transport_webrtc.go ================================================ package openai import ( "context" "encoding/json" "fmt" "math/rand/v2" "sync" "time" "github.com/mudler/LocalAI/core/http/endpoints/openai/types" "github.com/mudler/LocalAI/pkg/grpc" pb "github.com/mudler/LocalAI/pkg/grpc/proto" "github.com/mudler/xlog" "github.com/pion/rtp" "github.com/pion/webrtc/v4" ) // WebRTCTransport implements Transport over a pion/webrtc PeerConnection. // Events travel via the "oai-events" DataChannel; audio goes over an RTP track. type WebRTCTransport struct { pc *webrtc.PeerConnection dc *webrtc.DataChannel audioTrack *webrtc.TrackLocalStaticRTP opusBackend grpc.Backend inEvents chan []byte outEvents chan []byte // buffered outbound event queue closed chan struct{} closeOnce sync.Once flushed chan struct{} // closed when sender goroutine has drained outEvents dcReady chan struct{} // closed when data channel is open dcReadyOnce sync.Once sessionCh chan *Session // delivers session from runRealtimeSession to handleIncomingAudioTrack // RTP state for outbound audio — protected by rtpMu rtpMu sync.Mutex rtpSeqNum uint16 rtpTimestamp uint32 rtpMarker bool // true → next packet gets marker bit set } func NewWebRTCTransport(pc *webrtc.PeerConnection, audioTrack *webrtc.TrackLocalStaticRTP, opusBackend grpc.Backend) *WebRTCTransport { t := &WebRTCTransport{ pc: pc, audioTrack: audioTrack, opusBackend: opusBackend, inEvents: make(chan []byte, 256), outEvents: make(chan []byte, 256), closed: make(chan struct{}), flushed: make(chan struct{}), dcReady: make(chan struct{}), sessionCh: make(chan *Session, 1), rtpSeqNum: uint16(rand.UintN(65536)), rtpTimestamp: rand.Uint32(), rtpMarker: true, // first packet of the stream gets marker } // The client creates the "oai-events" data channel (so m=application is // included in the SDP offer). We receive it here via OnDataChannel. pc.OnDataChannel(func(dc *webrtc.DataChannel) { if dc.Label() != "oai-events" { return } t.dc = dc dc.OnOpen(func() { t.dcReadyOnce.Do(func() { close(t.dcReady) }) }) dc.OnMessage(func(msg webrtc.DataChannelMessage) { select { case t.inEvents <- msg.Data: case <-t.closed: } }) // The channel may already be open by the time OnDataChannel fires if dc.ReadyState() == webrtc.DataChannelStateOpen { t.dcReadyOnce.Do(func() { close(t.dcReady) }) } }) pc.OnConnectionStateChange(func(state webrtc.PeerConnectionState) { xlog.Debug("WebRTC connection state", "state", state.String()) if state == webrtc.PeerConnectionStateFailed || state == webrtc.PeerConnectionStateClosed || state == webrtc.PeerConnectionStateDisconnected { t.closeOnce.Do(func() { close(t.closed) }) } }) go t.sendLoop() return t } // sendLoop is a dedicated goroutine that drains outEvents and sends them // over the data channel. It waits for the data channel to open before // sending, and drains any remaining events when closed is signalled. func (t *WebRTCTransport) sendLoop() { defer close(t.flushed) // Wait for data channel to be ready select { case <-t.dcReady: case <-t.closed: return } for { select { case data, ok := <-t.outEvents: if !ok { return } if err := t.dc.SendText(string(data)); err != nil { xlog.Error("data channel send failed", "error", err) return } case <-t.closed: // Drain any remaining queued events before exiting for { select { case data := <-t.outEvents: if err := t.dc.SendText(string(data)); err != nil { return } default: return } } } } } func (t *WebRTCTransport) SendEvent(event types.ServerEvent) error { data, err := json.Marshal(event) if err != nil { return fmt.Errorf("marshal event: %w", err) } select { case t.outEvents <- data: return nil case <-t.closed: return fmt.Errorf("transport closed") } } func (t *WebRTCTransport) ReadEvent() ([]byte, error) { select { case msg := <-t.inEvents: return msg, nil case <-t.closed: return nil, fmt.Errorf("transport closed") } } // SendAudio encodes raw PCM int16 LE to Opus and writes RTP packets to the // audio track. The encoder resamples from the given sampleRate to 48kHz // internally. Frames are paced at real-time intervals (20ms per frame) to // avoid overwhelming the browser's jitter buffer with a burst of packets. // // The context allows callers to cancel mid-stream for barge-in support. // When cancelled, the marker bit is set so the next audio segment starts // cleanly in the browser's jitter buffer. // // RTP packets are constructed manually (rather than via WriteSample) so we // can control the marker bit. pion's WriteSample sets the marker bit on // every Opus packet, which causes Chrome's NetEq jitter buffer to reset // its timing estimation for each frame, producing severe audio distortion. func (t *WebRTCTransport) SendAudio(ctx context.Context, pcmData []byte, sampleRate int) error { result, err := t.opusBackend.AudioEncode(ctx, &pb.AudioEncodeRequest{ PcmData: pcmData, SampleRate: int32(sampleRate), Channels: 1, }) if err != nil { return fmt.Errorf("opus encode: %w", err) } frames := result.Frames const frameDuration = 20 * time.Millisecond const samplesPerFrame = 960 // 20ms at 48kHz ticker := time.NewTicker(frameDuration) defer ticker.Stop() for i, frame := range frames { t.rtpMu.Lock() pkt := &rtp.Packet{ Header: rtp.Header{ Version: 2, Marker: t.rtpMarker, SequenceNumber: t.rtpSeqNum, Timestamp: t.rtpTimestamp, // SSRC and PayloadType are overridden by pion's writeRTP }, Payload: frame, } t.rtpSeqNum++ t.rtpTimestamp += samplesPerFrame t.rtpMarker = false // only the first packet gets marker t.rtpMu.Unlock() if err := t.audioTrack.WriteRTP(pkt); err != nil { return fmt.Errorf("write rtp: %w", err) } // Pace output at ~real-time so the browser's jitter buffer // receives packets at the expected rate. Skip wait after last frame. if i < len(frames)-1 { select { case <-ticker.C: case <-ctx.Done(): // Barge-in: mark the next packet so the browser knows // a new audio segment is starting after the interruption. t.rtpMu.Lock() t.rtpMarker = true t.rtpMu.Unlock() return ctx.Err() case <-t.closed: return fmt.Errorf("transport closed during audio send") } } } return nil } // SetSession delivers the session to any goroutine waiting in WaitForSession. func (t *WebRTCTransport) SetSession(s *Session) { select { case t.sessionCh <- s: case <-t.closed: } } // WaitForSession blocks until the session is available or the transport closes. func (t *WebRTCTransport) WaitForSession() *Session { select { case s := <-t.sessionCh: return s case <-t.closed: return nil } } func (t *WebRTCTransport) Close() error { // Signal no more events and unblock the sender if it's waiting t.closeOnce.Do(func() { close(t.closed) }) // Wait for the sender to drain any remaining queued events <-t.flushed return t.pc.Close() } ================================================ FILE: core/http/endpoints/openai/realtime_transport_ws.go ================================================ package openai import ( "context" "encoding/json" "sync" "github.com/gorilla/websocket" "github.com/mudler/LocalAI/core/http/endpoints/openai/types" "github.com/mudler/xlog" ) // WebSocketTransport implements Transport over a gorilla/websocket connection. type WebSocketTransport struct { conn *websocket.Conn mu sync.Mutex } func NewWebSocketTransport(conn *websocket.Conn) *WebSocketTransport { return &WebSocketTransport{conn: conn} } func (t *WebSocketTransport) SendEvent(event types.ServerEvent) error { eventBytes, err := json.Marshal(event) if err != nil { xlog.Error("failed to marshal event", "error", err) return err } t.mu.Lock() defer t.mu.Unlock() return t.conn.WriteMessage(websocket.TextMessage, eventBytes) } func (t *WebSocketTransport) ReadEvent() ([]byte, error) { _, msg, err := t.conn.ReadMessage() return msg, err } // SendAudio is a no-op for WebSocket — audio is delivered via JSON events // (base64-encoded in response.audio.delta). func (t *WebSocketTransport) SendAudio(_ context.Context, _ []byte, _ int) error { return nil } func (t *WebSocketTransport) Close() error { return t.conn.Close() } ================================================ FILE: core/http/endpoints/openai/realtime_webrtc.go ================================================ package openai import ( "net/http" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" "github.com/pion/webrtc/v4" ) // RealtimeCallRequest is the JSON body for POST /v1/realtime/calls. type RealtimeCallRequest struct { SDP string `json:"sdp"` Model string `json:"model"` } // RealtimeCallResponse is the JSON response for POST /v1/realtime/calls. type RealtimeCallResponse struct { SDP string `json:"sdp"` SessionID string `json:"session_id"` } // RealtimeCalls handles POST /v1/realtime/calls for WebRTC signaling. func RealtimeCalls(application *application.Application) echo.HandlerFunc { return func(c echo.Context) error { var req RealtimeCallRequest if err := c.Bind(&req); err != nil { return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid request body"}) } if req.SDP == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "sdp is required"}) } if req.Model == "" { return c.JSON(http.StatusBadRequest, map[string]string{"error": "model is required"}) } // Create a MediaEngine with Opus support m := &webrtc.MediaEngine{} if err := m.RegisterDefaultCodecs(); err != nil { xlog.Error("failed to register codecs", "error", err) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "codec registration failed"}) } api := webrtc.NewAPI(webrtc.WithMediaEngine(m)) pc, err := api.NewPeerConnection(webrtc.Configuration{}) if err != nil { xlog.Error("failed to create peer connection", "error", err) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create peer connection"}) } // Create outbound audio track (Opus, 48kHz). // We use TrackLocalStaticRTP (not TrackLocalStaticSample) so that // SendAudio can construct RTP packets directly and control the marker // bit. pion's WriteSample sets the marker bit on every Opus packet, // which causes Chrome's NetEq jitter buffer to reset for each frame. audioTrack, err := webrtc.NewTrackLocalStaticRTP( webrtc.RTPCodecCapability{ MimeType: webrtc.MimeTypeOpus, ClockRate: 48000, Channels: 2, // Opus in WebRTC is always signaled as 2 channels per RFC 7587 }, "audio", "localai", ) if err != nil { pc.Close() xlog.Error("failed to create audio track", "error", err) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create audio track"}) } rtpSender, err := pc.AddTrack(audioTrack) if err != nil { pc.Close() xlog.Error("failed to add audio track", "error", err) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to add audio track"}) } // Drain RTCP (control protocol) packets we don't have anyting useful to do with go func() { buf := make([]byte, 1500) for { if _, _, err := rtpSender.Read(buf); err != nil { return } } }() // Load the Opus backend opusBackend, err := application.ModelLoader().Load( model.WithBackendString("opus"), model.WithModelID("__opus_codec__"), model.WithModel("opus"), ) if err != nil { pc.Close() xlog.Error("failed to load opus backend", "error", err) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "opus backend not available"}) } // Create the transport (the data channel is created by the client and // received via pc.OnDataChannel inside NewWebRTCTransport) transport := NewWebRTCTransport(pc, audioTrack, opusBackend) // Handle incoming audio track from the client pc.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { codec := track.Codec() if codec.MimeType != webrtc.MimeTypeOpus { xlog.Warn("unexpected track codec, ignoring", "mime", codec.MimeType) return } xlog.Debug("Received audio track from client", "codec", codec.MimeType, "clock_rate", codec.ClockRate, "channels", codec.Channels, "sdp_fmtp", codec.SDPFmtpLine, "payload_type", codec.PayloadType, ) handleIncomingAudioTrack(track, transport) }) // Set the remote SDP (client's offer) if err := pc.SetRemoteDescription(webrtc.SessionDescription{ Type: webrtc.SDPTypeOffer, SDP: req.SDP, }); err != nil { transport.Close() xlog.Error("failed to set remote description", "error", err) return c.JSON(http.StatusBadRequest, map[string]string{"error": "invalid SDP offer"}) } // Create answer answer, err := pc.CreateAnswer(nil) if err != nil { transport.Close() xlog.Error("failed to create answer", "error", err) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to create answer"}) } if err := pc.SetLocalDescription(answer); err != nil { transport.Close() xlog.Error("failed to set local description", "error", err) return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to set local description"}) } // Wait for ICE gathering to complete (with timeout) gatherDone := webrtc.GatheringCompletePromise(pc) select { case <-gatherDone: case <-time.After(10 * time.Second): xlog.Warn("ICE gathering timed out, using partial candidates") } localDesc := pc.LocalDescription() if localDesc == nil { transport.Close() return c.JSON(http.StatusInternalServerError, map[string]string{"error": "no local description"}) } sessionID := generateSessionID() // Start the realtime session in a goroutine evaluator := application.TemplatesEvaluator() go func() { defer transport.Close() runRealtimeSession(application, transport, req.Model, evaluator) }() return c.JSON(http.StatusCreated, RealtimeCallResponse{ SDP: localDesc.SDP, SessionID: sessionID, }) } } // handleIncomingAudioTrack reads RTP packets from a remote WebRTC track // and buffers the raw Opus payloads on the session. Decoding is done in // batches by decodeOpusLoop in realtime.go. func handleIncomingAudioTrack(track *webrtc.TrackRemote, transport *WebRTCTransport) { session := transport.WaitForSession() if session == nil { xlog.Error("could not find session for incoming audio track (transport closed)") sendError(transport, "session_error", "Session failed to start — check server logs", "", "") return } for { pkt, _, err := track.ReadRTP() if err != nil { xlog.Debug("audio track read ended", "error", err) return } // Copy the payload — pion's ReadRTP may back it by a reusable buffer payload := make([]byte, len(pkt.Payload)) copy(payload, pkt.Payload) session.OpusFramesLock.Lock() session.OpusFrames = append(session.OpusFrames, payload) session.OpusFramesLock.Unlock() } } ================================================ FILE: core/http/endpoints/openai/transcription.go ================================================ package openai import ( "errors" "io" "net/http" "os" "path" "path/filepath" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/pkg/format" model "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) // TranscriptEndpoint is the OpenAI Whisper API endpoint https://platform.openai.com/docs/api-reference/audio/create // @Summary Transcribes audio into the input language. // @accept multipart/form-data // @Param model formData string true "model" // @Param file formData file true "file" // @Success 200 {object} map[string]string "Response" // @Router /v1/audio/transcriptions [post] func TranscriptEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } config, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || config == nil { return echo.ErrBadRequest } diarize := c.FormValue("diarize") != "false" prompt := c.FormValue("prompt") responseFormat := schema.TranscriptionResponseFormatType(c.FormValue("response_format")) // retrieve the file data from the request file, err := c.FormFile("file") if err != nil { return err } f, err := file.Open() if err != nil { return err } defer f.Close() dir, err := os.MkdirTemp("", "whisper") if err != nil { return err } defer os.RemoveAll(dir) dst := filepath.Join(dir, path.Base(file.Filename)) dstFile, err := os.Create(dst) if err != nil { return err } if _, err := io.Copy(dstFile, f); err != nil { xlog.Debug("Audio file copying error", "filename", file.Filename, "dst", dst, "error", err) return err } xlog.Debug("Audio file copied", "dst", dst) tr, err := backend.ModelTranscription(dst, input.Language, input.Translate, diarize, prompt, ml, *config, appConfig) if err != nil { return err } xlog.Debug("Transcribed", "transcription", tr) switch responseFormat { case schema.TranscriptionResponseFormatLrc, schema.TranscriptionResponseFormatText, schema.TranscriptionResponseFormatSrt, schema.TranscriptionResponseFormatVtt: return c.String(http.StatusOK, format.TranscriptionResponse(tr, responseFormat)) case schema.TranscriptionResponseFormatJson: tr.Segments = nil fallthrough case schema.TranscriptionResponseFormatJsonVerbose, "": // maintain backwards compatibility return c.JSON(http.StatusOK, tr) default: return errors.New("invalid response_format") } } } ================================================ FILE: core/http/endpoints/openai/types/client_events.go ================================================ package types import "encoding/json" // ClientEventType is the type of client event. See https://platform.openai.com/docs/guides/realtime/client-events type ClientEventType string const ( ClientEventTypeSessionUpdate ClientEventType = "session.update" ClientEventTypeInputAudioBufferAppend ClientEventType = "input_audio_buffer.append" ClientEventTypeInputAudioBufferCommit ClientEventType = "input_audio_buffer.commit" ClientEventTypeInputAudioBufferClear ClientEventType = "input_audio_buffer.clear" ClientEventTypeConversationItemCreate ClientEventType = "conversation.item.create" ClientEventTypeConversationItemRetrieve ClientEventType = "conversation.item.retrieve" ClientEventTypeConversationItemTruncate ClientEventType = "conversation.item.truncate" ClientEventTypeConversationItemDelete ClientEventType = "conversation.item.delete" ClientEventTypeResponseCreate ClientEventType = "response.create" ClientEventTypeResponseCancel ClientEventType = "response.cancel" ClientEventTypeOutputAudioBufferClear ClientEventType = "output_audio_buffer.clear" ) // ClientEvent is the interface for client event. type ClientEvent interface { ClientEventType() ClientEventType } // EventBase is the base struct for all client events. type EventBase struct { Type string `json:"type"` // Optional client-generated ID used to identify this event. EventID string `json:"event_id,omitempty"` } // Send this event to update the session’s configuration. The client may send this event at any time to update any field except for voice and model. voice can be updated only if there have been no other audio outputs yet. // // When the server receives a session.update, it will respond with a session.updated event showing the full, effective configuration. Only the fields that are present in the session.update are updated. To clear a field like instructions, pass an empty string. To clear a field like tools, pass an empty array. To clear a field like turn_detection, pass null.// // // See https://platform.openai.com/docs/api-reference/realtime-client-events/session/update type SessionUpdateEvent struct { EventBase // Session configuration to update. Session SessionUnion `json:"session"` } func (m SessionUpdateEvent) ClientEventType() ClientEventType { return ClientEventTypeSessionUpdate } func (m SessionUpdateEvent) MarshalJSON() ([]byte, error) { type typeAlias SessionUpdateEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } type NoiseReductionType string const ( NoiseReductionNearField NoiseReductionType = "near_field" NoiseReductionFarField NoiseReductionType = "far_field" ) // Send this event to append audio bytes to the input audio buffer. The audio buffer is temporary storage you can write to and later commit. A "commit" will create a new user message item in the conversation history from the buffer content and clear the buffer. Input audio transcription (if enabled) will be generated when the buffer is committed. // // If VAD is enabled the audio buffer is used to detect speech and the server will decide when to commit. When Server VAD is disabled, you must commit the audio buffer manually. Input audio noise reduction operates on writes to the audio buffer. // // The client may choose how much audio to place in each event up to a maximum of 15 MiB, for example streaming smaller chunks from the client may allow the VAD to be more responsive. Unlike most other client events, the server will not send a confirmation response to this event. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/input_audio_buffer/append type InputAudioBufferAppendEvent struct { EventBase Audio string `json:"audio"` // Base64-encoded audio bytes. } func (m InputAudioBufferAppendEvent) ClientEventType() ClientEventType { return ClientEventTypeInputAudioBufferAppend } func (m InputAudioBufferAppendEvent) MarshalJSON() ([]byte, error) { type typeAlias InputAudioBufferAppendEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } // Send this event to commit the user input audio buffer, which will create a new user message item in the conversation. This event will produce an error if the input audio buffer is empty. When in Server VAD mode, the client does not need to send this event, the server will commit the audio buffer automatically. // // Committing the input audio buffer will trigger input audio transcription (if enabled in session configuration), but it will not create a response from the model. The server will respond with an input_audio_buffer.committed event. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/input_audio_buffer/commit type InputAudioBufferCommitEvent struct { EventBase } func (m InputAudioBufferCommitEvent) ClientEventType() ClientEventType { return ClientEventTypeInputAudioBufferCommit } func (m InputAudioBufferCommitEvent) MarshalJSON() ([]byte, error) { type typeAlias InputAudioBufferCommitEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } // Send this event to clear the audio bytes in the buffer. The server will respond with an input_audio_buffer.cleared event. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/input_audio_buffer/clear type InputAudioBufferClearEvent struct { EventBase } func (m InputAudioBufferClearEvent) ClientEventType() ClientEventType { return ClientEventTypeInputAudioBufferClear } func (m InputAudioBufferClearEvent) MarshalJSON() ([]byte, error) { type typeAlias InputAudioBufferClearEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } // Send this event to clear the audio bytes in the buffer. The server will respond with an input_audio_buffer.cleared event. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/output_audio_buffer/clear type OutputAudioBufferClearEvent struct { EventBase } func (m OutputAudioBufferClearEvent) ClientEventType() ClientEventType { return ClientEventTypeOutputAudioBufferClear } func (m OutputAudioBufferClearEvent) MarshalJSON() ([]byte, error) { type typeAlias OutputAudioBufferClearEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } // Add a new Item to the Conversation's context, including messages, function calls, and function call responses. This event can be used both to populate a "history" of the conversation and to add new items mid-stream, but has the current limitation that it cannot populate assistant audio messages. // // If successful, the server will respond with a conversation.item.created event, otherwise an error event will be sent. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/create type ConversationItemCreateEvent struct { EventBase // The ID of the preceding item after which the new item will be inserted. PreviousItemID string `json:"previous_item_id,omitempty"` // The item to add to the conversation. Item MessageItemUnion `json:"item"` } func (m ConversationItemCreateEvent) ClientEventType() ClientEventType { return ClientEventTypeConversationItemCreate } func (m ConversationItemCreateEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemCreateEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } // Send this event when you want to retrieve the server's representation of a specific item in the conversation history. This is useful, for example, to inspect user audio after noise cancellation and VAD. The server will respond with a conversation.item.retrieved event, unless the item does not exist in the conversation history, in which case the server will respond with an error. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/retrieve type ConversationItemRetrieveEvent struct { EventBase // The ID of the item to retrieve. ItemID string `json:"item_id"` } func (m ConversationItemRetrieveEvent) ClientEventType() ClientEventType { return ClientEventTypeConversationItemRetrieve } func (m ConversationItemRetrieveEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemRetrieveEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } // Send this event to truncate a previous assistant message’s audio. The server will produce audio faster than realtime, so this event is useful when the user interrupts to truncate audio that has already been sent to the client but not yet played. This will synchronize the server's understanding of the audio with the client's playback. // // Truncating audio will delete the server-side text transcript to ensure there is not text in the context that hasn't been heard by the user. // // If successful, the server will respond with a conversation.item.truncated event. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/truncate type ConversationItemTruncateEvent struct { EventBase // The ID of the assistant message item to truncate. ItemID string `json:"item_id"` // The index of the content part to truncate. ContentIndex int `json:"content_index"` // Inclusive duration up to which audio is truncated, in milliseconds. AudioEndMs int `json:"audio_end_ms"` } func (m ConversationItemTruncateEvent) ClientEventType() ClientEventType { return ClientEventTypeConversationItemTruncate } func (m ConversationItemTruncateEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemTruncateEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } // Send this event when you want to remove any item from the conversation history. The server will respond with a conversation.item.deleted event, unless the item does not exist in the conversation history, in which case the server will respond with an error. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/conversation/item/delete type ConversationItemDeleteEvent struct { EventBase // The ID of the item to delete. ItemID string `json:"item_id"` } func (m ConversationItemDeleteEvent) ClientEventType() ClientEventType { return ClientEventTypeConversationItemDelete } func (m ConversationItemDeleteEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemDeleteEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } // This event instructs the server to create a Response, which means triggering model inference. When in Server VAD mode, the server will create Responses automatically. // // A Response will include at least one Item, and may have two, in which case the second will be a function call. These Items will be appended to the conversation history by default. // // The server will respond with a response.created event, events for Items and content created, and finally a response.done event to indicate the Response is complete. // // The response.create event includes inference configuration like instructions and tools. If these are set, they will override the Session's configuration for this Response only. // // Responses can be created out-of-band of the default Conversation, meaning that they can have arbitrary input, and it's possible to disable writing the output to the Conversation. Only one Response can write to the default Conversation at a time, but otherwise multiple Responses can be created in parallel. The metadata field is a good way to disambiguate multiple simultaneous Responses. // // Clients can set conversation to none to create a Response that does not write to the default Conversation. Arbitrary input can be provided with the input field, which is an array accepting raw Items and references to existing Items. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/response/create type ResponseCreateEvent struct { EventBase // Configuration for the response. Response ResponseCreateParams `json:"response"` } func (m ResponseCreateEvent) ClientEventType() ClientEventType { return ClientEventTypeResponseCreate } func (m ResponseCreateEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseCreateEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } // Send this event to cancel an in-progress response. The server will respond with a response.done event with a status of response.status=cancelled. If there is no response to cancel, the server will respond with an error. It's safe to call response.cancel even if no response is in progress, an error will be returned the session will remain unaffected. // // See https://platform.openai.com/docs/api-reference/realtime-client-events/response/cancel type ResponseCancelEvent struct { EventBase // A specific response ID to cancel - if not provided, will cancel an in-progress response in the default conversation. ResponseID string `json:"response_id,omitempty"` } func (m ResponseCancelEvent) ClientEventType() ClientEventType { return ClientEventTypeResponseCancel } func (m ResponseCancelEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseCancelEvent type typeWrapper struct { typeAlias Type ClientEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ClientEventType(), } return json.Marshal(shadow) } type ClientEventInterface interface { SessionUpdateEvent | InputAudioBufferAppendEvent | InputAudioBufferCommitEvent | InputAudioBufferClearEvent | OutputAudioBufferClearEvent | ConversationItemCreateEvent | ConversationItemRetrieveEvent | ConversationItemTruncateEvent | ConversationItemDeleteEvent | ResponseCreateEvent | ResponseCancelEvent } func unmarshalClientEvent[T ClientEventInterface](data []byte) (T, error) { var t T err := json.Unmarshal(data, &t) if err != nil { return t, err } return t, nil } // UnmarshalClientEvent unmarshals the client event from the given JSON data. func UnmarshalClientEvent(data []byte) (ClientEvent, error) { var eventType struct { Type ClientEventType `json:"type"` } err := json.Unmarshal(data, &eventType) if err != nil { return nil, err } switch eventType.Type { case ClientEventTypeSessionUpdate: return unmarshalClientEvent[SessionUpdateEvent](data) case ClientEventTypeInputAudioBufferAppend: return unmarshalClientEvent[InputAudioBufferAppendEvent](data) case ClientEventTypeInputAudioBufferCommit: return unmarshalClientEvent[InputAudioBufferCommitEvent](data) case ClientEventTypeInputAudioBufferClear: return unmarshalClientEvent[InputAudioBufferClearEvent](data) case ClientEventTypeOutputAudioBufferClear: return unmarshalClientEvent[OutputAudioBufferClearEvent](data) case ClientEventTypeConversationItemCreate: return unmarshalClientEvent[ConversationItemCreateEvent](data) case ClientEventTypeConversationItemRetrieve: return unmarshalClientEvent[ConversationItemRetrieveEvent](data) case ClientEventTypeConversationItemTruncate: return unmarshalClientEvent[ConversationItemTruncateEvent](data) case ClientEventTypeConversationItemDelete: return unmarshalClientEvent[ConversationItemDeleteEvent](data) case ClientEventTypeResponseCreate: return unmarshalClientEvent[ResponseCreateEvent](data) case ClientEventTypeResponseCancel: return unmarshalClientEvent[ResponseCancelEvent](data) default: // We should probably return a generic event or error here, but for now just nil. // Or maybe a "UnknownEvent" struct? // For now matching the existing pattern return nil, nil } } ================================================ FILE: core/http/endpoints/openai/types/int_or_inf.go ================================================ package types import ( "encoding/json" "math" ) const ( // Inf is the maximum value for an IntOrInf. Inf IntOrInf = math.MaxInt ) // IntOrInf is a type that can be either an int or "inf". type IntOrInf int // IsInf returns true if the value is "inf". func (m IntOrInf) IsInf() bool { return m == Inf } // MarshalJSON marshals the IntOrInf to JSON. func (m IntOrInf) MarshalJSON() ([]byte, error) { if m == Inf { return []byte("\"inf\""), nil } return json.Marshal(int(m)) } // UnmarshalJSON unmarshals the IntOrInf from JSON. func (m *IntOrInf) UnmarshalJSON(data []byte) error { if string(data) == "\"inf\"" { *m = Inf return nil } if len(data) == 0 { return nil } return json.Unmarshal(data, (*int)(m)) } ================================================ FILE: core/http/endpoints/openai/types/message_item.go ================================================ package types import ( "encoding/json" "errors" "fmt" ) type MessageItemType string const ( MessageItemTypeMessage MessageItemType = "message" MessageItemTypeFunctionCall MessageItemType = "function_call" MessageItemTypeFunctionCallOutput MessageItemType = "function_call_output" MessageItemTypeMCPApprovalResponse MessageItemType = "mcp_approval_response" MessageItemTypeMCPListTools MessageItemType = "mcp_list_tools" MessageItemTypeMCPCall MessageItemType = "mcp_call" MessageItemTypeMCPApprovalRequest MessageItemType = "mcp_approval_request" ) type MessageContentType string const ( MessageContentTypeText MessageContentType = "text" MessageContentTypeAudio MessageContentType = "audio" MessageContentTypeTranscript MessageContentType = "transcript" MessageContentTypeInputText MessageContentType = "input_text" MessageContentTypeInputAudio MessageContentType = "input_audio" MessageContentTypeInputImage MessageContentType = "input_image" MessageContentTypeOutputText MessageContentType = "output_text" MessageContentTypeOutputAudio MessageContentType = "output_audio" ) type MessageContentText struct { Text string `json:"text,omitempty"` } type MessageContentAudio struct { Type MessageContentType `json:"type,omitempty"` Audio string `json:"audio,omitempty"` } type MessageContentTranscript struct { Type MessageContentType `json:"type,omitempty"` Transcript string `json:"transcript,omitempty"` } type MessageContentImage struct { Type MessageContentType `json:"type,omitempty"` ImageURL string `json:"image_url,omitempty"` Detail ImageDetail `json:"detail,omitempty"` } type MessageContentSystem MessageContentText type MessageItemSystem struct { // The unique ID of the item. This may be provided by the client or generated by the server. ID string `json:"id,omitempty"` // The content of the message. Content []MessageContentSystem `json:"content,omitempty"` // Identifier for the API object being returned - always realtime.item. Optional when creating a new item. Object string `json:"object,omitempty"` // The status of the item. Has no effect on the conversation. Status ItemStatus `json:"status,omitempty"` } func (m MessageItemSystem) MessageItemType() MessageItemType { return MessageItemTypeMessage } func (m MessageItemSystem) Role() MessageRole { return MessageRoleSystem } func (m MessageItemSystem) MarshalJSON() ([]byte, error) { type typeAlias MessageItemSystem type typeWrapper struct { typeAlias Type MessageItemType `json:"type"` Role MessageRole `json:"role"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.MessageItemType(), Role: m.Role(), } return json.Marshal(shadow) } type MessageItemUser struct { // The unique ID of the item. This may be provided by the client or generated by the server. ID string `json:"id,omitempty"` // The content of the message. Content []MessageContentInput `json:"content,omitempty"` // Identifier for the API object being returned - always realtime.item. Optional when creating a new item. Object string `json:"object,omitempty"` // The status of the item. Has no effect on the conversation. Status ItemStatus `json:"status,omitempty"` } func (m MessageItemUser) MessageItemType() MessageItemType { return MessageItemTypeMessage } func (m MessageItemUser) Role() MessageRole { return MessageRoleUser } func (m MessageItemUser) MarshalJSON() ([]byte, error) { type typeAlias MessageItemUser type typeWrapper struct { typeAlias Type MessageItemType `json:"type"` Role MessageRole `json:"role"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.MessageItemType(), Role: m.Role(), } return json.Marshal(shadow) } type MessageItemAssistant struct { // The unique ID of the item. This may be provided by the client or generated by the server. ID string `json:"id,omitempty"` // The content of the message. Content []MessageContentOutput `json:"content,omitempty"` // Identifier for the API object being returned - always realtime.item. Optional when creating a new item. Object string `json:"object,omitempty"` // The status of the item. Has no effect on the conversation. Status ItemStatus `json:"status,omitempty"` } func (m MessageItemAssistant) MessageItemType() MessageItemType { return MessageItemTypeMessage } func (m MessageItemAssistant) Role() MessageRole { return MessageRoleAssistant } func (m MessageItemAssistant) MarshalJSON() ([]byte, error) { type typeAlias MessageItemAssistant type typeWrapper struct { typeAlias Type MessageItemType `json:"type"` Role MessageRole `json:"role"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.MessageItemType(), Role: m.Role(), } return json.Marshal(shadow) } type MessageContentInput struct { // The content type (input_text, input_audio, or input_image). Type MessageContentType `json:"type"` // Base64-encoded audio bytes (for input_audio), these will be parsed as the format specified in the session input audio type configuration. This defaults to PCM 16-bit 24kHz mono if not specified. Audio string `json:"audio,omitempty"` // The detail level of the image (for input_image). auto will default to high. Detail ImageDetail `json:"detail,omitempty"` // Base64-encoded image bytes (for input_image) as a data URI. For example data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAA.... Supported formats are PNG and JPEG. ImageURL string `json:"image_url,omitempty"` // The text content (for input_text). Text string `json:"text,omitempty"` // Transcript of the audio (for input_audio). This is not sent to the model, but will be attached to the message item for reference. Transcript string `json:"transcript,omitempty"` } type MessageContentOutput struct { // The content type (input_text, input_audio, or input_image). Type MessageContentType `json:"type,omitempty"` // Base64-encoded audio bytes (for input_audio), these will be parsed as the format specified in the session input audio type configuration. This defaults to PCM 16-bit 24kHz mono if not specified. Audio string `json:"audio,omitempty"` // The text content (for input_text). Text string `json:"text,omitempty"` // Transcript of the audio (for input_audio). This is not sent to the model, but will be attached to the message item for reference. Transcript string `json:"transcript,omitempty"` } type MessageItemFunctionCall struct { // The unique ID of the item. This may be provided by the client or generated by the server. ID string `json:"id,omitempty"` // The ID of the function call. CallID string `json:"call_id,omitempty"` // The arguments of the function call. This is a JSON-encoded string representing the arguments passed to the function, for example {"arg1": "value1", "arg2": 42}. Arguments string `json:"arguments,omitempty"` // The name of the function being called. Name string `json:"name,omitempty"` // Identifier for the API object being returned - always realtime.item. Optional when creating a new item. Object string `json:"object,omitempty"` // The status of the item. Has no effect on the conversation. Status ItemStatus `json:"status,omitempty"` } func (m MessageItemFunctionCall) MessageItemType() MessageItemType { return MessageItemTypeFunctionCall } func (m MessageItemFunctionCall) MarshalJSON() ([]byte, error) { type typeAlias MessageItemFunctionCall type typeWrapper struct { typeAlias Type MessageItemType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.MessageItemType(), } return json.Marshal(shadow) } type MessageItemFunctionCallOutput struct { // The unique ID of the item. This may be provided by the client or generated by the server. ID string `json:"id,omitempty"` // The ID of the function call this output is for. CallID string `json:"call_id,omitempty"` // The output of the function call, this is free text and can contain any information or simply be empty. Output string `json:"output,omitempty"` // Identifier for the API object being returned - always realtime.item. Optional when creating a new item. Object string `json:"object,omitempty"` // The status of the item. Has no effect on the conversation. Status ItemStatus `json:"status,omitempty"` } func (m MessageItemFunctionCallOutput) MessageItemType() MessageItemType { return MessageItemTypeFunctionCallOutput } func (m MessageItemFunctionCallOutput) MarshalJSON() ([]byte, error) { type typeAlias MessageItemFunctionCallOutput type typeWrapper struct { typeAlias Type MessageItemType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.MessageItemType(), } return json.Marshal(shadow) } type MessageItemMCPApprovalResponse struct { // The unique ID of the approval response. ID string `json:"id,omitempty"` // The ID of the approval request being answered. ApprovalRequestID string `json:"approval_request_id,omitempty"` // Whether the request was approved. Approve bool `json:"approve,omitempty"` // Optional reason for the decision. Reason string `json:"reason,omitempty"` } func (m MessageItemMCPApprovalResponse) MessageItemType() MessageItemType { return MessageItemTypeMCPApprovalResponse } func (m MessageItemMCPApprovalResponse) MarshalJSON() ([]byte, error) { type typeAlias MessageItemMCPApprovalResponse type typeWrapper struct { typeAlias Type MessageItemType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.MessageItemType(), } return json.Marshal(shadow) } type MCPTool struct { // JSON schema describing the tool's expected input shape. InputSchema string `json:"input_schema,omitempty"` // The name of the MCP tool. Name string `json:"name,omitempty"` // A human-readable description of what the tool does. Description string `json:"description,omitempty"` // Additional metadata or annotations supplied by the server. Annotations any `json:"annotations,omitempty"` } type MessageItemMCPListTools struct { // The unique ID of the list. ID string `json:"id,omitempty"` // The label of the MCP server. ServerLabel string `json:"server_label,omitempty"` // The tools available on the server. Tools []MCPTool `json:"tools,omitempty"` } func (m MessageItemMCPListTools) MessageItemType() MessageItemType { return MessageItemTypeMCPListTools } func (m MessageItemMCPListTools) MarshalJSON() ([]byte, error) { type typeAlias MessageItemMCPListTools type typeWrapper struct { typeAlias Type MessageItemType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.MessageItemType(), } return json.Marshal(shadow) } type MCPErrorType string const ( MCPErrorTypeProtocolError MCPErrorType = "protocol_error" MCPErrorTypeToolExecution MCPErrorType = "tool_execution_error" MCPErrorTypeHTTPError MCPErrorType = "http_error" ) type MCPProtocolError struct { // Numeric error code (protocol-specific). Code int `json:"code,omitempty"` // Human-readable error message. Message string `json:"message,omitempty"` } func (m MCPProtocolError) ErrorType() MCPErrorType { return MCPErrorTypeProtocolError } func (m MCPProtocolError) MarshalJSON() ([]byte, error) { type typeAlias MCPProtocolError type typeWrapper struct { typeAlias Type MCPErrorType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ErrorType(), } return json.Marshal(shadow) } type MCPToolExecutionError struct { // Human-readable error message from tool execution. Message string `json:"message,omitempty"` } func (m MCPToolExecutionError) ErrorType() MCPErrorType { return MCPErrorTypeToolExecution } func (m MCPToolExecutionError) MarshalJSON() ([]byte, error) { type typeAlias MCPToolExecutionError type typeWrapper struct { typeAlias Type MCPErrorType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ErrorType(), } return json.Marshal(shadow) } type MCPHTTPError struct { // HTTP status code returned by the upstream call. Code int `json:"code,omitempty"` // Human-readable HTTP error message. Message string `json:"message,omitempty"` } func (m MCPHTTPError) ErrorType() MCPErrorType { return MCPErrorTypeHTTPError } func (m MCPHTTPError) MarshalJSON() ([]byte, error) { type typeAlias MCPHTTPError type typeWrapper struct { typeAlias Type MCPErrorType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ErrorType(), } return json.Marshal(shadow) } type MCPError struct { // Details when type is protocol_error. Protocol *MCPProtocolError `json:",omitempty"` // Details when type is tool_execution_error. ToolExecution *MCPToolExecutionError `json:",omitempty"` // Details when type is http_error. HTTP *MCPHTTPError `json:",omitempty"` } func (m MCPError) MarshalJSON() ([]byte, error) { if m.Protocol != nil { return json.Marshal(m.Protocol) } if m.ToolExecution != nil { return json.Marshal(m.ToolExecution) } return json.Marshal(m.HTTP) } func (m *MCPError) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } var u typeStruct if err := json.Unmarshal(data, &u); err != nil { return err } switch MCPErrorType(u.Type) { case MCPErrorTypeProtocolError: return json.Unmarshal(data, &m.Protocol) case MCPErrorTypeToolExecution: return json.Unmarshal(data, &m.ToolExecution) case MCPErrorTypeHTTPError: return json.Unmarshal(data, &m.HTTP) default: return errors.New("unknown error type: " + u.Type) } } type MessageItemMCPToolCall struct { // The unique ID of the tool call. ID string `json:"id,omitempty"` // The label of the MCP server running the tool. ServerLabel string `json:"server_label,omitempty"` // A JSON string of the arguments passed to the tool. Arguments string `json:"arguments,omitempty"` // The name of the tool that was run. Name string `json:"name,omitempty"` // The ID of an associated approval request, if any. ApprovalRequestID string `json:"approval_request_id,omitempty"` // The error from the tool call, if any. Error *MCPProtocolError `json:"error,omitempty"` // The output from the tool call. Output string `json:"output,omitempty"` } func (m MessageItemMCPToolCall) MessageItemType() MessageItemType { return MessageItemTypeMCPCall } func (m MessageItemMCPToolCall) MarshalJSON() ([]byte, error) { type typeAlias MessageItemMCPToolCall type typeWrapper struct { typeAlias Type MessageItemType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.MessageItemType(), } return json.Marshal(shadow) } type MessageItemMCPApprovalRequest struct { // The unique ID of the approval request. ID string `json:"id,omitempty"` // The name of the tool to run. Name string `json:"name,omitempty"` // A JSON string of arguments for the tool. Arguments string `json:"arguments,omitempty"` // The label of the MCP server making the request. ServerLabel string `json:"server_label,omitempty"` } func (m MessageItemMCPApprovalRequest) MessageItemType() MessageItemType { return MessageItemTypeMCPApprovalRequest } func (m MessageItemMCPApprovalRequest) MarshalJSON() ([]byte, error) { type typeAlias MessageItemMCPApprovalRequest type typeWrapper struct { typeAlias Type MessageItemType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.MessageItemType(), } return json.Marshal(shadow) } type MessageItemUnion struct { // A system message in a Realtime conversation can be used to provide additional context or instructions to the model. This is similar but distinct from the instruction prompt provided at the start of a conversation, as system messages can be added at any point in the conversation. For major changes to the conversation's behavior, use instructions, but for smaller updates (e.g. "the user is now asking about a different topic"), use system messages. System *MessageItemSystem `json:",omitempty"` // A user message item in a Realtime conversation. User *MessageItemUser `json:",omitempty"` // An assistant message item in a Realtime conversation. Assistant *MessageItemAssistant `json:",omitempty"` // A function call item in a Realtime conversation. FunctionCall *MessageItemFunctionCall `json:",omitempty"` // A function call output item in a Realtime conversation. FunctionCallOutput *MessageItemFunctionCallOutput `json:",omitempty"` // A Realtime item responding to an MCP approval request. MCPApprovalResponse *MessageItemMCPApprovalResponse `json:",omitempty"` // A Realtime item listing tools available on an MCP server. MCPListTools *MessageItemMCPListTools `json:",omitempty"` // A Realtime item representing an invocation of a tool on an MCP server. MCPToolCall *MessageItemMCPToolCall `json:",omitempty"` // A Realtime item requesting human approval of a tool invocation. MCPApprovalRequest *MessageItemMCPApprovalRequest `json:",omitempty"` } func (m MessageItemUnion) MarshalJSON() ([]byte, error) { switch { case m.System != nil: return json.Marshal(m.System) case m.User != nil: return json.Marshal(m.User) case m.Assistant != nil: return json.Marshal(m.Assistant) case m.FunctionCall != nil: return json.Marshal(m.FunctionCall) case m.FunctionCallOutput != nil: return json.Marshal(m.FunctionCallOutput) case m.MCPApprovalResponse != nil: return json.Marshal(m.MCPApprovalResponse) case m.MCPListTools != nil: return json.Marshal(m.MCPListTools) case m.MCPToolCall != nil: return json.Marshal(m.MCPToolCall) case m.MCPApprovalRequest != nil: return json.Marshal(m.MCPApprovalRequest) default: return nil, errors.New("unknown message item type") } } func (m *MessageItemUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } var t struct { Type string `json:"type"` Role string `json:"role"` } if err := json.Unmarshal(data, &t); err != nil { return err } switch MessageItemType(t.Type) { case MessageItemTypeMessage: switch MessageRole(t.Role) { case MessageRoleUser: return json.Unmarshal(data, &m.User) case MessageRoleAssistant: return json.Unmarshal(data, &m.Assistant) case MessageRoleSystem: return json.Unmarshal(data, &m.System) default: return fmt.Errorf("unknown message role: %s", t.Role) } case MessageItemTypeFunctionCall: return json.Unmarshal(data, &m.FunctionCall) case MessageItemTypeFunctionCallOutput: return json.Unmarshal(data, &m.FunctionCallOutput) case MessageItemTypeMCPApprovalResponse: return json.Unmarshal(data, &m.MCPApprovalResponse) case MessageItemTypeMCPListTools: return json.Unmarshal(data, &m.MCPListTools) case MessageItemTypeMCPCall: return json.Unmarshal(data, &m.MCPToolCall) case MessageItemTypeMCPApprovalRequest: return json.Unmarshal(data, &m.MCPApprovalRequest) default: return fmt.Errorf("unknown message item type: %s", t.Type) } } ================================================ FILE: core/http/endpoints/openai/types/server_events.go ================================================ package types import ( "encoding/json" "fmt" ) type ServerEventType string const ( ServerEventTypeError ServerEventType = "error" ServerEventTypeSessionCreated ServerEventType = "session.created" ServerEventTypeSessionUpdated ServerEventType = "session.updated" ServerEventTypeConversationItemAdded ServerEventType = "conversation.item.added" ServerEventTypeConversationItemDone ServerEventType = "conversation.item.done" ServerEventTypeConversationItemRetrieved ServerEventType = "conversation.item.retrieved" ServerEventTypeConversationItemInputAudioTranscriptionCompleted ServerEventType = "conversation.item.input_audio_transcription.completed" ServerEventTypeConversationItemInputAudioTranscriptionDelta ServerEventType = "conversation.item.input_audio_transcription.delta" ServerEventTypeConversationItemInputAudioTranscriptionSegment ServerEventType = "conversation.item.input_audio_transcription.segment" ServerEventTypeConversationItemInputAudioTranscriptionFailed ServerEventType = "conversation.item.input_audio_transcription.failed" ServerEventTypeConversationItemTruncated ServerEventType = "conversation.item.truncated" ServerEventTypeConversationItemDeleted ServerEventType = "conversation.item.deleted" ServerEventTypeInputAudioBufferCommitted ServerEventType = "input_audio_buffer.committed" ServerEventTypeInputAudioBufferCleared ServerEventType = "input_audio_buffer.cleared" ServerEventTypeInputAudioBufferSpeechStarted ServerEventType = "input_audio_buffer.speech_started" ServerEventTypeInputAudioBufferSpeechStopped ServerEventType = "input_audio_buffer.speech_stopped" ServerEventTypeInputAudioBufferTimeoutTriggered ServerEventType = "input_audio_buffer.timeout_triggered" ServerEventTypeResponseCreated ServerEventType = "response.created" ServerEventTypeResponseDone ServerEventType = "response.done" ServerEventTypeResponseOutputItemAdded ServerEventType = "response.output_item.added" ServerEventTypeResponseOutputItemDone ServerEventType = "response.output_item.done" ServerEventTypeResponseContentPartAdded ServerEventType = "response.content_part.added" ServerEventTypeResponseContentPartDone ServerEventType = "response.content_part.done" ServerEventTypeResponseOutputTextDelta ServerEventType = "response.output_text.delta" ServerEventTypeResponseOutputTextDone ServerEventType = "response.output_text.done" ServerEventTypeResponseOutputAudioTranscriptDelta ServerEventType = "response.output_audio_transcript.delta" ServerEventTypeResponseOutputAudioTranscriptDone ServerEventType = "response.output_audio_transcript.done" ServerEventTypeResponseOutputAudioDelta ServerEventType = "response.output_audio.delta" ServerEventTypeResponseOutputAudioDone ServerEventType = "response.output_audio.done" ServerEventTypeResponseFunctionCallArgumentsDelta ServerEventType = "response.function_call_arguments.delta" ServerEventTypeResponseFunctionCallArgumentsDone ServerEventType = "response.function_call_arguments.done" ServerEventTypeResponseMcpCallArgumentsDelta ServerEventType = "response.mcp_call_arguments.delta" ServerEventTypeResponseMcpCallArgumentsDone ServerEventType = "response.mcp_call_arguments.done" ServerEventTypeResponseMcpCallInProgress ServerEventType = "response.mcp_call.in_progress" ServerEventTypeResponseMcpCallCompleted ServerEventType = "response.mcp_call.completed" ServerEventTypeResponseMcpCallFailed ServerEventType = "response.mcp_call.failed" ServerEventTypeMcpListToolsInProgress ServerEventType = "mcp_list_tools.in_progress" ServerEventTypeMcpListToolsCompleted ServerEventType = "mcp_list_tools.completed" ServerEventTypeMcpListToolsFailed ServerEventType = "mcp_list_tools.failed" ServerEventTypeRateLimitsUpdated ServerEventType = "rate_limits.updated" ) // ServerEvent is the interface for server events. type ServerEvent interface { ServerEventType() ServerEventType } // ServerEventBase is the base struct for all server events. type ServerEventBase struct { EventID string `json:"event_id,omitempty"` } // Returned when an error occurs, which could be a client problem or a server problem. Most errors are recoverable and the session will stay open, we recommend to implementors to monitor and log error messages by default. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/error type ErrorEvent struct { ServerEventBase // Details of the error. Error Error `json:"error"` } func (m ErrorEvent) ServerEventType() ServerEventType { return ServerEventTypeError } func (m ErrorEvent) MarshalJSON() ([]byte, error) { type typeAlias ErrorEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when a Session is created. Emitted automatically when a new connection is established as the first server event. This event will contain the default Session configuration. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/session/created type SessionCreatedEvent struct { ServerEventBase // The session resource. Session SessionUnion `json:"session"` } func (m SessionCreatedEvent) ServerEventType() ServerEventType { return ServerEventTypeSessionCreated } func (m SessionCreatedEvent) MarshalJSON() ([]byte, error) { type typeAlias SessionCreatedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when a session is updated with a `session.update` event, unless there is an error. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/session/updated type SessionUpdatedEvent struct { ServerEventBase // The updated session resource. Session SessionUnion `json:"session"` } func (m SessionUpdatedEvent) ServerEventType() ServerEventType { return ServerEventTypeSessionUpdated } func (m SessionUpdatedEvent) MarshalJSON() ([]byte, error) { type typeAlias SessionUpdatedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when an input audio buffer is committed, either by the client or automatically in server VAD mode. // // The `item_id` property is the ID of the user message item that will be created, thus a `conversation.item.created` event will also be sent to the client. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/committed type InputAudioBufferCommittedEvent struct { ServerEventBase // The ID of the preceding item after which the new item will be inserted. PreviousItemID string `json:"previous_item_id,omitempty"` // The ID of the user message item that will be created. ItemID string `json:"item_id"` } func (m InputAudioBufferCommittedEvent) ServerEventType() ServerEventType { return ServerEventTypeInputAudioBufferCommitted } func (m InputAudioBufferCommittedEvent) MarshalJSON() ([]byte, error) { type typeAlias InputAudioBufferCommittedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the input audio buffer is cleared by the client with a `input_audio_buffer.clear` event. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/cleared type InputAudioBufferClearedEvent struct { ServerEventBase } func (m InputAudioBufferClearedEvent) ServerEventType() ServerEventType { return ServerEventTypeInputAudioBufferCleared } func (m InputAudioBufferClearedEvent) MarshalJSON() ([]byte, error) { type typeAlias InputAudioBufferClearedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Sent by the server when in `server_vad` mode to indicate that speech has been detected in the audio buffer. // // This can happen any time audio is added to the buffer (unless speech is already detected). The client may want to use this event to interrupt audio playback or provide visual feedback to the user. // // The client should expect to receive a `input_audio_buffer.speech_stopped` event when speech stops. // // The `item_id` property is the ID of the user message item that will be created when speech stops and will also be included in the `input_audio_buffer.speech_stopped` event (unless the client manually commits the audio buffer during VAD activation). // // See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/speech_started type InputAudioBufferSpeechStartedEvent struct { ServerEventBase // Milliseconds since the session started when speech was detected. AudioStartMs int64 `json:"audio_start_ms"` // The ID of the user message item that will be created when speech stops. ItemID string `json:"item_id"` } func (m InputAudioBufferSpeechStartedEvent) ServerEventType() ServerEventType { return ServerEventTypeInputAudioBufferSpeechStarted } func (m InputAudioBufferSpeechStartedEvent) MarshalJSON() ([]byte, error) { type typeAlias InputAudioBufferSpeechStartedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned in `server_vad` mode when the server detects the end of speech in the audio buffer. // // The server will also send an `conversation.item.created` event with the user message item that is created from the audio buffer. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/speech_stopped type InputAudioBufferSpeechStoppedEvent struct { ServerEventBase // Milliseconds since the session started when speech stopped. AudioEndMs int64 `json:"audio_end_ms"` // The ID of the user message item that will be created. ItemID string `json:"item_id"` } func (m InputAudioBufferSpeechStoppedEvent) ServerEventType() ServerEventType { return ServerEventTypeInputAudioBufferSpeechStopped } func (m InputAudioBufferSpeechStoppedEvent) MarshalJSON() ([]byte, error) { type typeAlias InputAudioBufferSpeechStoppedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the Server VAD timeout is triggered for the input audio buffer. // // This is configured with `idle_timeout_ms` in the `turn_detection` settings of the session, and it indicates that there hasn't been any speech detected for the configured duration. // // The `audio_start_ms` and `audio_end_ms` fields indicate the segment of audio after the last model response up to the triggering time, as an offset from the beginning of audio written to the input audio buffer. // // This means it demarcates the segment of audio that was silent and the difference between the start and end values will roughly match the configured timeout. // // The empty audio will be committed to the conversation as an `input_audio` item (there will be a `input_audio_buffer.committed` event) and a model response will be generated. // // There may be speech that didn't trigger VAD but is still detected by the model, so the model may respond with something relevant to the conversation or a prompt to continue speaking. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/input_audio_buffer/timeout_triggered type InputAudioBufferTimeoutTriggeredEvent struct { ServerEventBase // Milliseconds since the session started when speech started. AudioStartMs int64 `json:"audio_start_ms"` // Milliseconds since the session started when speech stopped. AudioEndMs int64 `json:"audio_end_ms"` // The ID of the user message item that will be created. ItemID string `json:"item_id"` } func (m InputAudioBufferTimeoutTriggeredEvent) ServerEventType() ServerEventType { return ServerEventTypeInputAudioBufferTimeoutTriggered } func (m InputAudioBufferTimeoutTriggeredEvent) MarshalJSON() ([]byte, error) { type typeAlias InputAudioBufferTimeoutTriggeredEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Sent by the server when an Item is added to the default Conversation. // // This can happen in several cases: // // - When the client sends a `conversation.item.create` event. // // - When the input audio buffer is committed. In this case the item will be a user message containing the audio from the buffer. // // - When the model is generating a Response. In this case the `conversation.item.added` event will be sent when the model starts generating a specific Item, and thus it will not yet have any content (and `status` will be `in_progress`). // // The event will include the full content of the Item (except when model is generating a Response) except for audio data, which can be retrieved separately with a `conversation.item.retrieve` event if necessary. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/item/added type ConversationItemAddedEvent struct { ServerEventBase // The ID of the preceding item after which the new item will be inserted. PreviousItemID string `json:"previous_item_id,omitempty"` // The item that was added. Item MessageItemUnion `json:"item"` } func (m ConversationItemAddedEvent) ServerEventType() ServerEventType { return ServerEventTypeConversationItemAdded } func (m ConversationItemAddedEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemAddedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when a conversation item is finalized. // // The event will include the full content of the Item except for audio data, which can be retrieved separately with a `conversation.item.retrieve` event if needed. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/item/done type ConversationItemDoneEvent struct { ServerEventBase // The ID of the preceding item after which the item appears. PreviousItemID string `json:"previous_item_id,omitempty"` // The completed item. Item MessageItemUnion `json:"item"` } func (m ConversationItemDoneEvent) ServerEventType() ServerEventType { return ServerEventTypeConversationItemDone } func (m ConversationItemDoneEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemDoneEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when a conversation item is retrieved with `conversation.item.retrieve`. This is provided as a way to fetch the server's representation of an item, for example to get access to the post-processed audio data after noise cancellation and VAD. It includes the full content of the Item, including audio data. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/item/retrieved type ConversationItemRetrievedEvent struct { ServerEventBase // The item that was retrieved. Item MessageItemUnion `json:"item"` } func (m ConversationItemRetrievedEvent) ServerEventType() ServerEventType { return ServerEventTypeConversationItemRetrieved } func (m ConversationItemRetrievedEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemRetrievedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } type Logprobs struct { // Raw byte sequence corresponding to the token (if applicable). Bytes []byte `json:"bytes,omitempty"` // Log probability of the token or segment. Logprob float64 `json:"logprob,omitempty"` // The decoded token text. Token string `json:"token,omitempty"` } // This event is the output of audio transcription for user audio written to the user audio buffer. Transcription begins when the input audio buffer is committed by the client or server (in `server_vad` mode). Transcription runs asynchronously with Response creation, so this event may come before or after the Response events. // Realtime API models accept audio natively, and thus input transcription is a separate process run on a separate ASR (Automatic Speech Recognition) model. The transcript may diverge somewhat from the model's interpretation, and should be treated as a rough guide. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/item/input_audio_transcription/completed type ConversationItemInputAudioTranscriptionCompletedEvent struct { ServerEventBase // The ID of the item. ItemID string `json:"item_id"` // The index of the content part in the item's content array. ContentIndex int `json:"content_index"` // The final transcript of the audio. Transcript string `json:"transcript"` // Log probability information for the transcription, if available. Logprobs []Logprobs `json:"logprobs,omitempty"` // Usage information for the transcription, if available. Usage *UsageUnion `json:"usage,omitempty"` } func (m ConversationItemInputAudioTranscriptionCompletedEvent) ServerEventType() ServerEventType { return ServerEventTypeConversationItemInputAudioTranscriptionCompleted } func (m ConversationItemInputAudioTranscriptionCompletedEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemInputAudioTranscriptionCompletedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the text value of an input audio transcription content part is updated with incremental transcription results. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/item/input_audio_transcription/delta type ConversationItemInputAudioTranscriptionDeltaEvent struct { ServerEventBase // The ID of the item. ItemID string `json:"item_id"` // The index of the content part in the item's content array. ContentIndex int `json:"content_index"` // The transcript delta. Delta string `json:"delta"` // Log probability updates for the delta, if available. Logprobs []Logprobs `json:"logprobs,omitempty"` } func (m ConversationItemInputAudioTranscriptionDeltaEvent) ServerEventType() ServerEventType { return ServerEventTypeConversationItemInputAudioTranscriptionDelta } func (m ConversationItemInputAudioTranscriptionDeltaEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemInputAudioTranscriptionDeltaEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when an input audio transcription segment is identified for an item. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/item/input_audio_transcription/segment type ConversationItemInputAudioTranscriptionSegmentEvent struct { ServerEventBase // The ID of the item. ItemID string `json:"item_id"` // The index of the content part in the item's content array. ContentIndex int `json:"content_index"` // Log probability information for the segment, if available. Logprobs []Logprobs `json:"logprobs,omitempty"` // The unique ID of the transcript segment. ID string `json:"id,omitempty"` // The speaker label for the segment, if available. Speaker string `json:"speaker,omitempty"` // The start time of the segment in seconds. Start float64 `json:"start,omitempty"` // The end time of the segment in seconds. End float64 `json:"end,omitempty"` // The text content of the segment. Text string `json:"text,omitempty"` } func (m ConversationItemInputAudioTranscriptionSegmentEvent) ServerEventType() ServerEventType { return ServerEventTypeConversationItemInputAudioTranscriptionSegment } func (m ConversationItemInputAudioTranscriptionSegmentEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemInputAudioTranscriptionSegmentEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when input audio transcription is configured, and a transcription request for a user message failed. These events are separate from other error events so that the client can identify the related Item. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/item/input_audio_transcription/failed type ConversationItemInputAudioTranscriptionFailedEvent struct { ServerEventBase // The ID of the item. ItemID string `json:"item_id"` // The index of the content part in the item's content array. ContentIndex int `json:"content_index"` // Details of the failure. Error Error `json:"error"` } func (m ConversationItemInputAudioTranscriptionFailedEvent) ServerEventType() ServerEventType { return ServerEventTypeConversationItemInputAudioTranscriptionFailed } func (m ConversationItemInputAudioTranscriptionFailedEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemInputAudioTranscriptionFailedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when an earlier assistant audio message item is truncated by the client with a `conversation.item.truncate` event. This event is used to synchronize the server's understanding of the audio with the client's playback. // // This action will truncate the audio and remove the server-side text transcript to ensure there is no text in the context that hasn't been heard by the user. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/item/truncated type ConversationItemTruncatedEvent struct { ServerEventBase // The ID of the assistant message item that was truncated. ItemID string `json:"item_id"` // The index of the content part that was truncated. ContentIndex int `json:"content_index"` // The duration up to which the audio was truncated, in milliseconds. AudioEndMs int `json:"audio_end_ms"` } func (m ConversationItemTruncatedEvent) ServerEventType() ServerEventType { return ServerEventTypeConversationItemTruncated } func (m ConversationItemTruncatedEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemTruncatedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when an item in the conversation is deleted by the client with a `conversation.item.delete` event. This event is used to synchronize the server's understanding of the conversation history with the client's view. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/conversation/item/deleted type ConversationItemDeletedEvent struct { ServerEventBase // The ID of the item that was deleted. ItemID string `json:"item_id"` } func (m ConversationItemDeletedEvent) ServerEventType() ServerEventType { return ServerEventTypeConversationItemDeleted } func (m ConversationItemDeletedEvent) MarshalJSON() ([]byte, error) { type typeAlias ConversationItemDeletedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when a new Response is created. The first event of response creation, where the response is in an initial state of in_progress. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/created type ResponseCreatedEvent struct { ServerEventBase // The response resource. Response Response `json:"response"` } func (m ResponseCreatedEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseCreated } func (m ResponseCreatedEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseCreatedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when a Response is done streaming. Always emitted, no matter the final state. The Response object included in the response.done event will include all output Items in the Response but will omit the raw audio data. // // Clients should check the status field of the Response to determine if it was successful (completed) or if there was another outcome: cancelled, failed, or incomplete. // // A response will contain all output items that were generated during the response, excluding any audio content. // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/done type ResponseDoneEvent struct { ServerEventBase // The response resource. Response Response `json:"response"` } func (m ResponseDoneEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseDone } func (m ResponseDoneEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseDoneEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when a new Item is created during Response generation. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_item/added type ResponseOutputItemAddedEvent struct { ServerEventBase // The ID of the response to which the item belongs. ResponseID string `json:"response_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The item that was added. Item MessageItemUnion `json:"item"` } func (m ResponseOutputItemAddedEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseOutputItemAdded } func (m ResponseOutputItemAddedEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseOutputItemAddedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when an Item is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_item/done type ResponseOutputItemDoneEvent struct { ServerEventBase // The ID of the response to which the item belongs. ResponseID string `json:"response_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The completed item. Item MessageItemUnion `json:"item"` } func (m ResponseOutputItemDoneEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseOutputItemDone } func (m ResponseOutputItemDoneEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseOutputItemDoneEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when a new content part is added to an assistant message item during response generation. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/content_part/added type ResponseContentPartAddedEvent struct { ServerEventBase ResponseID string `json:"response_id"` ItemID string `json:"item_id"` OutputIndex int `json:"output_index"` ContentIndex int `json:"content_index"` Part MessageContentOutput `json:"part"` } func (m ResponseContentPartAddedEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseContentPartAdded } func (m ResponseContentPartAddedEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseContentPartAddedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when a content part is done streaming in an assistant message item. Also emitted when a Response is interrupted, incomplete, or cancelled. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/content_part/done type ResponseContentPartDoneEvent struct { ServerEventBase // The ID of the response. ResponseID string `json:"response_id"` // The ID of the item to which the content part was added. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The index of the content part in the item's content array. ContentIndex int `json:"content_index"` // The content part that was added. Part MessageContentOutput `json:"part"` } func (m ResponseContentPartDoneEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseContentPartDone } func (m ResponseContentPartDoneEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseContentPartDoneEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the text value of an "output_text" content part is updated. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_text/delta type ResponseOutputTextDeltaEvent struct { ServerEventBase ResponseID string `json:"response_id"` ItemID string `json:"item_id"` OutputIndex int `json:"output_index"` ContentIndex int `json:"content_index"` Delta string `json:"delta"` } func (m ResponseOutputTextDeltaEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseOutputTextDelta } func (m ResponseOutputTextDeltaEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseOutputTextDeltaEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the text value of an "output_text" content part is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_text/done type ResponseOutputTextDoneEvent struct { ServerEventBase ResponseID string `json:"response_id"` ItemID string `json:"item_id"` OutputIndex int `json:"output_index"` ContentIndex int `json:"content_index"` Text string `json:"text"` } func (m ResponseOutputTextDoneEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseOutputTextDone } func (m ResponseOutputTextDoneEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseOutputTextDoneEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the model-generated transcription of audio output is updated. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_audio_transcript/delta type ResponseOutputAudioTranscriptDeltaEvent struct { ServerEventBase // The ID of the response. ResponseID string `json:"response_id"` // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The index of the content part in the item's content array. ContentIndex int `json:"content_index"` // The transcript delta. Delta string `json:"delta"` } func (m ResponseOutputAudioTranscriptDeltaEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseOutputAudioTranscriptDelta } func (m ResponseOutputAudioTranscriptDeltaEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseOutputAudioTranscriptDeltaEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the model-generated transcription of audio output is done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_audio_transcript/done type ResponseOutputAudioTranscriptDoneEvent struct { ServerEventBase // The ID of the response. ResponseID string `json:"response_id"` // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The index of the content part in the item's content array. ContentIndex int `json:"content_index"` // The final transcript of the audio. Transcript string `json:"transcript"` } func (m ResponseOutputAudioTranscriptDoneEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseOutputAudioTranscriptDone } func (m ResponseOutputAudioTranscriptDoneEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseOutputAudioTranscriptDoneEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the model-generated audio is updated. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_audio/delta type ResponseOutputAudioDeltaEvent struct { ServerEventBase // The ID of the response. ResponseID string `json:"response_id"` // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The index of the content part in the item's content array. ContentIndex int `json:"content_index"` // Base64-encoded audio data delta. Delta string `json:"delta"` } func (m ResponseOutputAudioDeltaEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseOutputAudioDelta } func (m ResponseOutputAudioDeltaEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseOutputAudioDeltaEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the model-generated audio is done. Also emitted when a Response is interrupted, incomplete, or cancelled. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/output_audio/done type ResponseOutputAudioDoneEvent struct { ServerEventBase // The ID of the response. ResponseID string `json:"response_id"` // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The index of the content part in the item's content array. ContentIndex int `json:"content_index"` } func (m ResponseOutputAudioDoneEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseOutputAudioDone } func (m ResponseOutputAudioDoneEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseOutputAudioDoneEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the model-generated function call arguments are updated. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/function_call_arguments/delta type ResponseFunctionCallArgumentsDeltaEvent struct { ServerEventBase // The ID of the response. ResponseID string `json:"response_id"` // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The ID of the function call. CallID string `json:"call_id"` // The arguments delta as a JSON string. Delta string `json:"delta"` } func (m ResponseFunctionCallArgumentsDeltaEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseFunctionCallArgumentsDelta } func (m ResponseFunctionCallArgumentsDeltaEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseFunctionCallArgumentsDeltaEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when the model-generated function call arguments are done streaming. Also emitted when a Response is interrupted, incomplete, or cancelled. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/function_call_arguments/done type ResponseFunctionCallArgumentsDoneEvent struct { ServerEventBase // The ID of the response. ResponseID string `json:"response_id"` // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The ID of the function call. CallID string `json:"call_id"` // The final arguments as a JSON string. Arguments string `json:"arguments"` // The name of the function. Not shown in API reference but present in the actual event. Name string `json:"name"` } func (m ResponseFunctionCallArgumentsDoneEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseFunctionCallArgumentsDone } func (m ResponseFunctionCallArgumentsDoneEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseFunctionCallArgumentsDoneEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when MCP tool call arguments are updated during response generation. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/mcp_call_arguments/delta type ResponseMcpCallArgumentsDeltaEvent struct { ServerEventBase // The ID of the response. ResponseID string `json:"response_id"` // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The arguments delta as a JSON string. Delta string `json:"delta"` // Obfuscation Obfuscation string `json:"obfuscation"` } func (m ResponseMcpCallArgumentsDeltaEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseMcpCallArgumentsDelta } func (m ResponseMcpCallArgumentsDeltaEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseMcpCallArgumentsDeltaEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when MCP tool call arguments are finalized during response generation. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/mcp_call_arguments/done type ResponseMcpCallArgumentsDoneEvent struct { ServerEventBase // The ID of the response. ResponseID string `json:"response_id"` // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` // The final arguments as a JSON string. Arguments string `json:"arguments"` } func (m ResponseMcpCallArgumentsDoneEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseMcpCallArgumentsDone } func (m ResponseMcpCallArgumentsDoneEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseMcpCallArgumentsDoneEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when an MCP tool call has started and is in progress. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/mcp_call/in_progress type ResponseMcpCallInProgressEvent struct { ServerEventBase // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` } func (m ResponseMcpCallInProgressEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseMcpCallInProgress } func (m ResponseMcpCallInProgressEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseMcpCallInProgressEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when an MCP tool call has completed successfully. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/mcp_call/completed type ResponseMcpCallCompletedEvent struct { ServerEventBase // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` } func (m ResponseMcpCallCompletedEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseMcpCallCompleted } func (m ResponseMcpCallCompletedEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseMcpCallCompletedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when an MCP tool call has failed. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/response/mcp_call/failed type ResponseMcpCallFailedEvent struct { ServerEventBase // The ID of the item. ItemID string `json:"item_id"` // The index of the output item in the response. OutputIndex int `json:"output_index"` } func (m ResponseMcpCallFailedEvent) ServerEventType() ServerEventType { return ServerEventTypeResponseMcpCallFailed } func (m ResponseMcpCallFailedEvent) MarshalJSON() ([]byte, error) { type typeAlias ResponseMcpCallFailedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when listing MCP tools is in progress for an item. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/mcp_list_tools/in_progress type McpListToolsInProgressEvent struct { ServerEventBase // The ID of the MCP list tools item. ItemID string `json:"item_id"` } func (m McpListToolsInProgressEvent) ServerEventType() ServerEventType { return ServerEventTypeMcpListToolsInProgress } func (m McpListToolsInProgressEvent) MarshalJSON() ([]byte, error) { type typeAlias McpListToolsInProgressEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when listing MCP tools has completed for an item. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/mcp_list_tools/completed type McpListToolsCompletedEvent struct { ServerEventBase // The ID of the MCP list tools item. ItemID string `json:"item_id"` } func (m McpListToolsCompletedEvent) ServerEventType() ServerEventType { return ServerEventTypeMcpListToolsCompleted } func (m McpListToolsCompletedEvent) MarshalJSON() ([]byte, error) { type typeAlias McpListToolsCompletedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Returned when listing MCP tools has failed for an item. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/mcp_list_tools/failed type McpListToolsFailedEvent struct { ServerEventBase // The ID of the MCP list tools item. ItemID string `json:"item_id"` } func (m McpListToolsFailedEvent) ServerEventType() ServerEventType { return ServerEventTypeMcpListToolsFailed } func (m McpListToolsFailedEvent) MarshalJSON() ([]byte, error) { type typeAlias McpListToolsFailedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } // Emitted at the beginning of a Response to indicate the updated rate limits. When a Response is created some tokens will be "reserved" for the output tokens, the rate limits shown here reflect that reservation, which is then adjusted accordingly once the Response is completed. // // See https://platform.openai.com/docs/api-reference/realtime-server-events/rate_limits/updated type RateLimitsUpdatedEvent struct { ServerEventBase // List of rate limit information. RateLimits []RateLimit `json:"rate_limits"` } func (m RateLimitsUpdatedEvent) ServerEventType() ServerEventType { return ServerEventTypeRateLimitsUpdated } func (m RateLimitsUpdatedEvent) MarshalJSON() ([]byte, error) { type typeAlias RateLimitsUpdatedEvent type typeWrapper struct { typeAlias Type ServerEventType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(m), Type: m.ServerEventType(), } return json.Marshal(shadow) } type ServerEventInterface interface { ErrorEvent | SessionCreatedEvent | SessionUpdatedEvent | ConversationItemAddedEvent | ConversationItemDoneEvent | ConversationItemRetrievedEvent | ConversationItemInputAudioTranscriptionCompletedEvent | ConversationItemInputAudioTranscriptionDeltaEvent | ConversationItemInputAudioTranscriptionSegmentEvent | ConversationItemInputAudioTranscriptionFailedEvent | ConversationItemTruncatedEvent | ConversationItemDeletedEvent | InputAudioBufferCommittedEvent | InputAudioBufferClearedEvent | InputAudioBufferSpeechStartedEvent | InputAudioBufferSpeechStoppedEvent | InputAudioBufferTimeoutTriggeredEvent | ResponseCreatedEvent | ResponseDoneEvent | ResponseOutputItemAddedEvent | ResponseOutputItemDoneEvent | ResponseContentPartAddedEvent | ResponseContentPartDoneEvent | ResponseOutputTextDeltaEvent | ResponseOutputTextDoneEvent | ResponseOutputAudioTranscriptDeltaEvent | ResponseOutputAudioTranscriptDoneEvent | ResponseOutputAudioDeltaEvent | ResponseOutputAudioDoneEvent | ResponseFunctionCallArgumentsDeltaEvent | ResponseFunctionCallArgumentsDoneEvent | ResponseMcpCallArgumentsDeltaEvent | ResponseMcpCallArgumentsDoneEvent | ResponseMcpCallInProgressEvent | ResponseMcpCallCompletedEvent | ResponseMcpCallFailedEvent | McpListToolsInProgressEvent | McpListToolsCompletedEvent | McpListToolsFailedEvent | RateLimitsUpdatedEvent } func unmarshalServerEvent[T ServerEventInterface](data []byte) (T, error) { var t T err := json.Unmarshal(data, &t) if err != nil { return t, err } return t, nil } // UnmarshalServerEvent unmarshals the server event from the given JSON data. func UnmarshalServerEvent(data []byte) (ServerEvent, error) { //nolint:funlen,cyclop,gocyclo // TODO: optimize var eventType struct { Type ServerEventType `json:"type"` } err := json.Unmarshal(data, &eventType) if err != nil { return nil, err } switch eventType.Type { case ServerEventTypeError: return unmarshalServerEvent[ErrorEvent](data) case ServerEventTypeSessionCreated: return unmarshalServerEvent[SessionCreatedEvent](data) case ServerEventTypeSessionUpdated: return unmarshalServerEvent[SessionUpdatedEvent](data) case ServerEventTypeConversationItemAdded: return unmarshalServerEvent[ConversationItemAddedEvent](data) case ServerEventTypeConversationItemDone: return unmarshalServerEvent[ConversationItemDoneEvent](data) case ServerEventTypeConversationItemRetrieved: return unmarshalServerEvent[ConversationItemRetrievedEvent](data) case ServerEventTypeConversationItemInputAudioTranscriptionCompleted: return unmarshalServerEvent[ConversationItemInputAudioTranscriptionCompletedEvent](data) case ServerEventTypeConversationItemInputAudioTranscriptionDelta: return unmarshalServerEvent[ConversationItemInputAudioTranscriptionDeltaEvent](data) case ServerEventTypeConversationItemInputAudioTranscriptionSegment: return unmarshalServerEvent[ConversationItemInputAudioTranscriptionSegmentEvent](data) case ServerEventTypeConversationItemInputAudioTranscriptionFailed: return unmarshalServerEvent[ConversationItemInputAudioTranscriptionFailedEvent](data) case ServerEventTypeConversationItemTruncated: return unmarshalServerEvent[ConversationItemTruncatedEvent](data) case ServerEventTypeConversationItemDeleted: return unmarshalServerEvent[ConversationItemDeletedEvent](data) case ServerEventTypeInputAudioBufferCommitted: return unmarshalServerEvent[InputAudioBufferCommittedEvent](data) case ServerEventTypeInputAudioBufferCleared: return unmarshalServerEvent[InputAudioBufferClearedEvent](data) case ServerEventTypeInputAudioBufferSpeechStarted: return unmarshalServerEvent[InputAudioBufferSpeechStartedEvent](data) case ServerEventTypeInputAudioBufferSpeechStopped: return unmarshalServerEvent[InputAudioBufferSpeechStoppedEvent](data) case ServerEventTypeInputAudioBufferTimeoutTriggered: return unmarshalServerEvent[InputAudioBufferTimeoutTriggeredEvent](data) case ServerEventTypeResponseCreated: return unmarshalServerEvent[ResponseCreatedEvent](data) case ServerEventTypeResponseDone: return unmarshalServerEvent[ResponseDoneEvent](data) case ServerEventTypeResponseOutputItemAdded: return unmarshalServerEvent[ResponseOutputItemAddedEvent](data) case ServerEventTypeResponseOutputItemDone: return unmarshalServerEvent[ResponseOutputItemDoneEvent](data) case ServerEventTypeResponseContentPartAdded: return unmarshalServerEvent[ResponseContentPartAddedEvent](data) case ServerEventTypeResponseContentPartDone: return unmarshalServerEvent[ResponseContentPartDoneEvent](data) case ServerEventTypeResponseOutputTextDelta: return unmarshalServerEvent[ResponseOutputTextDeltaEvent](data) case ServerEventTypeResponseOutputTextDone: return unmarshalServerEvent[ResponseOutputTextDoneEvent](data) case ServerEventTypeResponseOutputAudioTranscriptDelta: return unmarshalServerEvent[ResponseOutputAudioTranscriptDeltaEvent](data) case ServerEventTypeResponseOutputAudioTranscriptDone: return unmarshalServerEvent[ResponseOutputAudioTranscriptDoneEvent](data) case ServerEventTypeResponseOutputAudioDelta: return unmarshalServerEvent[ResponseOutputAudioDeltaEvent](data) case ServerEventTypeResponseOutputAudioDone: return unmarshalServerEvent[ResponseOutputAudioDoneEvent](data) case ServerEventTypeResponseFunctionCallArgumentsDelta: return unmarshalServerEvent[ResponseFunctionCallArgumentsDeltaEvent](data) case ServerEventTypeResponseFunctionCallArgumentsDone: return unmarshalServerEvent[ResponseFunctionCallArgumentsDoneEvent](data) case ServerEventTypeResponseMcpCallArgumentsDelta: return unmarshalServerEvent[ResponseMcpCallArgumentsDeltaEvent](data) case ServerEventTypeResponseMcpCallArgumentsDone: return unmarshalServerEvent[ResponseMcpCallArgumentsDoneEvent](data) case ServerEventTypeResponseMcpCallInProgress: return unmarshalServerEvent[ResponseMcpCallInProgressEvent](data) case ServerEventTypeResponseMcpCallCompleted: return unmarshalServerEvent[ResponseMcpCallCompletedEvent](data) case ServerEventTypeResponseMcpCallFailed: return unmarshalServerEvent[ResponseMcpCallFailedEvent](data) case ServerEventTypeMcpListToolsInProgress: return unmarshalServerEvent[McpListToolsInProgressEvent](data) case ServerEventTypeMcpListToolsCompleted: return unmarshalServerEvent[McpListToolsCompletedEvent](data) case ServerEventTypeMcpListToolsFailed: return unmarshalServerEvent[McpListToolsFailedEvent](data) case ServerEventTypeRateLimitsUpdated: return unmarshalServerEvent[RateLimitsUpdatedEvent](data) default: // This should never happen. return nil, fmt.Errorf("unknown server event type: %s", eventType.Type) } } ================================================ FILE: core/http/endpoints/openai/types/types.go ================================================ package types import ( "bytes" "encoding/json" "errors" "fmt" ) // The voice the model uses to respond. Voice cannot be changed during the session once the model has responded with audio at least once. Current voice options are alloy, ash, ballad, coral, echo, sage, shimmer, verse, marin, and cedar. We recommend marin and cedar for best quality. type Voice string const ( VoiceAlloy Voice = "alloy" VoiceAsh Voice = "ash" VoiceBallad Voice = "ballad" VoiceCoral Voice = "coral" VoiceEcho Voice = "echo" VoiceSage Voice = "sage" VoiceShimmer Voice = "shimmer" VoiceVerse Voice = "verse" VoiceMarin Voice = "marin" VoiceCedar Voice = "cedar" VoiceFable Voice = "fable" VoiceOnyx Voice = "onyx" VoiceNova Voice = "nova" ) type AudioFormat string const ( AudioFormatPcm16 AudioFormat = "pcm16" AudioFormatG711Ulaw AudioFormat = "g711_ulaw" AudioFormatG711Alaw AudioFormat = "g711_alaw" ) type Modality string const ( ModalityText Modality = "text" ModalityAudio Modality = "audio" ) type TurnDetectionType string const ( TurnDetectionTypeServerVad TurnDetectionType = "server_vad" TurnDetectionTypeSemanticVad TurnDetectionType = "semantic_vad" ) type ToolChoiceMode string const ( ToolChoiceModeNone ToolChoiceMode = "none" ToolChoiceModeAuto ToolChoiceMode = "auto" ToolChoiceModeRequired ToolChoiceMode = "required" ) func (t ToolChoiceMode) ToolChoiceType() string { return string(t) } type ToolChoiceType string const ( ToolChoiceTypeFunction ToolChoiceType = "function" ToolChoiceTypeMCP ToolChoiceType = "mcp" ) type ToolChoiceFunction struct { // The name of the function to call. Name string `json:"name,omitempty"` } func (t ToolChoiceFunction) ToolChoiceType() string { return string(ToolChoiceTypeFunction) } func (t ToolChoiceFunction) MarshalJSON() ([]byte, error) { type typeAlias ToolChoiceFunction type typeWrapper struct { typeAlias Type string `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(t), Type: t.ToolChoiceType(), } return json.Marshal(shadow) } type ToolChoiceMCP struct { // The label of the MCP server to use. ServerLabel string `json:"server_label,omitempty"` // The name of the tool to call on the server. Name string `json:"name,omitempty"` } func (t ToolChoiceMCP) ToolChoiceType() string { return string(ToolChoiceTypeMCP) } func (t ToolChoiceMCP) MarshalJSON() ([]byte, error) { type typeAlias ToolChoiceMCP type typeWrapper struct { typeAlias Type string `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(t), Type: t.ToolChoiceType(), } return json.Marshal(shadow) } type ToolChoiceUnion struct { // Controls which (if any) tool is called by the model. // // none means the model will not call any tool and instead generates a message. // // auto means the model can pick between generating a message or calling one or more tools. // // required means the model must call one or more tools. Mode ToolChoiceMode `json:",omitempty"` // Use this option to force the model to call a specific function. Function *ToolChoiceFunction `json:",omitempty"` // Use this option to force the model to call a specific tool on a remote MCP server. MCP *ToolChoiceMCP `json:",omitempty"` } func (t ToolChoiceUnion) MarshalJSON() ([]byte, error) { if t.Function != nil { return json.Marshal(t.Function) } if t.MCP != nil { return json.Marshal(t.MCP) } return json.Marshal(t.Mode) } func (t *ToolChoiceUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } var u typeStruct if err := json.Unmarshal(data, &u); err != nil { t.Mode = ToolChoiceMode(bytes.Trim(data, "\"")) return nil //nolint: nilerr // data is string instead of object } switch ToolChoiceType(u.Type) { case ToolChoiceTypeFunction: return json.Unmarshal(data, &t.Function) case ToolChoiceTypeMCP: return json.Unmarshal(data, &t.MCP) default: t.Mode = ToolChoiceMode(u.Type) } return nil } type ToolType string const ( ToolTypeFunction ToolType = "function" ToolTypeMCP ToolType = "mcp" ) type ToolFunction struct { // The name of the function. Name string `json:"name"` // The description of the function, including guidance on when and how to call it, and guidance about what to tell the user when calling (if anything). Description string `json:"description"` // The jsonschema representing the parameters Parameters any `json:"parameters,omitempty"` } func (t ToolFunction) ToolType() ToolType { return ToolTypeFunction } func (t ToolFunction) MarshalJSON() ([]byte, error) { type typeAlias ToolFunction type toolFunction struct { typeAlias Type ToolType `json:"type"` } shadow := toolFunction{ typeAlias: typeAlias(t), Type: t.ToolType(), } return json.Marshal(shadow) } type MCPToolFilter struct { // Indicates whether or not a tool modifies data or is read-only. If an MCP server is annotated with readOnlyHint, it will match this filter. ReadOnly bool `json:"read_only,omitempty"` // List of allowed tool names. ToolNames []string `json:"tool_names,omitempty"` } type MCPAllowedToolsUnion struct { // A string array of allowed tool names ToolNames []string `json:",omitempty"` // A filter object to specify which tools are allowed. Filter *MCPToolFilter `json:",omitempty"` } func (t MCPAllowedToolsUnion) MarshalJSON() ([]byte, error) { if len(t.ToolNames) > 0 { return json.Marshal(t.ToolNames) } return json.Marshal(t.Filter) } func (t *MCPAllowedToolsUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } if err := json.Unmarshal(data, &t.Filter); err == nil { return nil } return json.Unmarshal(data, &t.ToolNames) } type MCPRequireApprovalFilter struct { // A filter object to specify which tools are allowed. Always *MCPToolFilter `json:",omitempty"` // A filter object to specify which tools are allowed. Never *MCPToolFilter `json:",omitempty"` } type MCPToolRequireApprovalUnion struct { // Specify which of the MCP server's tools require approval. Can be always, never, or a filter object associated with tools that require approval. Filter *MCPRequireApprovalFilter `json:",omitempty"` // Specify a single approval policy for all tools. One of always or never. When set to always, all tools will require approval. When set to never, all tools will not require approval. Setting string `json:",omitempty"` } func (t MCPToolRequireApprovalUnion) MarshalJSON() ([]byte, error) { if t.Filter != nil { return json.Marshal(t.Filter) } return json.Marshal(t.Setting) } func (t *MCPToolRequireApprovalUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } if err := json.Unmarshal(data, &t.Filter); err == nil { return nil } return json.Unmarshal(data, &t.Setting) } type ToolMCP struct { // A label for this MCP server, used to identify it in tool calls. ServerLabel string `json:"server_label,omitempty"` // An OAuth access token that can be used with a remote MCP server, either with a custom MCP server URL or a service connector. Your application must handle the OAuth authorization flow and provide the token here. Authorization string `json:"authorization,omitempty"` // Optional description of the MCP server, used to provide more context. ServerDescription string `json:"server_description,omitempty"` // The URL for the MCP server. One of server_url or connector_id must be provided. ServerURL string `json:"server_url,omitempty"` // List of allowed tool names or a filter object. AllowedTools *MCPAllowedToolsUnion `json:"allowed_tools,omitempty"` // Optional HTTP headers to send to the MCP server. Use for authentication or other purposes. Headers map[string]string `json:"headers,omitempty"` // Specify which of the MCP server's tools require approval. RequireApproval *MCPToolRequireApprovalUnion `json:"require_approval,omitempty"` // Identifier for service connectors, like those available in ChatGPT. One of server_url or connector_id must be provided. Learn more about service connectors here. // // Currently supported connector_id values are: // // Dropbox: connector_dropbox // Gmail: connector_gmail // Google Calendar: connector_googlecalendar // Google Drive: connector_googledrive // Microsoft Teams: connector_microsoftteams // Outlook Calendar: connector_outlookcalendar // Outlook Email: connector_outlookemail // SharePoint: connector_sharepoint ConnectorID string `json:"connector_id,omitempty"` } func (t ToolMCP) ToolType() ToolType { return ToolTypeMCP } func (t ToolMCP) MarshalJSON() ([]byte, error) { type typeAlias ToolMCP type toolMCP struct { typeAlias Type ToolType `json:"type"` } shadow := toolMCP{ typeAlias: typeAlias(t), Type: t.ToolType(), } return json.Marshal(shadow) } type TracingConfiguration struct { GroupID string `json:"group_id,omitempty"` Metadata any `json:"metadata,omitempty"` WorkflowName string `json:"workflow_name,omitempty"` } type ToolUnion struct { Function *ToolFunction `json:",omitempty"` // Give the model access to additional tools via remote Model Context Protocol (MCP) servers. Learn more about MCP. MCP *ToolMCP `json:",omitempty"` } func (t ToolUnion) MarshalJSON() ([]byte, error) { if t.Function != nil { return json.Marshal(t.Function) } if t.MCP != nil { return json.Marshal(t.MCP) } return nil, errors.New("no tool") } func (t *ToolUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } var u typeStruct if err := json.Unmarshal(data, &u); err != nil { return err } switch ToolType(u.Type) { case ToolTypeFunction: return json.Unmarshal(data, &t.Function) case ToolTypeMCP: return json.Unmarshal(data, &t.MCP) default: return fmt.Errorf("unknown tool type: %s", u.Type) } } type TracingMode string const ( TracingModeAuto = "auto" ) type TracingUnion struct { Mode TracingMode `json:",omitempty"` Configuration *TracingConfiguration `json:",omitempty"` } type TruncationStrategy string const ( TruncationStrategyAuto TruncationStrategy = "auto" TruncationStrategyDisabled TruncationStrategy = "disabled" TruncationStrategyRetentionRatio TruncationStrategy = "retention_ratio" ) func (t TruncationStrategy) TruncationStrategy() string { return string(t) } type RetentionRatioTruncation struct { Ratio float32 `json:"retention_ratio,omitempty"` } func (t RetentionRatioTruncation) TruncationStrategy() string { return string(TruncationStrategyRetentionRatio) } type TruncationUnion struct { Strategy TruncationStrategy `json:",omitempty"` RetentionRatioTruncation *RetentionRatioTruncation `json:",omitempty"` } const nullString = "null" func isNull(data []byte) bool { return len(data) == len(nullString) && string(data) == nullString } func (t *TruncationUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } var u typeStruct if err := json.Unmarshal(data, &u); err != nil { t.Strategy = TruncationStrategy(bytes.Trim(data, "\"")) return nil //nolint: nilerr // data is string instead of object } switch TruncationStrategy(u.Type) { case TruncationStrategyRetentionRatio: return json.Unmarshal(data, &t.RetentionRatioTruncation) case TruncationStrategyDisabled, TruncationStrategyAuto: t.Strategy = TruncationStrategy(data) default: return fmt.Errorf("unknown truncation strategy: %s", u.Type) } return nil } type ResponseAudioOutput struct { // The format of the output audio. Format *AudioFormatUnion `json:"format,omitempty"` // The voice the model uses to respond. Voice cannot be changed during the session once the model has responded with audio at least once. Current voice options are alloy, ash, ballad, coral, echo, sage, shimmer, verse, marin, and cedar. We recommend marin and cedar for best quality. Voice Voice `json:"voice,omitempty"` } type ResponseAudio struct { Output *ResponseAudioOutput `json:"output,omitempty"` } type MessageRole string const ( MessageRoleSystem MessageRole = "system" MessageRoleAssistant MessageRole = "assistant" MessageRoleUser MessageRole = "user" ) type Tool struct { Type ToolType `json:"type"` Name string `json:"name"` Description string `json:"description"` Parameters any `json:"parameters"` } type ResponseMessageItem struct { MessageItemUnion // The object type, must be "realtime.item". Object string `json:"object,omitempty"` } type Error struct { // The type of error (e.g., "invalid_request_error", "server_error"). Message string `json:"message,omitempty"` // Error code, if any. Type string `json:"type,omitempty"` // A human-readable error message. Code string `json:"code,omitempty"` // Parameter related to the error, if any. Param string `json:"param,omitempty"` // The event_id of the client event that caused the error, if applicable. EventID string `json:"event_id,omitempty"` } type AudioFormatType string const ( AudioFormatTypePCM AudioFormatType = "audio/pcm" AudioFormatTypePCMU AudioFormatType = "audio/pcmu" AudioFormatTypePCMA AudioFormatType = "audio/pcma" ) // The PCM audio format. Only a 24kHz sample rate is supported. type AudioFormatPCM struct { // The sample rate of the audio. Always 24000. Rate int `json:"rate,omitempty"` } func (p AudioFormatPCM) AudioFormat() string { return string(AudioFormatTypePCM) } func (p AudioFormatPCM) MarshalJSON() ([]byte, error) { type typeAlias AudioFormatPCM type typeWrapper struct { typeAlias Type string `json:"type,omitempty"` } return json.Marshal(typeWrapper{ typeAlias: typeAlias(p), Type: p.AudioFormat(), }) } // The G.711 μ-law format. type AudioFormatPCMU struct { } func (p AudioFormatPCMU) AudioFormat() string { return string(AudioFormatTypePCMU) } func (p AudioFormatPCMU) MarshalJSON() ([]byte, error) { type typeAlias AudioFormatPCMU type typeWrapper struct { typeAlias Type string `json:"type,omitempty"` } return json.Marshal(typeWrapper{ typeAlias: typeAlias(p), Type: p.AudioFormat(), }) } // The G.711 A-law format. type AudioFormatPCMA struct { } func (p AudioFormatPCMA) AudioFormat() string { return string(AudioFormatTypePCMA) } func (p AudioFormatPCMA) MarshalJSON() ([]byte, error) { type typeAlias AudioFormatPCMA type typeWrapper struct { typeAlias Type string `json:"type,omitempty"` } return json.Marshal(typeWrapper{ typeAlias: typeAlias(p), Type: p.AudioFormat(), }) } type AudioFormatUnion struct { // The PCM audio format. Only a 24kHz sample rate is supported. PCM *AudioFormatPCM `json:",omitempty"` // The G.711 μ-law format. PCMU *AudioFormatPCMU `json:",omitempty"` // The G.711 A-law format. PCMA *AudioFormatPCMA `json:",omitempty"` } func (r AudioFormatUnion) MarshalJSON() ([]byte, error) { if r.PCM != nil { return json.Marshal(r.PCM) } if r.PCMU != nil { return json.Marshal(r.PCMU) } if r.PCMA != nil { return json.Marshal(r.PCMA) } return nil, errors.New("no audio format") } func (r *AudioFormatUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } type typeStruct struct { Type string `json:"type"` } var t typeStruct if err := json.Unmarshal(data, &t); err != nil { return err } switch AudioFormatType(t.Type) { case AudioFormatTypePCM: r.PCM = &AudioFormatPCM{} return json.Unmarshal(data, r.PCM) case AudioFormatTypePCMU: r.PCMU = &AudioFormatPCMU{} return json.Unmarshal(data, r.PCMU) case AudioFormatTypePCMA: r.PCMA = &AudioFormatPCMA{} return json.Unmarshal(data, r.PCMA) default: return fmt.Errorf("unknown audio format: %s", t.Type) } } type AudioNoiseReduction struct { // Type of noise reduction. near_field is for close-talking microphones such as headphones, far_field is for far-field microphones such as laptop or conference room microphones. Type NoiseReductionType `json:"type,omitempty"` } type ServerVad struct { // Optional timeout after which a model response will be triggered automatically. This is useful for situations in which a long pause from the user is unexpected, such as a phone call. The model will effectively prompt the user to continue the conversation based on the current context. // // The timeout value will be applied after the last model response's audio has finished playing, i.e. it's set to the response.done time plus audio playback duration. // // An input_audio_buffer.timeout_triggered event (plus events associated with the Response) will be emitted when the timeout is reached. Idle timeout is currently only supported for server_vad mode. IdleTimeoutMs int64 `json:"idle_timeout_ms,omitempty"` // Whether or not to automatically generate a response when a VAD stop event occurs. CreateResponse bool `json:"create_response,omitempty"` // Whether or not to automatically interrupt any ongoing response with output to the default conversation (i.e. conversation of auto) when a VAD start event occurs. InterruptResponse bool `json:"interrupt_response,omitempty"` // Used only for server_vad mode. Amount of audio to include before the VAD detected speech (in milliseconds). Defaults to 300ms. PrefixPaddingMs int64 `json:"prefix_padding_ms,omitempty"` // Used only for server_vad mode. Duration of silence to detect speech stop (in milliseconds). Defaults to 500ms. With shorter values the model will respond more quickly, but may jump in on short pauses from the user. SilenceDurationMs int64 `json:"silence_duration_ms,omitempty"` // Used only for server_vad mode. Activation threshold for VAD (0.0 to 1.0), this defaults to 0.5. A higher threshold will require louder audio to activate the model, and thus might perform better in noisy environments. Threshold float64 `json:"threshold,omitempty"` } func (r ServerVad) VadType() TurnDetectionType { return TurnDetectionTypeServerVad } func (r ServerVad) MarshalJSON() ([]byte, error) { type typeAlias ServerVad type typeWrapper struct { typeAlias Type TurnDetectionType `json:"type,omitempty"` } shadow := typeWrapper{ typeAlias: typeAlias(r), Type: TurnDetectionTypeServerVad, } return json.Marshal(shadow) } type RealtimeSessionSemanticVad struct { // Whether or not to automatically generate a response when a VAD stop event occurs. CreateResponse bool `json:"create_response,omitempty"` // Whether or not to automatically interrupt any ongoing response with output to the default conversation (i.e. conversation of auto) when a VAD start event occurs. InterruptResponse bool `json:"interrupt_response,omitempty"` // Used only for semantic_vad mode. The eagerness of the model to respond. low will wait longer for the user to continue speaking, high will respond more quickly. auto is the default and is equivalent to medium. low, medium, and high have max timeouts of 8s, 4s, and 2s respectively. Eagerness string `json:"eagerness,omitempty"` } func (r RealtimeSessionSemanticVad) VadType() TurnDetectionType { return TurnDetectionTypeSemanticVad } func (r RealtimeSessionSemanticVad) MarshalJSON() ([]byte, error) { type typeAlias RealtimeSessionSemanticVad type typeWrapper struct { typeAlias Type TurnDetectionType `json:"type,omitempty"` } shadow := typeWrapper{ typeAlias: typeAlias(r), Type: TurnDetectionTypeSemanticVad, } return json.Marshal(shadow) } type TurnDetectionUnion struct { // Server-side voice activity detection (VAD) which flips on when user speech is detected and off after a period of silence. ServerVad *ServerVad `json:",omitempty"` // Server-side semantic turn detection which uses a model to determine when the user has finished speaking. SemanticVad *RealtimeSessionSemanticVad `json:",omitempty"` } func (r TurnDetectionUnion) MarshalJSON() ([]byte, error) { if r.ServerVad != nil { return json.Marshal(r.ServerVad) } if r.SemanticVad != nil { return json.Marshal(r.SemanticVad) } return nil, errors.New("no turn detection") } func (r *TurnDetectionUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } var t typeStruct if err := json.Unmarshal(data, &t); err != nil { return err } switch TurnDetectionType(t.Type) { case TurnDetectionTypeServerVad: return json.Unmarshal(data, &r.ServerVad) case TurnDetectionTypeSemanticVad: return json.Unmarshal(data, &r.SemanticVad) default: return fmt.Errorf("unknown turn detection type: %s", t.Type) } } type AudioTranscription struct { // The language of the input audio. Supplying the input language in ISO-639-1 (e.g. en) format will improve accuracy and latency. Language string `json:"language,omitempty"` // An optional text to guide the model's style or continue a previous audio segment. For whisper-1, the prompt is a list of keywords. For gpt-4o-transcribe models (excluding gpt-4o-transcribe-diarize), the prompt is a free text string, for example "expect words related to technology". Prompt string `json:"prompt,omitempty"` // The model to use for transcription. Current options are whisper-1, gpt-4o-mini-transcribe, gpt-4o-transcribe, and gpt-4o-transcribe-diarize. Use gpt-4o-transcribe-diarize when you need diarization with speaker labels. Model string `json:"model,omitempty"` } type SessionAudioInput struct { Format *AudioFormatUnion `json:"format,omitempty"` // Configuration for input audio noise reduction. This can be set to null to turn off. Noise reduction filters audio added to the input audio buffer before it is sent to VAD and the model. Filtering the audio can improve VAD and turn detection accuracy (reducing false positives) and model performance by improving perception of the input audio. NoiseReduction *AudioNoiseReduction `json:"noise_reduction,omitempty"` // Configuration for turn detection: Server VAD or Semantic VAD. Set to null // to turn off, in which case the client must manually trigger model response. TurnDetection *TurnDetectionUnion `json:"turn_detection,omitempty"` // True when the JSON payload explicitly included "turn_detection" (even as null). // Standard Go JSON can't distinguish absent from null for pointer fields. TurnDetectionSet bool `json:"-"` // Configuration for input audio transcription, defaults to off and can be // set to null to turn off once on. Transcription *AudioTranscription `json:"transcription,omitempty"` } func (s *SessionAudioInput) UnmarshalJSON(data []byte) error { // Check whether turn_detection key exists in the raw JSON. var raw map[string]json.RawMessage if err := json.Unmarshal(data, &raw); err != nil { return err } type alias SessionAudioInput var a alias if err := json.Unmarshal(data, &a); err != nil { return err } *s = SessionAudioInput(a) if _, ok := raw["turn_detection"]; ok { s.TurnDetectionSet = true } return nil } type SessionAudioOutput struct { Format *AudioFormatUnion `json:"format,omitempty"` Speed float32 `json:"speed,omitempty"` Voice Voice `json:"voice,omitempty"` } type RealtimeSessionAudio struct { Input *SessionAudioInput `json:"input,omitempty"` Output *SessionAudioOutput `json:"output,omitempty"` } type TranscriptionSessionAudio struct { Input *SessionAudioInput `json:"input,omitempty"` } type PromptInputType string const ( PromptInputTypeText PromptInputType = "input_text" PromptInputTypeImage PromptInputType = "input_image" PromptInputTypeFile PromptInputType = "input_file" ) // The detail level of the image to be sent to the model. One of `high`, `low`, or // `auto`. Defaults to `auto`. type ImageDetail string const ( ImageDetailLow ImageDetail = "low" ImageDetailHigh ImageDetail = "high" ImageDetailAuto ImageDetail = "auto" ) type PromptInputText struct { Text string `json:"text"` } func (r PromptInputText) PromptInputType() PromptInputType { return PromptInputTypeText } func (r PromptInputText) MarshalJSON() ([]byte, error) { type typeAlias PromptInputText type typeWrapper struct { typeAlias Type PromptInputType `json:"type,omitempty"` } shadow := typeWrapper{ typeAlias: typeAlias(r), Type: r.PromptInputType(), } return json.Marshal(shadow) } type PromptInputImage struct { Detail ImageDetail `json:"detail,omitempty"` FileID string `json:"file_id,omitempty"` ImageURL string `json:"image_url,omitempty"` } func (r PromptInputImage) PromptInputType() PromptInputType { return PromptInputTypeImage } func (r PromptInputImage) MarshalJSON() ([]byte, error) { type typeAlias PromptInputImage type typeWrapper struct { typeAlias Type PromptInputType `json:"type,omitempty"` } shadow := typeWrapper{ typeAlias: typeAlias(r), Type: r.PromptInputType(), } return json.Marshal(shadow) } type PromptInputFile struct { FileID string `json:"file_id,omitempty"` FileData string `json:"file_data,omitempty"` FileURL string `json:"file_url,omitempty"` Filename string `json:"filename,omitempty"` } func (r PromptInputFile) PromptInputType() PromptInputType { return PromptInputTypeFile } func (r PromptInputFile) MarshalJSON() ([]byte, error) { type typeAlias PromptInputFile type typeWrapper struct { typeAlias Type PromptInputType `json:"type,omitempty"` } shadow := typeWrapper{ typeAlias: typeAlias(r), Type: r.PromptInputType(), } return json.Marshal(shadow) } type PromptVariableUnion struct { String string `json:",omitempty"` InputText *PromptInputText `json:",omitempty"` InputImage *PromptInputImage `json:",omitempty"` InputFile *PromptInputFile `json:",omitempty"` } type typeStruct struct { Type string `json:"type"` } func (u *PromptVariableUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } var t typeStruct if err := json.Unmarshal(data, &t); err != nil { return err } switch PromptInputType(t.Type) { case PromptInputTypeText: u.InputText = &PromptInputText{} return json.Unmarshal(data, u.InputText) case PromptInputTypeImage: u.InputImage = &PromptInputImage{} return json.Unmarshal(data, u.InputImage) case PromptInputTypeFile: u.InputFile = &PromptInputFile{} return json.Unmarshal(data, u.InputFile) default: return fmt.Errorf("unknown input type: %s", t.Type) } } type PromptReference struct { // The unique identifier of the prompt template to use. ID string `json:"id,omitempty"` // Optional version of the prompt template. Version string `json:"version,omitempty"` // Optional map of values to substitute in for variables in your prompt. The substitution values can either be strings, or other Response input types like images or files. Variables map[string]PromptVariableUnion `json:"variables,omitempty"` } type SessionType string const ( SessionTypeRealtime SessionType = "realtime" SessionTypeTranscription SessionType = "transcription" ) type RealtimeSession struct { // Unique identifier for the session that looks like sess_1234567890abcdef. ID string `json:"id,omitempty"` // Expiration timestamp for the session, in seconds since epoch. ExpiresAt int64 `json:"expires_at,omitempty"` // The object type. Always realtime.session. Object string `json:"object,omitempty"` // Configuration for input and output audio. Audio *RealtimeSessionAudio `json:"audio,omitempty"` // Additional fields to include in server outputs. // // `item.input_audio_transcription.logprobs`: Include logprobs for input audio // transcription. // // Any of "item.input_audio_transcription.logprobs". Include []string `json:"include,omitempty"` // The default system instructions (i.e. system message) prepended to model calls. This field allows the client to guide the model on desired responses. The model can be instructed on response content and format, (e.g. "be extremely succinct", "act friendly", "here are examples of good responses") and on audio behavior (e.g. "talk quickly", "inject emotion into your voice", "laugh frequently"). The instructions are not guaranteed to be followed by the model, but they provide guidance to the model on the desired behavior. // // Note that the server sets default instructions which will be used if this field is not set and are visible in the session.created event at the start of the session. Instructions string `json:"instructions,omitempty"` // Maximum number of output tokens for a single assistant response, inclusive of tool calls. Provide an integer between 1 and 4096 to limit output tokens, or inf for the maximum available tokens for a given model. Defaults to inf. MaxOutputTokens IntOrInf `json:"max_output_tokens,omitempty"` // The Realtime model used for this session. Model string `json:"model,omitempty"` // The set of modalities the model can respond with. It defaults to ["audio"], indicating that the model will respond with audio plus a transcript. ["text"] can be used to make the model respond with text only. It is not possible to request both text and audio at the same time. OutputModalities []Modality `json:"output_modalities,omitempty"` // Reference to a prompt template and its variables. Prompt *PromptReference `json:"prompt,omitempty"` // How the model chooses tools. Provide one of the string modes or force a specific function/MCP tool. ToolChoice *ToolChoiceUnion `json:"tool_choice,omitempty"` // Tools available to the model. Tools []ToolUnion `json:"tools,omitempty"` // Realtime API can write session traces to the Traces Dashboard. Set to null to disable tracing. Once tracing is enabled for a session, the configuration cannot be modified. // // auto will create a trace for the session with default values for the workflow name, group id, and metadata. Tracing *TracingUnion `json:"tracing,omitempty"` // Controls how the realtime conversation is truncated prior to model inference. The default is auto. Truncation *TruncationUnion `json:"truncation,omitempty"` } func (r RealtimeSession) Type() SessionType { return SessionTypeRealtime } func (r RealtimeSession) MarshalJSON() ([]byte, error) { type typeAlias RealtimeSession type typeWrapper struct { typeAlias Type SessionType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(r), Type: r.Type(), } return json.Marshal(shadow) } type TranscriptionSession struct { // Unique identifier for the session that looks like sess_1234567890abcdef. ID string `json:"id,omitempty"` // Expiration timestamp for the session, in seconds since epoch. ExpiresAt int64 `json:"expires_at,omitempty"` // The object type. Always realtime.transcription_session. Object string `json:"object,omitempty"` // Configuration for input audio. Audio *TranscriptionSessionAudio `json:"audio,omitempty"` // Additional fields to include in server outputs. // // `item.input_audio_transcription.logprobs`: Include logprobs for input audio // transcription. // // Any of "item.input_audio_transcription.logprobs". Include []string `json:"include,omitempty"` } func (r TranscriptionSession) Type() SessionType { return SessionTypeTranscription } func (r TranscriptionSession) MarshalJSON() ([]byte, error) { type typeAlias TranscriptionSession type typeWrapper struct { typeAlias Type SessionType `json:"type"` } shadow := typeWrapper{ typeAlias: typeAlias(r), Type: r.Type(), } return json.Marshal(shadow) } type SessionUnion struct { // Realtime session object configuration. Realtime *RealtimeSession `json:"realtime,omitempty"` // Realtime transcription session object configuration. Transcription *TranscriptionSession `json:"transcription,omitempty"` } func (r SessionUnion) MarshalJSON() ([]byte, error) { if r.Realtime != nil { return json.Marshal(r.Realtime) } if r.Transcription != nil { return json.Marshal(r.Transcription) } return nil, errors.New("no session type") } func (r *SessionUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } var t typeStruct if err := json.Unmarshal(data, &t); err != nil { return err } switch SessionType(t.Type) { case SessionTypeRealtime, "": // Default to realtime when no type field is present (e.g. session.update events). r.Realtime = &RealtimeSession{} return json.Unmarshal(data, r.Realtime) case SessionTypeTranscription: r.Transcription = &TranscriptionSession{} return json.Unmarshal(data, r.Transcription) default: return fmt.Errorf("unknown session type: %s", t.Type) } } type ItemStatus string const ( ItemStatusInProgress ItemStatus = "in_progress" ItemStatusCompleted ItemStatus = "completed" ItemStatusIncomplete ItemStatus = "incomplete" ) type Conversation struct { // The unique ID of the conversation. ID string `json:"id"` // The object type, must be "realtime.conversation". Object string `json:"object"` } type ResponseStatus string const ( ResponseStatusInProgress ResponseStatus = "in_progress" ResponseStatusCompleted ResponseStatus = "completed" ResponseStatusCancelled ResponseStatus = "cancelled" ResponseStatusIncomplete ResponseStatus = "incomplete" ResponseStatusFailed ResponseStatus = "failed" ) type UsageType string const ( UsageTypeTokens UsageType = "tokens" UsageTypeDuration UsageType = "duration" ) type CachedTokensDetails struct { TextTokens int `json:"text_tokens"` AudioTokens int `json:"audio_tokens"` } type InputTokenDetails struct { CachedTokens int `json:"cached_tokens"` TextTokens int `json:"text_tokens"` AudioTokens int `json:"audio_tokens"` CachedTokensDetails *CachedTokensDetails `json:"cached_tokens_details,omitempty"` } type OutputTokenDetails struct { TextTokens int `json:"text_tokens"` AudioTokens int `json:"audio_tokens"` } type TokenUsage struct { TotalTokens int `json:"total_tokens"` InputTokens int `json:"input_tokens"` OutputTokens int `json:"output_tokens"` // Input token details. InputTokenDetails *InputTokenDetails `json:"input_token_details,omitempty"` // Output token details. OutputTokenDetails *OutputTokenDetails `json:"output_token_details,omitempty"` } func (u TokenUsage) UsageType() UsageType { return UsageTypeTokens } type DurationUsage struct { Seconds float64 `json:"seconds"` } func (u DurationUsage) UsageType() UsageType { return UsageTypeDuration } type UsageUnion struct { Tokens *TokenUsage `json:",omitempty"` Duration *DurationUsage `json:",omitempty"` } func (u *UsageUnion) UnmarshalJSON(data []byte) error { if isNull(data) { return nil } var t typeStruct if err := json.Unmarshal(data, &t); err != nil { return err } switch UsageType(t.Type) { case UsageTypeTokens: return json.Unmarshal(data, &u.Tokens) case UsageTypeDuration: return json.Unmarshal(data, &u.Duration) default: return fmt.Errorf("unknown usage type: %s", t.Type) } } type StatusDetail struct { Error *Error `json:"error,omitempty"` Reason string `json:"reason,omitempty"` Type string `json:"type,omitempty"` } type ResponseCreateParams struct { // Configuration for audio input and output. Audio *ResponseAudio `json:"audio,omitempty"` // Controls which conversation the response is added to. Currently supports auto and none, with auto as the default value. The auto value means that the contents of the response will be added to the default conversation. Set this to none to create an out-of-band response which will not add items to default conversation. Conversation string `json:"conversation,omitempty"` // Input items to include in the prompt for the model. Using this field creates a new context for this Response instead of using the default conversation. An empty array [] will clear the context for this Response. Note that this can include references to items that previously appeared in the session using their id. Input []MessageItemUnion `json:"input,omitempty"` // The default system instructions (i.e. system message) prepended to model calls. This field allows the client to guide the model on desired responses. The model can be instructed on response content and format, (e.g. "be extremely succinct", "act friendly", "here are examples of good responses") and on audio behavior (e.g. "talk quickly", "inject emotion into your voice", "laugh frequently"). The instructions are not guaranteed to be followed by the model, but they provide guidance to the model on the desired behavior. Note that the server sets default instructions which will be used if this field is not set and are visible in the session.created event at the start of the session. Instructions string `json:"instructions,omitempty"` // Maximum number of output tokens for a single assistant response, inclusive of tool calls. Provide an integer between 1 and 4096 to limit output tokens, or inf for the maximum available tokens for a given model. Defaults to inf. MaxOutputTokens IntOrInf `json:"max_output_tokens,omitempty"` // Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. // // Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters. Metadata map[string]string `json:"metadata,omitempty"` // The set of modalities the model used to respond, currently the only possible values are [\"audio\"], [\"text\"]. Audio output always include a text transcript. Setting the output to mode text will disable audio output from the model. OutputModalities []Modality `json:"output_modalities,omitempty"` // Reference to a prompt template and its variables. // // See https://platform.openai.com/docs/guides/text?api-mode=responses#reusable-prompts. Prompt *PromptReference `json:"prompt,omitempty"` // How the model chooses tools. Provide one of the string modes or force a specific function/MCP tool. ToolChoice *ToolChoiceUnion `json:"tool_choice,omitempty"` // Tools available to the model. Tools []ToolUnion `json:"tools,omitempty"` } type Response struct { Audio *ResponseAudio `json:"audio,omitempty"` ConversationID string `json:"conversation_id,omitempty"` // The unique ID of the response. ID string `json:"id"` MaxOutputTokens IntOrInf `json:"max_output_tokens,omitempty"` Metadata map[string]string `json:"metadata,omitempty"` // The object type, must be "realtime.response". Object string `json:"object,omitempty"` Output []MessageItemUnion `json:"output,omitempty"` OutputModalities []Modality `json:"output_modalities,omitempty"` // The status of the response. Status ResponseStatus `json:"status,omitempty"` // Additional details about the status. StatusDetails *StatusDetail `json:"status_details,omitempty"` Usage *TokenUsage `json:"usage,omitempty"` } type RateLimit struct { // The name of the rate limit ("requests", "tokens", "input_tokens", "output_tokens"). Name string `json:"name,omitempty"` // The maximum allowed value for the rate limit. Limit int `json:"limit,omitempty"` // The remaining value before the limit is reached. Remaining int `json:"remaining,omitempty"` // Seconds until the rate limit resets. ResetSeconds float64 `json:"reset_seconds,omitempty"` } ================================================ FILE: core/http/endpoints/openresponses/responses.go ================================================ package openresponses import ( "context" "encoding/json" "fmt" "strings" "time" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" mcpTools "github.com/mudler/LocalAI/core/http/endpoints/mcp" openaiEndpoint "github.com/mudler/LocalAI/core/http/endpoints/openai" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" reason "github.com/mudler/LocalAI/pkg/reasoning" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/xlog" ) // ResponsesEndpoint is the Open Responses API endpoint // https://www.openresponses.org/specification // @Summary Create a response using the Open Responses API // @Param request body schema.OpenResponsesRequest true "Request body" // @Success 200 {object} schema.ORResponseResource "Response" // @Router /v1/responses [post] func ResponsesEndpoint(cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) echo.HandlerFunc { return func(c echo.Context) error { createdAt := time.Now().Unix() responseID := fmt.Sprintf("resp_%s", uuid.New().String()) input, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) if !ok || input.Model == "" { return sendOpenResponsesError(c, 400, "invalid_request", "model is required", "") } cfg, ok := c.Get(middleware.CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return sendOpenResponsesError(c, 400, "invalid_request", "model configuration not found", "") } // Initialize store with TTL from appConfig store := GetGlobalStore() if appConfig.OpenResponsesStoreTTL > 0 { store.SetTTL(appConfig.OpenResponsesStoreTTL) } // Check if storage is disabled for this request shouldStore := true if input.Store != nil && !*input.Store { shouldStore = false } // Handle previous_response_id if provided var previousResponse *schema.ORResponseResource var messages []schema.Message if input.PreviousResponseID != "" { stored, err := store.Get(input.PreviousResponseID) if err != nil { return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("previous response not found: %s", input.PreviousResponseID), "previous_response_id") } previousResponse = stored.Response // Also convert previous response input to messages previousInputMessages, err := convertORInputToMessages(stored.Request.Input, cfg) if err != nil { return sendOpenResponsesError(c, 400, "invalid_request", fmt.Sprintf("failed to convert previous input: %v", err), "") } // Convert previous response output items to messages previousOutputMessages, err := convertOROutputItemsToMessages(previousResponse.Output) if err != nil { return sendOpenResponsesError(c, 400, "invalid_request", fmt.Sprintf("failed to convert previous response: %v", err), "") } // Concatenate: previous_input + previous_output + new_input // Start with previous input messages messages = previousInputMessages // Add previous output as assistant messages messages = append(messages, previousOutputMessages...) } // Convert Open Responses input to internal Messages newMessages, err := convertORInputToMessages(input.Input, cfg) if err != nil { return sendOpenResponsesError(c, 400, "invalid_request", fmt.Sprintf("failed to parse input: %v", err), "") } // Append new input messages messages = append(messages, newMessages...) // Add instructions as system message if provided if input.Instructions != "" { messages = append([]schema.Message{{Role: "system", StringContent: input.Instructions}}, messages...) } // Handle tools var funcs functions.Functions var shouldUseFn bool var mcpToolInfos []mcpTools.MCPToolInfo if len(input.Tools) > 0 { funcs, shouldUseFn = convertORToolsToFunctions(input, cfg) } // MCP injection: prompts, resources, and tools mcpServers := mcpTools.MCPServersFromMetadata(input.Metadata) mcpPromptName, mcpPromptArgs := mcpTools.MCPPromptFromMetadata(input.Metadata) mcpResourceURIs := mcpTools.MCPResourcesFromMetadata(input.Metadata) hasMCPRequest := len(mcpServers) > 0 || mcpPromptName != "" || len(mcpResourceURIs) > 0 hasMCPConfig := cfg.MCP.Servers != "" || cfg.MCP.Stdio != "" if hasMCPRequest && hasMCPConfig { remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML() if mcpErr == nil { namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, mcpServers) if sessErr == nil && len(namedSessions) > 0 { // Prompt injection if mcpPromptName != "" { prompts, discErr := mcpTools.DiscoverMCPPrompts(c.Request().Context(), namedSessions) if discErr == nil { promptMsgs, getErr := mcpTools.GetMCPPrompt(c.Request().Context(), prompts, mcpPromptName, mcpPromptArgs) if getErr == nil { var injected []schema.Message for _, pm := range promptMsgs { injected = append(injected, schema.Message{ Role: string(pm.Role), Content: mcpTools.PromptMessageToText(pm), }) } messages = append(injected, messages...) xlog.Debug("Open Responses MCP prompt injected", "prompt", mcpPromptName, "messages", len(injected)) } else { xlog.Error("Failed to get MCP prompt", "error", getErr) } } } // Resource injection if len(mcpResourceURIs) > 0 { resources, discErr := mcpTools.DiscoverMCPResources(c.Request().Context(), namedSessions) if discErr == nil { var resourceTexts []string for _, uri := range mcpResourceURIs { content, readErr := mcpTools.ReadMCPResource(c.Request().Context(), resources, uri) if readErr != nil { xlog.Error("Failed to read MCP resource", "error", readErr, "uri", uri) continue } name := uri for _, r := range resources { if r.URI == uri { name = r.Name break } } resourceTexts = append(resourceTexts, fmt.Sprintf("--- MCP Resource: %s ---\n%s", name, content)) } if len(resourceTexts) > 0 && len(messages) > 0 { lastIdx := len(messages) - 1 suffix := "\n\n" + strings.Join(resourceTexts, "\n\n") switch ct := messages[lastIdx].Content.(type) { case string: messages[lastIdx].Content = ct + suffix default: messages[lastIdx].Content = fmt.Sprintf("%v%s", ct, suffix) } xlog.Debug("Open Responses MCP resources injected", "count", len(resourceTexts)) } } } // Tool injection if len(mcpServers) > 0 { discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions) if discErr == nil { mcpToolInfos = discovered for _, ti := range mcpToolInfos { funcs = append(funcs, ti.Function) input.Tools = append(input.Tools, schema.ORFunctionTool{ Type: "function", Name: ti.Function.Name, Description: ti.Function.Description, Parameters: ti.Function.Parameters, }) } shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions() xlog.Debug("Open Responses MCP tools injected", "count", len(mcpToolInfos), "total_funcs", len(funcs)) } else { xlog.Error("Failed to discover MCP tools", "error", discErr) } } } } else { xlog.Error("Failed to parse MCP config", "error", mcpErr) } } else if len(input.Tools) == 0 && hasMCPConfig { // Backward compat: model has MCP config, no user tools and no mcp_servers field remote, stdio, mcpErr := cfg.MCP.MCPConfigFromYAML() if mcpErr == nil { namedSessions, sessErr := mcpTools.NamedSessionsFromMCPConfig(cfg.Name, remote, stdio, nil) if sessErr == nil && len(namedSessions) > 0 { discovered, discErr := mcpTools.DiscoverMCPTools(c.Request().Context(), namedSessions) if discErr == nil { mcpToolInfos = discovered for _, ti := range mcpToolInfos { funcs = append(funcs, ti.Function) input.Tools = append(input.Tools, schema.ORFunctionTool{ Type: "function", Name: ti.Function.Name, Description: ti.Function.Description, Parameters: ti.Function.Parameters, }) } shouldUseFn = len(funcs) > 0 && cfg.ShouldUseFunctions() xlog.Debug("Open Responses MCP tools auto-activated", "count", len(mcpToolInfos)) } } } } // Create OpenAI-compatible request for internal processing openAIReq := &schema.OpenAIRequest{ PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{Model: input.Model}, Temperature: input.Temperature, TopP: input.TopP, Maxtokens: input.MaxOutputTokens, }, Messages: messages, Stream: input.Stream, Context: input.Context, Cancel: input.Cancel, Functions: funcs, } // Handle text_format -> response_format conversion if input.TextFormat != nil { openAIReq.ResponseFormat = convertTextFormatToResponseFormat(input.TextFormat) } // Generate grammar for function calling (similar to OpenAI chat endpoint) if shouldUseFn && !cfg.FunctionsConfig.GrammarConfig.NoGrammar { // Add no-action function to allow model to respond without calling a tool noActionName := "answer" noActionDescription := "use this action to answer without performing any action" if cfg.FunctionsConfig.NoActionFunctionName != "" { noActionName = cfg.FunctionsConfig.NoActionFunctionName } if cfg.FunctionsConfig.NoActionDescriptionName != "" { noActionDescription = cfg.FunctionsConfig.NoActionDescriptionName } noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, Parameters: map[string]interface{}{ "properties": map[string]interface{}{ "message": map[string]interface{}{ "type": "string", "description": "The message to reply the user with", }, }, }, } // Make a copy of funcs to avoid modifying the original funcsWithNoAction := make(functions.Functions, len(funcs)) copy(funcsWithNoAction, funcs) // Append no-action function unless disabled if !cfg.FunctionsConfig.DisableNoAction { funcsWithNoAction = append(funcsWithNoAction, noActionGrammar) } // Force picking one of the functions by the request if cfg.FunctionToCall() != "" { funcsWithNoAction = funcsWithNoAction.Select(cfg.FunctionToCall()) } // Generate grammar to constrain model output to valid function calls jsStruct := funcsWithNoAction.ToJSONStructure(cfg.FunctionsConfig.FunctionNameKey, cfg.FunctionsConfig.FunctionNameKey) g, err := jsStruct.Grammar(cfg.FunctionsConfig.GrammarOptions()...) if err == nil { cfg.Grammar = g xlog.Debug("Open Responses - Generated grammar for function calling") } else { xlog.Error("Open Responses - Failed generating grammar for function calling", "error", err) } } // Template the prompt predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) xlog.Debug("Open Responses - Prompt (after templating)", "prompt", predInput) // Handle background mode isBackground := input.Background != nil && *input.Background if isBackground { // Background mode requires storage if !shouldStore { return sendOpenResponsesError(c, 400, "invalid_request_error", "background=true requires store=true", "background") } // Create initial response with "queued" status queuedResponse := buildORResponse(responseID, createdAt, nil, schema.ORStatusQueued, input, []schema.ORItemField{}, nil, true) // Create cancellable context for background execution bgCtx, bgCancel := context.WithCancel(context.Background()) // Store the background response store.StoreBackground(responseID, input, queuedResponse, bgCancel, input.Stream) // Start background processing goroutine go func() { defer bgCancel() // Update status to in_progress store.UpdateStatus(responseID, schema.ORStatusInProgress, nil) var finalResponse *schema.ORResponseResource var bgErr error if input.Stream { // Background streaming processing (buffer events) finalResponse, bgErr = handleBackgroundStream(bgCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) } else { // Background non-streaming processing finalResponse, bgErr = handleBackgroundNonStream(bgCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, mcpToolInfos, evaluator) } if bgErr != nil { xlog.Error("Background response failed", "response_id", responseID, "error", bgErr) now := time.Now().Unix() store.UpdateStatus(responseID, schema.ORStatusFailed, &now) return } // Update final response in store if finalResponse != nil { store.UpdateResponse(responseID, finalResponse) } }() // Return immediately with queued response return c.JSON(200, queuedResponse) } if input.Stream { return handleOpenResponsesStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore, mcpToolInfos, evaluator) } return handleOpenResponsesNonStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore, mcpToolInfos, evaluator, 0) } } // convertORInputToMessages converts Open Responses input to internal Messages func convertORInputToMessages(input interface{}, cfg *config.ModelConfig) ([]schema.Message, error) { var messages []schema.Message switch v := input.(type) { case string: // Simple string = user message return []schema.Message{{Role: "user", StringContent: v}}, nil case []interface{}: // Array of items for _, itemRaw := range v { itemMap, ok := itemRaw.(map[string]interface{}) if !ok { continue } itemType, _ := itemMap["type"].(string) switch itemType { case "message": msg, err := convertORMessageItem(itemMap, cfg) if err != nil { return nil, err } messages = append(messages, msg) case "reasoning": msg, err := convertORReasoningItemToMessage(itemMap) if err != nil { return nil, err } messages = append(messages, msg) case "function_call": msg, err := convertORFunctionCallItemToMessage(itemMap) if err != nil { return nil, err } messages = append(messages, msg) case "function_call_output": // Convert function call output to tool role message callID, _ := itemMap["call_id"].(string) output := itemMap["output"] var outputStr string if str, ok := output.(string); ok { outputStr = str } else { // Convert to JSON string outputBytes, _ := json.Marshal(output) outputStr = string(outputBytes) } // For tool messages, we use the Name field to store the call ID messages = append(messages, schema.Message{ Role: "tool", Name: callID, Content: outputStr, StringContent: outputStr, }) case "item_reference": // Handle item references - look up item in stored responses // According to spec, item_reference uses "id" field, not "item_id" itemID, ok := itemMap["id"].(string) if !ok || itemID == "" { return nil, fmt.Errorf("item_reference missing id") } store := GetGlobalStore() item, responseID, err := store.FindItem(itemID) if err != nil { return nil, fmt.Errorf("item not found: %s (from response %s): %w", itemID, responseID, err) } // Log item reference resolution for debugging xlog.Debug("Resolved item reference", "item_id", itemID, "response_id", responseID, "item_type", item.Type) // Convert referenced item to message based on its type msg, err := convertORItemToMessage(item, responseID) if err != nil { return nil, fmt.Errorf("failed to convert referenced item %s from response %s: %w", itemID, responseID, err) } messages = append(messages, msg) } } return mergeContiguousAssistantMessages(messages), nil default: return nil, fmt.Errorf("unsupported input type: %T", input) } } // convertORReasoningItemToMessage converts an Open Responses reasoning item to an assistant Message fragment (for merging). func convertORReasoningItemToMessage(itemMap map[string]interface{}) (schema.Message, error) { var reasoning string if content := itemMap["content"]; content != nil { if s, ok := content.(string); ok { reasoning = s } else if parts, ok := content.([]interface{}); ok { for _, p := range parts { if partMap, ok := p.(map[string]interface{}); ok { if t, _ := partMap["type"].(string); (t == "output_text" || t == "input_text") && partMap["text"] != nil { if tStr, ok := partMap["text"].(string); ok { reasoning += tStr } } } } } } return schema.Message{Role: "assistant", Reasoning: stringPtr(reasoning)}, nil } // convertORFunctionCallItemToMessage converts an Open Responses function_call item to an assistant Message fragment (for merging). func convertORFunctionCallItemToMessage(itemMap map[string]interface{}) (schema.Message, error) { callID, _ := itemMap["call_id"].(string) name, _ := itemMap["name"].(string) arguments, _ := itemMap["arguments"].(string) if callID == "" { callID = fmt.Sprintf("call_%s", name) } return schema.Message{ Role: "assistant", ToolCalls: []schema.ToolCall{{ Index: 0, ID: callID, Type: "function", FunctionCall: schema.FunctionCall{Name: name, Arguments: arguments}, }}, }, nil } func stringPtr(s string) *string { if s == "" { return nil } return &s } // convertORItemToMessage converts a single ORItemField to a Message // responseID is the ID of the response where this item was found (for logging/debugging) func convertORItemToMessage(item *schema.ORItemField, responseID string) (schema.Message, error) { switch item.Type { case "message": // Convert message item to message var textContent string if contentParts, ok := item.Content.([]schema.ORContentPart); ok { for _, part := range contentParts { if part.Type == "output_text" || part.Type == "input_text" { textContent += part.Text } } } else if str, ok := item.Content.(string); ok { textContent = str } return schema.Message{ Role: item.Role, StringContent: textContent, Content: textContent, }, nil case "function_call_output": // Convert function call output to tool role message var outputStr string if str, ok := item.Output.(string); ok { outputStr = str } else { // Convert to JSON string outputBytes, _ := json.Marshal(item.Output) outputStr = string(outputBytes) } return schema.Message{ Role: "tool", Name: item.CallID, Content: outputStr, StringContent: outputStr, }, nil case "reasoning": reasoning := extractReasoningContentFromORItem(item) return schema.Message{Role: "assistant", Reasoning: stringPtr(reasoning)}, nil case "function_call": callID := item.CallID if callID == "" { callID = fmt.Sprintf("call_%s", item.Name) } return schema.Message{ Role: "assistant", ToolCalls: []schema.ToolCall{{ Index: 0, ID: callID, Type: "function", FunctionCall: schema.FunctionCall{Name: item.Name, Arguments: item.Arguments}, }}, }, nil default: return schema.Message{}, fmt.Errorf("unsupported item type for conversion: %s (from response %s)", item.Type, responseID) } } func extractReasoningContentFromORItem(item *schema.ORItemField) string { if contentParts, ok := item.Content.([]schema.ORContentPart); ok { var s string for _, part := range contentParts { if part.Type == "output_text" || part.Type == "input_text" { s += part.Text } } return s } if s, ok := item.Content.(string); ok { return s } return "" } // convertOROutputItemsToMessages converts Open Responses output items to internal Messages. // Contiguous assistant items (message, reasoning, function_call) are merged into a single message. func convertOROutputItemsToMessages(outputItems []schema.ORItemField) ([]schema.Message, error) { var messages []schema.Message for _, item := range outputItems { switch item.Type { case "message": var textContent string if contentParts, ok := item.Content.([]schema.ORContentPart); ok && len(contentParts) > 0 { for _, part := range contentParts { if part.Type == "output_text" { textContent += part.Text } } } messages = append(messages, schema.Message{ Role: item.Role, StringContent: textContent, Content: textContent, }) case "reasoning": reasoning := extractReasoningContentFromORItem(&item) messages = append(messages, schema.Message{Role: "assistant", Reasoning: stringPtr(reasoning)}) case "function_call": msg := schema.Message{ Role: "assistant", ToolCalls: []schema.ToolCall{{ Index: 0, ID: item.CallID, Type: "function", FunctionCall: schema.FunctionCall{Name: item.Name, Arguments: item.Arguments}, }}, } if msg.ToolCalls[0].ID == "" { msg.ToolCalls[0].ID = fmt.Sprintf("call_%s", item.Name) } messages = append(messages, msg) case "function_call_output": // Convert function call output to tool role message var outputStr string if str, ok := item.Output.(string); ok { outputStr = str } else { // Convert to JSON string outputBytes, _ := json.Marshal(item.Output) outputStr = string(outputBytes) } messages = append(messages, schema.Message{ Role: "tool", Name: item.CallID, Content: outputStr, StringContent: outputStr, }) } } return mergeContiguousAssistantMessages(messages), nil } // mergeContiguousAssistantMessages merges contiguous assistant messages into one. // Many chat templates expect content, reasoning, and tool calls in a single assistant message // (see e.g. llama.cpp PR 19773). This avoids creating separate messages per input item. func mergeContiguousAssistantMessages(messages []schema.Message) []schema.Message { if len(messages) == 0 { return messages } var out []schema.Message var acc *schema.Message for i := range messages { m := &messages[i] if m.Role != "assistant" { flushAssistantAccumulator(&out, &acc) out = append(out, *m) continue } if acc == nil { acc = &schema.Message{Role: "assistant"} } if m.StringContent != "" { if acc.StringContent != "" { acc.StringContent += "\n" + m.StringContent } else { acc.StringContent = m.StringContent } if acc.Content == nil { acc.Content = m.Content } else if _, ok := m.Content.(string); ok { acc.Content = acc.StringContent } } if m.Reasoning != nil && *m.Reasoning != "" { if acc.Reasoning == nil { acc.Reasoning = m.Reasoning } else { combined := *acc.Reasoning + "\n" + *m.Reasoning acc.Reasoning = &combined } } if len(m.ToolCalls) > 0 { acc.ToolCalls = append(acc.ToolCalls, m.ToolCalls...) } } flushAssistantAccumulator(&out, &acc) return out } func flushAssistantAccumulator(out *[]schema.Message, acc **schema.Message) { if acc == nil || *acc == nil { return } m := *acc if m.StringContent == "" && (m.Reasoning == nil || *m.Reasoning == "") && len(m.ToolCalls) == 0 { *acc = nil return } if m.Content == nil { m.Content = m.StringContent } // Re-index tool calls after merge (each may have been 0) for i := range m.ToolCalls { m.ToolCalls[i].Index = i } *out = append(*out, *m) *acc = nil } // convertORMessageItem converts an Open Responses message item to internal Message func convertORMessageItem(itemMap map[string]interface{}, cfg *config.ModelConfig) (schema.Message, error) { role, _ := itemMap["role"].(string) msg := schema.Message{Role: role} content := itemMap["content"] switch contentVal := content.(type) { case string: msg.StringContent = contentVal msg.Content = contentVal case []interface{}: // Array of content parts var textContent string var stringImages []string var stringVideos []string var stringAudios []string for _, partRaw := range contentVal { partMap, ok := partRaw.(map[string]interface{}) if !ok { continue } partType, _ := partMap["type"].(string) switch partType { case "input_text": if text, ok := partMap["text"].(string); ok { textContent += text } case "input_image": if imageURL, ok := partMap["image_url"].(string); ok { // Convert to base64 data URI base64, err := utils.GetContentURIAsBase64(imageURL) if err != nil { xlog.Error("Failed encoding image", "error", err) continue } stringImages = append(stringImages, base64) } case "input_file": if fileURL, ok := partMap["file_url"].(string); ok { // Convert to base64 base64, err := utils.GetContentURIAsBase64(fileURL) if err != nil { xlog.Error("Failed encoding file", "error", err) continue } // For now, treat files as text content textContent += base64 } else if fileData, ok := partMap["file_data"].(string); ok { // Already base64 textContent += fileData } case "input_video": if videoURL, ok := partMap["video_url"].(string); ok { // Convert to base64 data URI base64, err := utils.GetContentURIAsBase64(videoURL) if err != nil { xlog.Error("Failed encoding video", "error", err) continue } stringVideos = append(stringVideos, base64) } case "input_audio": if audioURL, ok := partMap["audio_url"].(string); ok { // Convert to base64 data URI base64, err := utils.GetContentURIAsBase64(audioURL) if err != nil { xlog.Error("Failed encoding audio", "error", err) continue } stringAudios = append(stringAudios, base64) } } } msg.StringContent = textContent msg.Content = textContent msg.StringImages = stringImages msg.StringVideos = stringVideos msg.StringAudios = stringAudios // Template multimodal content if len(stringImages) > 0 || len(stringVideos) > 0 || len(stringAudios) > 0 { msg.StringContent, _ = templates.TemplateMultiModal(cfg.TemplateConfig.Multimodal, templates.MultiModalOptions{ TotalImages: len(stringImages), TotalVideos: len(stringVideos), TotalAudios: len(stringAudios), ImagesInMessage: len(stringImages), VideosInMessage: len(stringVideos), AudiosInMessage: len(stringAudios), }, textContent) } } return msg, nil } // convertORToolsToFunctions converts Open Responses tools to internal Functions func convertORToolsToFunctions(input *schema.OpenResponsesRequest, cfg *config.ModelConfig) (functions.Functions, bool) { if len(input.Tools) == 0 { return nil, false } // Build allowed tools set if specified allowedSet := make(map[string]bool) if len(input.AllowedTools) > 0 { for _, name := range input.AllowedTools { allowedSet[name] = true } } var funcs functions.Functions for _, tool := range input.Tools { if tool.Type == "function" { // Skip if not in allowed list (when allowed_tools is specified) if len(allowedSet) > 0 && !allowedSet[tool.Name] { continue } f := functions.Function{ Name: tool.Name, Description: tool.Description, Parameters: tool.Parameters, } funcs = append(funcs, f) } } // Handle tool_choice if input.ToolChoice != nil { switch tc := input.ToolChoice.(type) { case string: switch tc { case "required": cfg.SetFunctionCallString("required") case "none": return nil, false case "auto": // "auto" is the default - let model decide whether to use tools // Tools are available but not forced } case map[string]interface{}: if tcType, ok := tc["type"].(string); ok && tcType == "function" { if name, ok := tc["name"].(string); ok { cfg.SetFunctionCallString(name) } } } } return funcs, len(funcs) > 0 && cfg.ShouldUseFunctions() } // convertTextFormatToResponseFormat converts Open Responses text_format to OpenAI response_format func convertTextFormatToResponseFormat(textFormat interface{}) interface{} { switch tf := textFormat.(type) { case map[string]interface{}: if tfType, ok := tf["type"].(string); ok { if tfType == "json_schema" { return map[string]interface{}{ "type": "json_schema", "json_schema": tf, } } return map[string]interface{}{"type": tfType} } case string: return map[string]interface{}{"type": tf} } return nil } // handleBackgroundNonStream handles background non-streaming responses func handleBackgroundNonStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) { mcpMaxIterations := 10 if cfg.Agent.MaxIterations > 0 { mcpMaxIterations = cfg.Agent.MaxIterations } hasMCPTools := len(mcpToolInfos) > 0 var allOutputItems []schema.ORItemField for mcpIteration := 0; mcpIteration <= mcpMaxIterations; mcpIteration++ { if mcpIteration > 0 { predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) xlog.Debug("Background MCP re-templating", "iteration", mcpIteration) } // Populate openAIReq fields for ComputeChoices openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools) openAIReq.ToolsChoice = input.ToolChoice if input.TopLogprobs != nil && *input.TopLogprobs > 0 { openAIReq.TopLogprobs = input.TopLogprobs openAIReq.Logprobs = schema.LogprobsValue{Enabled: true} } openAIReq.LogitBias = input.LogitBias select { case <-ctx.Done(): return nil, ctx.Err() default: } var result string cb := func(s string, c *[]schema.Choice) { result = s } choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil) if err != nil { return nil, fmt.Errorf("model inference failed: %w", err) } // Extract logprobs from choices if available var resultLogprobs *schema.Logprobs if len(choices) > 0 { resultLogprobs = choices[0].Logprobs } // Parse tool calls var funcCallResults []functions.FuncCallResults var textContent string if shouldUseFn { if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { funcCallResults = deltaToolCalls textContent = functions.ContentFromChatDeltas(chatDeltas) } else { cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig) funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig) } noActionName := "answer" if cfg.FunctionsConfig.NoActionFunctionName != "" { noActionName = cfg.FunctionsConfig.NoActionFunctionName } var toolCalls []schema.ToolCall for i, fc := range funcCallResults { if fc.Name == noActionName { if fc.Arguments != "" { var args map[string]interface{} if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { if msg, ok := args["message"].(string); ok && msg != "" { textContent = msg } } } continue } toolCalls = append(toolCalls, schema.ToolCall{ Index: i, ID: fmt.Sprintf("fc_%s", uuid.New().String()), Type: "function", FunctionCall: schema.FunctionCall{ Name: fc.Name, Arguments: fc.Arguments, }, }) } // MCP tool execution if hasMCPTools && len(toolCalls) > 0 { var hasMCPCalls bool for _, tc := range toolCalls { if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { hasMCPCalls = true break } } if hasMCPCalls { assistantMsg := schema.Message{Role: "assistant", Content: result, ToolCalls: toolCalls} openAIReq.Messages = append(openAIReq.Messages, assistantMsg) for _, tc := range toolCalls { // Emit function_call + function_call_output items allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "function_call", ID: fmt.Sprintf("fc_%s", uuid.New().String()), Status: "completed", CallID: tc.ID, Name: tc.FunctionCall.Name, Arguments: tc.FunctionCall.Arguments, }) if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { continue } toolResult, toolErr := mcpTools.ExecuteMCPToolCall(ctx, mcpToolInfos, tc.FunctionCall.Name, tc.FunctionCall.Arguments) if toolErr != nil { toolResult = fmt.Sprintf("Error: %v", toolErr) } openAIReq.Messages = append(openAIReq.Messages, schema.Message{ Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tc.ID, Name: tc.FunctionCall.Name, }) allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "function_call_output", ID: fmt.Sprintf("fco_%s", uuid.New().String()), Status: "completed", CallID: tc.ID, Output: toolResult, }) } continue // next MCP iteration } } // No MCP calls, build output items if textContent != "" { allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)}, }) } for _, tc := range toolCalls { allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "function_call", ID: fmt.Sprintf("fc_%s", uuid.New().String()), Status: "completed", CallID: tc.ID, Name: tc.FunctionCall.Name, Arguments: tc.FunctionCall.Arguments, }) } if len(allOutputItems) == 0 && result != "" { allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)}, }) } } else { allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, resultLogprobs)}, }) } now := time.Now().Unix() return buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, allOutputItems, &schema.ORUsage{ InputTokens: tokenUsage.Prompt, OutputTokens: tokenUsage.Completion, TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, }, true), nil } // end MCP iteration loop return nil, fmt.Errorf("MCP iteration limit reached") } // handleBackgroundStream handles background streaming responses with event buffering func handleBackgroundStream(ctx context.Context, store *ResponseStore, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) (*schema.ORResponseResource, error) { // Populate openAIReq fields for ComputeChoices openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools) openAIReq.ToolsChoice = input.ToolChoice if input.TopLogprobs != nil && *input.TopLogprobs > 0 { openAIReq.TopLogprobs = input.TopLogprobs openAIReq.Logprobs = schema.LogprobsValue{Enabled: true} } openAIReq.LogitBias = input.LogitBias sequenceNumber := 0 // Emit response.created responseCreated := buildORResponse(responseID, createdAt, nil, schema.ORStatusInProgress, input, []schema.ORItemField{}, nil, true) bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.created", SequenceNumber: sequenceNumber, Response: responseCreated, }) sequenceNumber++ // Emit response.in_progress bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.in_progress", SequenceNumber: sequenceNumber, Response: responseCreated, }) sequenceNumber++ var accumulatedText string var collectedOutputItems []schema.ORItemField outputIndex := 0 mcpBgStreamMaxIterations := 10 if cfg.Agent.MaxIterations > 0 { mcpBgStreamMaxIterations = cfg.Agent.MaxIterations } hasMCPTools := len(mcpToolInfos) > 0 var lastTokenUsage backend.TokenUsage var lastLogprobs *schema.Logprobs for mcpIter := 0; mcpIter <= mcpBgStreamMaxIterations; mcpIter++ { if mcpIter > 0 { predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) xlog.Debug("Background stream MCP re-templating", "iteration", mcpIter) } accumulatedText = "" currentMessageID := fmt.Sprintf("msg_%s", uuid.New().String()) // Emit output_item.added messageItem := &schema.ORItemField{ Type: "message", ID: currentMessageID, Status: "in_progress", Role: "assistant", Content: []schema.ORContentPart{}, } bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: messageItem, }) sequenceNumber++ // Emit content_part.added currentContentIndex := 0 emptyPart := makeOutputTextPart("") bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.content_part.added", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Part: &emptyPart, }) sequenceNumber++ // Token callback for streaming tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool { select { case <-ctx.Done(): return false default: } accumulatedText += token // Buffer text delta bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_text.delta", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Delta: strPtr(token), Logprobs: emptyLogprobs(), }) sequenceNumber++ return true } var result string cb := func(s string, c *[]schema.Choice) { result = s } choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, tokenCallback) if err != nil { return nil, fmt.Errorf("model inference failed: %w", err) } lastTokenUsage = tokenUsage if len(choices) > 0 { lastLogprobs = choices[0].Logprobs } // Check for MCP tool calls in the streamed result if shouldUseFn && hasMCPTools { var funcCallResults []functions.FuncCallResults if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { funcCallResults = deltaToolCalls } else { cleanedResult := functions.CleanupLLMResult(result, cfg.FunctionsConfig) funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) } noActionName := "answer" if cfg.FunctionsConfig.NoActionFunctionName != "" { noActionName = cfg.FunctionsConfig.NoActionFunctionName } var toolCalls []schema.ToolCall for i, fc := range funcCallResults { if fc.Name == noActionName { continue } toolCalls = append(toolCalls, schema.ToolCall{ Index: i, ID: fmt.Sprintf("fc_%s", uuid.New().String()), Type: "function", FunctionCall: schema.FunctionCall{Name: fc.Name, Arguments: fc.Arguments}, }) } var hasMCPCalls bool for _, tc := range toolCalls { if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { hasMCPCalls = true break } } if hasMCPCalls { // Close the current message bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_text.done", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Text: strPtr(accumulatedText), Logprobs: emptyLogprobs(), }) sequenceNumber++ textPart := makeOutputTextPart(accumulatedText) bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Part: &textPart, }) sequenceNumber++ completedMsg := &schema.ORItemField{ Type: "message", ID: currentMessageID, Status: "completed", Role: "assistant", Content: []schema.ORContentPart{textPart}, } bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: completedMsg, }) sequenceNumber++ collectedOutputItems = append(collectedOutputItems, *completedMsg) // Append assistant message with tool calls assistantMsg := schema.Message{Role: "assistant", Content: result, ToolCalls: toolCalls} openAIReq.Messages = append(openAIReq.Messages, assistantMsg) // Execute MCP tools and emit events for _, tc := range toolCalls { outputIndex++ functionCallItem := &schema.ORItemField{ Type: "function_call", ID: tc.ID, Status: "completed", CallID: tc.ID, Name: tc.FunctionCall.Name, Arguments: tc.FunctionCall.Arguments, } bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ collectedOutputItems = append(collectedOutputItems, *functionCallItem) if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool (background stream)", "tool", tc.FunctionCall.Name, "iteration", mcpIter) toolResult, toolErr := mcpTools.ExecuteMCPToolCall(ctx, mcpToolInfos, tc.FunctionCall.Name, tc.FunctionCall.Arguments) if toolErr != nil { toolResult = fmt.Sprintf("Error: %v", toolErr) } openAIReq.Messages = append(openAIReq.Messages, schema.Message{ Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tc.ID, Name: tc.FunctionCall.Name, }) outputIndex++ outputItem := &schema.ORItemField{ Type: "function_call_output", ID: fmt.Sprintf("fco_%s", uuid.New().String()), Status: "completed", CallID: tc.ID, Output: toolResult, } bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: outputItem, }) sequenceNumber++ bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: outputItem, }) sequenceNumber++ collectedOutputItems = append(collectedOutputItems, *outputItem) } continue // next MCP iteration } } // No MCP tools — close the message and break streamEventLogprobs := convertLogprobsForStreaming(lastLogprobs) bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_text.done", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Text: strPtr(accumulatedText), Logprobs: logprobsPtr(streamEventLogprobs), }) sequenceNumber++ textPart := makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs) bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Part: &textPart, }) sequenceNumber++ completedMessageItem := &schema.ORItemField{ Type: "message", ID: currentMessageID, Status: "completed", Role: "assistant", Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(accumulatedText, lastLogprobs)}, } bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: completedMessageItem, }) sequenceNumber++ collectedOutputItems = append(collectedOutputItems, *completedMessageItem) break } // end MCP background stream iteration loop // Build final response now := time.Now().Unix() response := buildORResponse(responseID, createdAt, &now, schema.ORStatusCompleted, input, collectedOutputItems, &schema.ORUsage{ InputTokens: lastTokenUsage.Prompt, OutputTokens: lastTokenUsage.Completion, TotalTokens: lastTokenUsage.Prompt + lastTokenUsage.Completion, }, true) // Emit response.completed bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.completed", SequenceNumber: sequenceNumber, Response: response, }) return response, nil } // bufferEvent stores an SSE event in the response store for streaming resume func bufferEvent(store *ResponseStore, responseID string, event *schema.ORStreamEvent) { normalizeORStreamEvent(event) if err := store.AppendEvent(responseID, event); err != nil { xlog.Error("Failed to buffer event", "response_id", responseID, "error", err) } } // handleOpenResponsesNonStream handles non-streaming responses func handleOpenResponsesNonStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, shouldStore bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator, mcpIteration int) error { mcpMaxIterations := 10 if cfg.Agent.MaxIterations > 0 { mcpMaxIterations = cfg.Agent.MaxIterations } if mcpIteration > mcpMaxIterations { return sendOpenResponsesError(c, 500, "server_error", "MCP iteration limit reached", "") } // Populate openAIReq fields for ComputeChoices openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools) openAIReq.ToolsChoice = input.ToolChoice if input.TopLogprobs != nil && *input.TopLogprobs > 0 { openAIReq.TopLogprobs = input.TopLogprobs openAIReq.Logprobs = schema.LogprobsValue{Enabled: true} } openAIReq.LogitBias = input.LogitBias var result string cb := func(s string, c *[]schema.Choice) { result = s } choices, tokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, cb, nil) if err != nil { xlog.Error("Open Responses model inference failed", "error", err) return sendOpenResponsesError(c, 500, "model_error", fmt.Sprintf("model inference failed: %v", err), "") } var resultLogprobs *schema.Logprobs if len(choices) > 0 { resultLogprobs = choices[0].Logprobs } xlog.Debug("Open Responses - Raw model result", "result", result, "shouldUseFn", shouldUseFn) // Detect if thinking token is already in prompt or template var template string if cfg.TemplateConfig.UseTokenizerTemplate { template = cfg.GetModelTemplate() } else { template = predInput } thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig) // Extract reasoning from result before cleaning reasoningContent, cleanedResult := reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig) // Parse tool calls if using functions var outputItems []schema.ORItemField var toolCalls []schema.ToolCall // Add reasoning item if reasoning was found (reasoning comes first per spec) if reasoningContent != "" { reasoningItem := schema.ORItemField{ Type: "reasoning", ID: fmt.Sprintf("reasoning_%s", uuid.New().String()), Status: "completed", Content: []schema.ORContentPart{makeOutputTextPart(reasoningContent)}, } outputItems = append(outputItems, reasoningItem) xlog.Debug("Open Responses - Extracted reasoning", "reasoning_length", len(reasoningContent)) } if shouldUseFn { var funcCallResults []functions.FuncCallResults var textContent string // Try pre-parsed tool calls from C++ autoparser first if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { xlog.Debug("[ChatDeltas] OpenResponses: using pre-parsed tool calls", "count", len(deltaToolCalls)) funcCallResults = deltaToolCalls textContent = functions.ContentFromChatDeltas(chatDeltas) } else { xlog.Debug("[ChatDeltas] OpenResponses: no pre-parsed tool calls, falling back to Go-side text parsing") // Clean up the result (already extracted reasoning above) cleanedResult = functions.CleanupLLMResult(cleanedResult, cfg.FunctionsConfig) funcCallResults = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig) } xlog.Debug("[ChatDeltas] OpenResponses: final tool call decision", "count", len(funcCallResults), "textContent", textContent) // Check for noAction function (model chose to respond without tool) noActionName := "answer" if cfg.FunctionsConfig.NoActionFunctionName != "" { noActionName = cfg.FunctionsConfig.NoActionFunctionName } // Filter out noAction calls and extract the message for i, fc := range funcCallResults { if fc.Name == noActionName { // This is a text response, not a tool call // Try to extract the message from the arguments if fc.Arguments != "" { var args map[string]interface{} if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { if msg, ok := args["message"].(string); ok && msg != "" { textContent = msg } } } continue } toolCalls = append(toolCalls, schema.ToolCall{ Index: i, ID: fmt.Sprintf("fc_%s", uuid.New().String()), Type: "function", FunctionCall: schema.FunctionCall{ Name: fc.Name, Arguments: fc.Arguments, }, }) } // MCP server-side tool execution: if any tool calls are MCP tools, execute and re-run if len(mcpToolInfos) > 0 && len(toolCalls) > 0 { var hasMCPCalls bool for _, tc := range toolCalls { if mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { hasMCPCalls = true break } } if hasMCPCalls { // Append assistant message with tool_calls to conversation assistantMsg := schema.Message{Role: "assistant", Content: result, ToolCalls: toolCalls} openAIReq.Messages = append(openAIReq.Messages, assistantMsg) // Execute each MCP tool call and append results for _, tc := range toolCalls { if !mcpTools.IsMCPTool(mcpToolInfos, tc.FunctionCall.Name) { continue } xlog.Debug("Executing MCP tool (Open Responses)", "tool", tc.FunctionCall.Name) toolResult, toolErr := mcpTools.ExecuteMCPToolCall( c.Request().Context(), mcpToolInfos, tc.FunctionCall.Name, tc.FunctionCall.Arguments, ) if toolErr != nil { xlog.Error("MCP tool execution failed", "tool", tc.FunctionCall.Name, "error", toolErr) toolResult = fmt.Sprintf("Error: %v", toolErr) } openAIReq.Messages = append(openAIReq.Messages, schema.Message{ Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tc.ID, Name: tc.FunctionCall.Name, }) // Collect function_call + function_call_output items for the response outputItems = append(outputItems, schema.ORItemField{ Type: "function_call", ID: fmt.Sprintf("fc_%s", uuid.New().String()), Status: "completed", CallID: tc.ID, Name: tc.FunctionCall.Name, Arguments: tc.FunctionCall.Arguments, }) outputItems = append(outputItems, schema.ORItemField{ Type: "function_call_output", ID: fmt.Sprintf("fco_%s", uuid.New().String()), Status: "completed", CallID: tc.ID, Output: toolResult, }) } // Re-template and re-run inference predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) return handleOpenResponsesNonStream(c, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, shouldStore, mcpToolInfos, evaluator, mcpIteration+1) } } // Add message item with text content (include logprobs if available) if textContent != "" { outputItems = append(outputItems, schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, resultLogprobs)}, }) } // Add function call items for _, tc := range toolCalls { outputItems = append(outputItems, schema.ORItemField{ Type: "function_call", ID: fmt.Sprintf("fc_%s", uuid.New().String()), Status: "completed", CallID: tc.ID, Name: tc.FunctionCall.Name, Arguments: tc.FunctionCall.Arguments, }) } // If we have no output items but the model did produce output, include the cleaned result as a message hasMessageItem := false for _, item := range outputItems { if item.Type == "message" { hasMessageItem = true break } } if !hasMessageItem && cleanedResult != "" { xlog.Debug("Open Responses - No parsed output, falling back to cleaned result") outputItems = append(outputItems, schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)}, }) } } else { // Simple text response (include logprobs if available) messageItem := schema.ORItemField{ Type: "message", ID: fmt.Sprintf("msg_%s", uuid.New().String()), Status: "completed", Role: "assistant", Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(cleanedResult, resultLogprobs)}, } outputItems = append(outputItems, messageItem) } // Calculate reasoning tokens (approximate: character count / 4) reasoningTokens := 0 if reasoningContent != "" { // Simple estimation: ~4 characters per token reasoningTokens = len(reasoningContent) / 4 if reasoningTokens == 0 && len(reasoningContent) > 0 { reasoningTokens = 1 } } // Build response with all required fields now := time.Now().Unix() response := buildORResponse(responseID, createdAt, &now, "completed", input, outputItems, &schema.ORUsage{ InputTokens: tokenUsage.Prompt, OutputTokens: tokenUsage.Completion, TotalTokens: tokenUsage.Prompt + tokenUsage.Completion, OutputTokensDetails: &schema.OROutputTokensDetails{ ReasoningTokens: reasoningTokens, }, }, shouldStore) // Store response for future reference (if enabled) if shouldStore { store := GetGlobalStore() store.Store(responseID, input, response) } return c.JSON(200, response) } // handleOpenResponsesStream handles streaming responses func handleOpenResponsesStream(c echo.Context, responseID string, createdAt int64, input *schema.OpenResponsesRequest, cfg *config.ModelConfig, ml *model.ModelLoader, cl *config.ModelConfigLoader, appConfig *config.ApplicationConfig, predInput string, openAIReq *schema.OpenAIRequest, funcs functions.Functions, shouldUseFn bool, shouldStore bool, mcpToolInfos []mcpTools.MCPToolInfo, evaluator *templates.Evaluator) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") sequenceNumber := 0 // Emit response.created - use helper to create response with all required fields responseCreated := buildORResponse(responseID, createdAt, nil, "in_progress", input, []schema.ORItemField{}, nil, shouldStore) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.created", SequenceNumber: sequenceNumber, Response: responseCreated, }) sequenceNumber++ // Emit response.in_progress sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.in_progress", SequenceNumber: sequenceNumber, Response: responseCreated, }) sequenceNumber++ // Populate openAIReq fields for ComputeChoices openAIReq.Tools = convertORToolsToOpenAIFormat(input.Tools) openAIReq.ToolsChoice = input.ToolChoice if input.TopLogprobs != nil && *input.TopLogprobs > 0 { openAIReq.TopLogprobs = input.TopLogprobs openAIReq.Logprobs = schema.LogprobsValue{Enabled: true} } openAIReq.LogitBias = input.LogitBias // Detect if thinking token is already in prompt or template var template string if cfg.TemplateConfig.UseTokenizerTemplate { template = cfg.GetModelTemplate() } else { template = predInput } thinkingStartToken := reason.DetectThinkingStartToken(template, &cfg.ReasoningConfig) // Track state for streaming var currentMessageID string var currentContentIndex int var accumulatedText string var lastEmittedToolCallCount int outputIndex := 0 inToolCallMode := false // Track reasoning state for streaming var currentReasoningID string var currentReasoningContentIndex int var reasoningTokens int extractor := reason.NewReasoningExtractor(thinkingStartToken, cfg.ReasoningConfig) // Collect all output items for storage var collectedOutputItems []schema.ORItemField if shouldUseFn { mcpStreamMaxIterations := 10 if cfg.Agent.MaxIterations > 0 { mcpStreamMaxIterations = cfg.Agent.MaxIterations } hasMCPToolsStream := len(mcpToolInfos) > 0 var result, finalReasoning, finalCleanedResult string var textContent string var parsedToolCalls []functions.FuncCallResults var toolCalls []functions.FuncCallResults var lastStreamTokenUsage backend.TokenUsage var lastStreamLogprobs *schema.Logprobs for mcpStreamIter := 0; mcpStreamIter <= mcpStreamMaxIterations; mcpStreamIter++ { if mcpStreamIter > 0 { // Reset reasoning and tool-call state for re-inference so reasoning // extraction runs again on subsequent iterations inToolCallMode = false extractor.Reset() currentMessageID = "" lastEmittedToolCallCount = 0 currentReasoningID = "" predInput = evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) xlog.Debug("Open Responses stream MCP re-templating", "iteration", mcpStreamIter) } // For tool calls, we need to track accumulated result and parse incrementally // We'll handle this differently - track the full result and parse tool calls accumulatedResult := "" tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool { accumulatedResult += token accumulatedText += token // Try to parse tool calls incrementally cleanedResult := functions.CleanupLLMResult(accumulatedResult, cfg.FunctionsConfig) // Determine XML format from config var xmlFormat *functions.XMLToolCallFormat if cfg.FunctionsConfig.XMLFormat != nil { xmlFormat = cfg.FunctionsConfig.XMLFormat } else if cfg.FunctionsConfig.XMLFormatPreset != "" { xmlFormat = functions.GetXMLFormatPreset(cfg.FunctionsConfig.XMLFormatPreset) } // Try XML parsing first partialResults, parseErr := functions.ParseXMLIterative(cleanedResult, xmlFormat, true) if parseErr == nil && len(partialResults) > lastEmittedToolCallCount { // New tool calls detected if !inToolCallMode && currentMessageID != "" { // Close the current message content part textPart := makeOutputTextPart(functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig)) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Part: &textPart, }) sequenceNumber++ inToolCallMode = true } // Emit new tool calls for i := lastEmittedToolCallCount; i < len(partialResults); i++ { tc := partialResults[i] toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) outputIndex++ // Emit function_call item added functionCallItem := &schema.ORItemField{ Type: "function_call", ID: toolCallID, Status: "in_progress", CallID: toolCallID, Name: tc.Name, Arguments: "", } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ // Emit arguments delta if tc.Arguments != "" { sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.function_call_arguments.delta", SequenceNumber: sequenceNumber, ItemID: toolCallID, OutputIndex: &outputIndex, Delta: strPtr(tc.Arguments), }) sequenceNumber++ // Emit arguments done sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.function_call_arguments.done", SequenceNumber: sequenceNumber, ItemID: toolCallID, OutputIndex: &outputIndex, Arguments: strPtr(tc.Arguments), }) sequenceNumber++ // Emit function_call item done functionCallItem.Status = "completed" functionCallItem.Arguments = tc.Arguments sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ // Collect item for storage collectedOutputItems = append(collectedOutputItems, *functionCallItem) } } lastEmittedToolCallCount = len(partialResults) c.Response().Flush() return true } // Try JSON parsing as fallback jsonResults, jsonErr := functions.ParseJSONIterative(cleanedResult, true) if jsonErr == nil && len(jsonResults) > lastEmittedToolCallCount { for i := lastEmittedToolCallCount; i < len(jsonResults); i++ { jsonObj := jsonResults[i] if name, ok := jsonObj["name"].(string); ok && name != "" { args := "{}" if argsVal, ok := jsonObj["arguments"]; ok { if argsStr, ok := argsVal.(string); ok { args = argsStr } else { argsBytes, _ := json.Marshal(argsVal) args = string(argsBytes) } } toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) outputIndex++ functionCallItem := &schema.ORItemField{ Type: "function_call", ID: toolCallID, Status: "completed", CallID: toolCallID, Name: name, Arguments: args, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ } } lastEmittedToolCallCount = len(jsonResults) c.Response().Flush() return true } // If no tool calls detected yet, handle reasoning and text if !inToolCallMode { reasoningDelta, contentDelta := extractor.ProcessToken(token) // Handle reasoning item if extractor.Reasoning() != "" { // Check if we need to create reasoning item if currentReasoningID == "" { outputIndex++ currentReasoningID = fmt.Sprintf("reasoning_%s", uuid.New().String()) reasoningItem := &schema.ORItemField{ Type: "reasoning", ID: currentReasoningID, Status: "in_progress", } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: reasoningItem, }) sequenceNumber++ // Emit content_part.added for reasoning currentReasoningContentIndex = 0 emptyPart := makeOutputTextPart("") sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.added", SequenceNumber: sequenceNumber, ItemID: currentReasoningID, OutputIndex: &outputIndex, ContentIndex: ¤tReasoningContentIndex, Part: &emptyPart, }) sequenceNumber++ } // Emit reasoning delta if there's new content if reasoningDelta != "" { sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.delta", SequenceNumber: sequenceNumber, ItemID: currentReasoningID, OutputIndex: &outputIndex, ContentIndex: ¤tReasoningContentIndex, Delta: strPtr(reasoningDelta), Logprobs: emptyLogprobs(), }) sequenceNumber++ c.Response().Flush() } } // Only emit message content if there's actual content (not just reasoning) if contentDelta != "" { if currentMessageID == "" { // Emit output_item.added for message outputIndex++ currentMessageID = fmt.Sprintf("msg_%s", uuid.New().String()) messageItem := &schema.ORItemField{ Type: "message", ID: currentMessageID, Status: "in_progress", Role: "assistant", Content: []schema.ORContentPart{}, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: messageItem, }) sequenceNumber++ // Emit content_part.added currentContentIndex = 0 emptyPart := makeOutputTextPart("") sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.added", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Part: &emptyPart, }) sequenceNumber++ } // Emit text delta sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.delta", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Delta: strPtr(contentDelta), Logprobs: emptyLogprobs(), }) sequenceNumber++ c.Response().Flush() } } return true } var ccResult string ccCb := func(s string, c *[]schema.Choice) { ccResult = s } choices, ccTokenUsage, chatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, ccCb, tokenCallback) if err != nil { xlog.Error("Open Responses stream model inference failed", "error", err) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "error", SequenceNumber: sequenceNumber, Error: &schema.ORErrorPayload{ Type: "model_error", Message: fmt.Sprintf("model inference failed: %v", err), }, }) sequenceNumber++ responseFailed := responseCreated responseFailed.Status = "failed" sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.failed", SequenceNumber: sequenceNumber, Response: responseFailed, }) // Send [DONE] even on error fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } result = ccResult lastStreamTokenUsage = ccTokenUsage if len(choices) > 0 { lastStreamLogprobs = choices[0].Logprobs } // Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's // streaming state, (3) final extraction from the finetuned result. if chatDeltaReasoning := functions.ReasoningFromChatDeltas(chatDeltas); chatDeltaReasoning != "" { finalReasoning = chatDeltaReasoning finalCleanedResult = functions.ContentFromChatDeltas(chatDeltas) if finalCleanedResult == "" { finalCleanedResult = extractor.CleanedContent() } } else { finalReasoning = extractor.Reasoning() finalCleanedResult = extractor.CleanedContent() } if finalReasoning == "" && finalCleanedResult == "" { finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig) } // Close reasoning item if it exists and wasn't closed yet if currentReasoningID != "" && finalReasoning != "" { // Emit output_text.done for reasoning sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.done", SequenceNumber: sequenceNumber, ItemID: currentReasoningID, OutputIndex: &outputIndex, ContentIndex: ¤tReasoningContentIndex, Text: strPtr(finalReasoning), Logprobs: emptyLogprobs(), }) sequenceNumber++ // Emit content_part.done for reasoning reasoningPart := makeOutputTextPart(finalReasoning) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, ItemID: currentReasoningID, OutputIndex: &outputIndex, ContentIndex: ¤tReasoningContentIndex, Part: &reasoningPart, }) sequenceNumber++ // Emit output_item.done for reasoning reasoningItem := &schema.ORItemField{ Type: "reasoning", ID: currentReasoningID, Status: "completed", Content: []schema.ORContentPart{reasoningPart}, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: reasoningItem, }) sequenceNumber++ // Collect reasoning item for storage collectedOutputItems = append(collectedOutputItems, *reasoningItem) // Calculate reasoning tokens reasoningTokens = len(finalReasoning) / 4 if reasoningTokens == 0 && len(finalReasoning) > 0 { reasoningTokens = 1 } } parsedToolCalls = nil textContent = "" // Try pre-parsed tool calls from C++ autoparser first if deltaToolCalls := functions.ToolCallsFromChatDeltas(chatDeltas); len(deltaToolCalls) > 0 { xlog.Debug("[ChatDeltas] OpenResponses Stream: using pre-parsed tool calls", "count", len(deltaToolCalls)) parsedToolCalls = deltaToolCalls textContent = functions.ContentFromChatDeltas(chatDeltas) } else { xlog.Debug("[ChatDeltas] OpenResponses Stream: no pre-parsed tool calls, falling back to Go-side text parsing") cleanedResult := functions.CleanupLLMResult(finalCleanedResult, cfg.FunctionsConfig) parsedToolCalls = functions.ParseFunctionCall(cleanedResult, cfg.FunctionsConfig) textContent = functions.ParseTextContent(cleanedResult, cfg.FunctionsConfig) } // Handle noAction function (model chose to respond without tool) noActionName := "answer" if cfg.FunctionsConfig.NoActionFunctionName != "" { noActionName = cfg.FunctionsConfig.NoActionFunctionName } // Filter out noAction calls and extract the message toolCalls = nil for _, fc := range parsedToolCalls { if fc.Name == noActionName { // This is a text response, not a tool call if fc.Arguments != "" { var args map[string]interface{} if err := json.Unmarshal([]byte(fc.Arguments), &args); err == nil { if msg, ok := args["message"].(string); ok && msg != "" { textContent = msg } } } continue } toolCalls = append(toolCalls, fc) } xlog.Debug("Open Responses Stream - Parsed", "toolCalls", len(toolCalls), "textContent", textContent) // MCP streaming tool execution: check if any tool calls are MCP tools if hasMCPToolsStream && len(toolCalls) > 0 { var hasMCPCalls bool for _, tc := range toolCalls { if mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { hasMCPCalls = true break } } if hasMCPCalls { // Build schema.ToolCall list for the assistant message var schemaToolCalls []schema.ToolCall for i, tc := range toolCalls { schemaToolCalls = append(schemaToolCalls, schema.ToolCall{ Index: i, ID: fmt.Sprintf("fc_%s", uuid.New().String()), Type: "function", FunctionCall: schema.FunctionCall{Name: tc.Name, Arguments: tc.Arguments}, }) } assistantMsg := schema.Message{Role: "assistant", Content: result, ToolCalls: schemaToolCalls} openAIReq.Messages = append(openAIReq.Messages, assistantMsg) for idx, tc := range toolCalls { tcID := schemaToolCalls[idx].ID // Emit function_call item outputIndex++ functionCallItem := &schema.ORItemField{ Type: "function_call", ID: tcID, Status: "completed", CallID: tcID, Name: tc.Name, Arguments: tc.Arguments, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ collectedOutputItems = append(collectedOutputItems, *functionCallItem) if !mcpTools.IsMCPTool(mcpToolInfos, tc.Name) { continue } // Execute MCP tool xlog.Debug("Executing MCP tool (Open Responses stream)", "tool", tc.Name, "iteration", mcpStreamIter) toolResult, toolErr := mcpTools.ExecuteMCPToolCall( input.Context, mcpToolInfos, tc.Name, tc.Arguments, ) if toolErr != nil { xlog.Error("MCP tool execution failed", "tool", tc.Name, "error", toolErr) toolResult = fmt.Sprintf("Error: %v", toolErr) } openAIReq.Messages = append(openAIReq.Messages, schema.Message{ Role: "tool", Content: toolResult, StringContent: toolResult, ToolCallID: tcID, Name: tc.Name, }) // Emit function_call_output item outputIndex++ outputItem := &schema.ORItemField{ Type: "function_call_output", ID: fmt.Sprintf("fco_%s", uuid.New().String()), Status: "completed", CallID: tcID, Output: toolResult, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: outputItem, }) sequenceNumber++ sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: outputItem, }) sequenceNumber++ collectedOutputItems = append(collectedOutputItems, *outputItem) } c.Response().Flush() xlog.Debug("MCP streaming tools executed, re-running inference", "iteration", mcpStreamIter) continue // next MCP stream iteration } } // Convert logprobs for streaming events streamEventLogprobs := convertLogprobsForStreaming(lastStreamLogprobs) // If we have no output but the model did produce something, use the cleaned result (without reasoning tags) if textContent == "" && len(toolCalls) == 0 && finalCleanedResult != "" { xlog.Debug("Open Responses Stream - No parsed output, using cleaned result") textContent = finalCleanedResult } // Close message if we have text content if currentMessageID != "" && textContent != "" && !inToolCallMode { // Emit output_text.done sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.done", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Text: strPtr(textContent), Logprobs: logprobsPtr(streamEventLogprobs), }) sequenceNumber++ // Emit content_part.done (with actual logprobs) textPart := makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Part: &textPart, }) sequenceNumber++ // Emit output_item.done for message (with actual logprobs) messageItem := &schema.ORItemField{ Type: "message", ID: currentMessageID, Status: "completed", Role: "assistant", Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)}, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: messageItem, }) sequenceNumber++ // Collect message item for storage collectedOutputItems = append(collectedOutputItems, *messageItem) } // Emit any remaining tool calls that weren't streamed for i := lastEmittedToolCallCount; i < len(toolCalls); i++ { tc := toolCalls[i] toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) outputIndex++ functionCallItem := &schema.ORItemField{ Type: "function_call", ID: toolCallID, Status: "completed", CallID: toolCallID, Name: tc.Name, Arguments: tc.Arguments, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: functionCallItem, }) sequenceNumber++ // Collect function call item for storage collectedOutputItems = append(collectedOutputItems, *functionCallItem) } break // no MCP tools to execute, exit loop } // end MCP stream iteration loop // Build final response with all items (include reasoning first, then messages, then tool calls) var allOutputItems []schema.ORItemField // Add reasoning item if it exists if currentReasoningID != "" && finalReasoning != "" { allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "reasoning", ID: currentReasoningID, Status: "completed", Content: []schema.ORContentPart{makeOutputTextPart(finalReasoning)}, }) } // Add message item if currentMessageID != "" && textContent != "" { allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "message", ID: currentMessageID, Status: "completed", Role: "assistant", Content: []schema.ORContentPart{makeOutputTextPartWithLogprobs(textContent, lastStreamLogprobs)}, }) } // Add tool call items for _, tc := range toolCalls { toolCallID := fmt.Sprintf("fc_%s", uuid.New().String()) allOutputItems = append(allOutputItems, schema.ORItemField{ Type: "function_call", ID: toolCallID, Status: "completed", CallID: toolCallID, Name: tc.Name, Arguments: tc.Arguments, }) } // Emit response.completed now := time.Now().Unix() responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, allOutputItems, &schema.ORUsage{ InputTokens: lastStreamTokenUsage.Prompt, OutputTokens: lastStreamTokenUsage.Completion, TotalTokens: lastStreamTokenUsage.Prompt + lastStreamTokenUsage.Completion, OutputTokensDetails: &schema.OROutputTokensDetails{ ReasoningTokens: reasoningTokens, }, }, shouldStore) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.completed", SequenceNumber: sequenceNumber, Response: responseCompleted, }) // Store response for future reference (if enabled) if shouldStore { store := GetGlobalStore() store.Store(responseID, input, responseCompleted) } // Send [DONE] fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } // Non-tool-call streaming path // Emit output_item.added for message currentMessageID = fmt.Sprintf("msg_%s", uuid.New().String()) messageItem := &schema.ORItemField{ Type: "message", ID: currentMessageID, Status: "in_progress", Role: "assistant", Content: []schema.ORContentPart{}, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: messageItem, }) sequenceNumber++ // Emit content_part.added currentContentIndex = 0 emptyTextPart := makeOutputTextPart("") sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.added", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Part: &emptyTextPart, }) sequenceNumber++ // Stream text deltas with reasoning extraction tokenCallback := func(token string, tokenUsage backend.TokenUsage) bool { accumulatedText += token reasoningDelta, contentDelta := extractor.ProcessToken(token) // Handle reasoning item if extractor.Reasoning() != "" { // Check if we need to create reasoning item if currentReasoningID == "" { outputIndex++ currentReasoningID = fmt.Sprintf("reasoning_%s", uuid.New().String()) reasoningItem := &schema.ORItemField{ Type: "reasoning", ID: currentReasoningID, Status: "in_progress", } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.added", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: reasoningItem, }) sequenceNumber++ // Emit content_part.added for reasoning currentReasoningContentIndex = 0 emptyPart := makeOutputTextPart("") sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.added", SequenceNumber: sequenceNumber, ItemID: currentReasoningID, OutputIndex: &outputIndex, ContentIndex: ¤tReasoningContentIndex, Part: &emptyPart, }) sequenceNumber++ } // Emit reasoning delta if there's new content if reasoningDelta != "" { sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.delta", SequenceNumber: sequenceNumber, ItemID: currentReasoningID, OutputIndex: &outputIndex, ContentIndex: ¤tReasoningContentIndex, Delta: strPtr(reasoningDelta), Logprobs: emptyLogprobs(), }) sequenceNumber++ c.Response().Flush() } } // Only emit message content if there's actual content (not just reasoning) if contentDelta != "" { // Emit text delta sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.delta", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Delta: strPtr(contentDelta), Logprobs: emptyLogprobs(), }) sequenceNumber++ c.Response().Flush() } return true } var noToolResult string noToolCb := func(s string, c *[]schema.Choice) { noToolResult = s } noToolChoices, noToolTokenUsage, noToolChatDeltas, err := openaiEndpoint.ComputeChoices(openAIReq, predInput, cfg, cl, appConfig, ml, noToolCb, tokenCallback) if err != nil { xlog.Error("Open Responses stream model inference failed", "error", err) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "error", SequenceNumber: sequenceNumber, Error: &schema.ORErrorPayload{ Type: "model_error", Message: fmt.Sprintf("model inference failed: %v", err), }, }) sequenceNumber++ responseFailed := responseCreated responseFailed.Status = "failed" sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.failed", SequenceNumber: sequenceNumber, Response: responseFailed, }) // Send [DONE] even on error fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } result := noToolResult var noToolLogprobs *schema.Logprobs if len(noToolChoices) > 0 { noToolLogprobs = noToolChoices[0].Logprobs } // Source reasoning from: (1) ChatDeltas from C++ autoparser, (2) extractor's // streaming state, (3) final extraction from the finetuned result. var finalReasoning, finalCleanedResult string if chatDeltaReasoning := functions.ReasoningFromChatDeltas(noToolChatDeltas); chatDeltaReasoning != "" { finalReasoning = chatDeltaReasoning finalCleanedResult = functions.ContentFromChatDeltas(noToolChatDeltas) if finalCleanedResult == "" { finalCleanedResult = extractor.CleanedContent() } } else { finalReasoning = extractor.Reasoning() finalCleanedResult = extractor.CleanedContent() } if finalReasoning == "" && finalCleanedResult == "" { finalReasoning, finalCleanedResult = reason.ExtractReasoningWithConfig(result, thinkingStartToken, cfg.ReasoningConfig) } // Close reasoning item if it exists and wasn't closed yet if currentReasoningID != "" && finalReasoning != "" { // Emit output_text.done for reasoning sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.done", SequenceNumber: sequenceNumber, ItemID: currentReasoningID, OutputIndex: &outputIndex, ContentIndex: ¤tReasoningContentIndex, Text: strPtr(finalReasoning), Logprobs: emptyLogprobs(), }) sequenceNumber++ // Emit content_part.done for reasoning reasoningPart := makeOutputTextPart(finalReasoning) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, ItemID: currentReasoningID, OutputIndex: &outputIndex, ContentIndex: ¤tReasoningContentIndex, Part: &reasoningPart, }) sequenceNumber++ // Emit output_item.done for reasoning reasoningItem := &schema.ORItemField{ Type: "reasoning", ID: currentReasoningID, Status: "completed", Content: []schema.ORContentPart{reasoningPart}, } sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: reasoningItem, }) sequenceNumber++ // Collect reasoning item for storage collectedOutputItems = append(collectedOutputItems, *reasoningItem) // Calculate reasoning tokens reasoningTokens = len(finalReasoning) / 4 if reasoningTokens == 0 && len(finalReasoning) > 0 { reasoningTokens = 1 } } result = finalCleanedResult // Convert logprobs for streaming events mcpStreamLogprobs := convertLogprobsForStreaming(noToolLogprobs) // Emit output_text.done sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_text.done", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Text: strPtr(result), Logprobs: logprobsPtr(mcpStreamLogprobs), }) sequenceNumber++ // Emit content_part.done (with actual logprobs) resultPart := makeOutputTextPartWithLogprobs(result, noToolLogprobs) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.content_part.done", SequenceNumber: sequenceNumber, ItemID: currentMessageID, OutputIndex: &outputIndex, ContentIndex: ¤tContentIndex, Part: &resultPart, }) sequenceNumber++ // Emit output_item.done (with actual logprobs) messageItem.Status = "completed" messageItem.Content = []schema.ORContentPart{makeOutputTextPartWithLogprobs(result, noToolLogprobs)} sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.output_item.done", SequenceNumber: sequenceNumber, OutputIndex: &outputIndex, Item: messageItem, }) sequenceNumber++ // Emit response.completed now := time.Now().Unix() // Collect final output items (reasoning first, then message) var finalOutputItems []schema.ORItemField // Add reasoning item if it exists if currentReasoningID != "" && finalReasoning != "" { finalOutputItems = append(finalOutputItems, schema.ORItemField{ Type: "reasoning", ID: currentReasoningID, Status: "completed", Content: []schema.ORContentPart{makeOutputTextPart(finalReasoning)}, }) } // Add message item if len(collectedOutputItems) > 0 { // Use collected items (may include reasoning already) for _, item := range collectedOutputItems { if item.Type == "message" { finalOutputItems = append(finalOutputItems, item) } } } else { finalOutputItems = append(finalOutputItems, *messageItem) } responseCompleted := buildORResponse(responseID, createdAt, &now, "completed", input, finalOutputItems, &schema.ORUsage{ InputTokens: noToolTokenUsage.Prompt, OutputTokens: noToolTokenUsage.Completion, TotalTokens: noToolTokenUsage.Prompt + noToolTokenUsage.Completion, OutputTokensDetails: &schema.OROutputTokensDetails{ ReasoningTokens: reasoningTokens, }, }, shouldStore) sendSSEEvent(c, &schema.ORStreamEvent{ Type: "response.completed", SequenceNumber: sequenceNumber, Response: responseCompleted, }) // Store response for future reference (if enabled) if shouldStore { store := GetGlobalStore() store.Store(responseID, input, responseCompleted) } // Send [DONE] fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } // sendSSEEvent sends a Server-Sent Event func sendSSEEvent(c echo.Context, event *schema.ORStreamEvent) { normalizeORStreamEvent(event) data, err := json.Marshal(event) if err != nil { xlog.Error("Failed to marshal SSE event", "error", err) return } fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.Type, string(data)) } // normalizeORStreamEvent ensures required fields like Summary are never null. func normalizeORStreamEvent(event *schema.ORStreamEvent) { if event.Item != nil && event.Item.Summary == nil { event.Item.Summary = []schema.ORContentPart{} } } // getTopLogprobs returns the top_logprobs value, defaulting to 0 if nil func getTopLogprobs(topLogprobs *int) int { if topLogprobs != nil { return *topLogprobs } return 0 } // Helper functions for pointer types in streaming events func strPtr(s string) *string { return &s } func logprobsPtr(lp []schema.ORLogProb) *[]schema.ORLogProb { return &lp } func emptyLogprobs() *[]schema.ORLogProb { empty := []schema.ORLogProb{} return &empty } // makeOutputTextPart creates an output_text content part with all required fields per Open Responses spec func makeOutputTextPart(text string) schema.ORContentPart { return schema.ORContentPartWithLogprobs(text, nil) } // makeOutputTextPartWithLogprobs creates an output_text content part with actual logprobs data func makeOutputTextPartWithLogprobs(text string, logprobs *schema.Logprobs) schema.ORContentPart { return schema.ORContentPartWithLogprobs(text, logprobs) } // convertLogprobsForStreaming converts OpenAI-style logprobs to Open Responses format for streaming events func convertLogprobsForStreaming(logprobs *schema.Logprobs) []schema.ORLogProb { if logprobs == nil || len(logprobs.Content) == 0 { return []schema.ORLogProb{} } result := make([]schema.ORLogProb, 0, len(logprobs.Content)) for _, lp := range logprobs.Content { topLPs := make([]schema.ORTopLogProb, 0, len(lp.TopLogprobs)) for _, tlp := range lp.TopLogprobs { topLPs = append(topLPs, schema.ORTopLogProb{ Token: tlp.Token, Logprob: tlp.Logprob, Bytes: tlp.Bytes, }) } result = append(result, schema.ORLogProb{ Token: lp.Token, Logprob: lp.Logprob, Bytes: lp.Bytes, TopLogprobs: topLPs, }) } return result } // ensureUsageDetails ensures usage has all required detail fields func ensureUsageDetails(usage *schema.ORUsage) *schema.ORUsage { if usage == nil { return nil } // Ensure details are always present (not nil) if usage.InputTokensDetails == nil { usage.InputTokensDetails = &schema.ORInputTokensDetails{CachedTokens: 0} } if usage.OutputTokensDetails == nil { usage.OutputTokensDetails = &schema.OROutputTokensDetails{ReasoningTokens: 0} } return usage } // buildORResponse creates a complete ORResponseResource with all required fields func buildORResponse(responseID string, createdAt int64, completedAt *int64, status string, input *schema.OpenResponsesRequest, outputItems []schema.ORItemField, usage *schema.ORUsage, shouldStore bool) *schema.ORResponseResource { // Ensure output is never null - always an array if outputItems == nil { outputItems = []schema.ORItemField{} } // Ensure Summary is never null on any output item for i := range outputItems { if outputItems[i].Summary == nil { outputItems[i].Summary = []schema.ORContentPart{} } } // Ensure tools is never null - always an array tools := input.Tools if tools == nil { tools = []schema.ORFunctionTool{} } // Ensure metadata is never null - always a map metadata := input.Metadata if metadata == nil { metadata = map[string]string{} } // Set default values for sampling parameters temperature := 1.0 if input.Temperature != nil { temperature = *input.Temperature } topP := 1.0 if input.TopP != nil { topP = *input.TopP } presencePenalty := 0.0 if input.PresencePenalty != nil { presencePenalty = *input.PresencePenalty } frequencyPenalty := 0.0 if input.FrequencyPenalty != nil { frequencyPenalty = *input.FrequencyPenalty } // Default truncation to "auto" truncation := "auto" if input.Truncation != "" { truncation = input.Truncation } // Default service_tier to "default" serviceTier := "default" if input.ServiceTier != "" { serviceTier = input.ServiceTier } // Default parallel_tool_calls to true parallelToolCalls := true if input.ParallelToolCalls != nil { parallelToolCalls = *input.ParallelToolCalls } // Default tool_choice: "auto" if tools are present, "none" otherwise var toolChoice interface{} if input.ToolChoice != nil { toolChoice = input.ToolChoice } else if len(tools) > 0 { toolChoice = "auto" } else { toolChoice = "none" } // Background defaults to false background := false if input.Background != nil { background = *input.Background } // Convert nullable string fields var previousResponseID *string if input.PreviousResponseID != "" { previousResponseID = &input.PreviousResponseID } var instructions *string if input.Instructions != "" { instructions = &input.Instructions } // Convert reasoning var reasoning *schema.ORReasoning if input.Reasoning != nil { reasoning = &schema.ORReasoning{ Effort: input.Reasoning.Effort, Summary: input.Reasoning.Summary, } } // Build default text config textConfig := &schema.ORTextConfig{ Format: &schema.ORTextFormat{ Type: "text", }, } return &schema.ORResponseResource{ ID: responseID, Object: "response", CreatedAt: createdAt, CompletedAt: completedAt, Status: status, Model: input.Model, Output: outputItems, Error: nil, // null when no error IncompleteDetails: nil, // null when complete PreviousResponseID: previousResponseID, Instructions: instructions, // Tool-related fields Tools: tools, ToolChoice: toolChoice, ParallelToolCalls: parallelToolCalls, MaxToolCalls: input.MaxToolCalls, // Sampling parameters Temperature: temperature, TopP: topP, PresencePenalty: presencePenalty, FrequencyPenalty: frequencyPenalty, TopLogprobs: getTopLogprobs(input.TopLogprobs), MaxOutputTokens: input.MaxOutputTokens, // Text format Text: textConfig, // Truncation and reasoning Truncation: truncation, Reasoning: reasoning, // Usage Usage: ensureUsageDetails(usage), // Metadata and operational flags Metadata: metadata, Store: shouldStore, Background: background, ServiceTier: serviceTier, // Safety and caching (nullable, not yet implemented) SafetyIdentifier: nil, PromptCacheKey: nil, } } // sendOpenResponsesError sends an error response func sendOpenResponsesError(c echo.Context, statusCode int, errorType, message, param string) error { errorResp := map[string]interface{}{ "error": map[string]interface{}{ "type": errorType, "message": message, }, } if param != "" { errorResp["error"].(map[string]interface{})["param"] = param } return c.JSON(statusCode, errorResp) } // convertORToolsToOpenAIFormat converts Open Responses tools to OpenAI format for the backend // Open Responses format: { type, name, description, parameters } // OpenAI format: { type, function: { name, description, parameters } } func convertORToolsToOpenAIFormat(orTools []schema.ORFunctionTool) []functions.Tool { result := make([]functions.Tool, 0, len(orTools)) for _, t := range orTools { result = append(result, functions.Tool{ Type: "function", Function: functions.Function{ Name: t.Name, Description: t.Description, Parameters: t.Parameters, }, }) } return result } // GetResponseEndpoint returns a handler for GET /responses/:id // This endpoint is used for polling background responses or resuming streaming // @Summary Get a response by ID // @Description Retrieve a response by ID. Can be used for polling background responses or resuming streaming responses. // @Param id path string true "Response ID" // @Param stream query string false "Set to 'true' to resume streaming" // @Param starting_after query int false "Sequence number to resume from (for streaming)" // @Success 200 {object} schema.ORResponseResource "Response" // @Failure 400 {object} map[string]interface{} "Bad Request" // @Failure 404 {object} map[string]interface{} "Not Found" // @Router /v1/responses/{id} [get] func GetResponseEndpoint() func(c echo.Context) error { return func(c echo.Context) error { responseID := c.Param("id") if responseID == "" { return sendOpenResponsesError(c, 400, "invalid_request_error", "response ID is required", "id") } store := GetGlobalStore() stored, err := store.Get(responseID) if err != nil { return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("response not found: %s", responseID), "id") } // Check if streaming resume is requested streamParam := c.QueryParam("stream") if streamParam == "true" { // Validate that the response was created with streaming enabled if !stored.StreamEnabled { return sendOpenResponsesError(c, 400, "invalid_request_error", "cannot stream a response that was not created with stream=true", "stream") } // Get starting_after parameter startingAfter := 0 startingAfterParam := c.QueryParam("starting_after") if startingAfterParam != "" { if _, err := fmt.Sscanf(startingAfterParam, "%d", &startingAfter); err != nil { return sendOpenResponsesError(c, 400, "invalid_request_error", "starting_after must be an integer", "starting_after") } } return handleStreamResume(c, store, responseID, stored, startingAfter) } // Non-streaming: return the current response state stored.mu.RLock() response := stored.Response stored.mu.RUnlock() return c.JSON(200, response) } } // handleStreamResume handles resuming a streaming response from a specific sequence number func handleStreamResume(c echo.Context, store *ResponseStore, responseID string, stored *StoredResponse, startingAfter int) error { c.Response().Header().Set("Content-Type", "text/event-stream") c.Response().Header().Set("Cache-Control", "no-cache") c.Response().Header().Set("Connection", "keep-alive") // Get buffered events after the starting point events, err := store.GetEventsAfter(responseID, startingAfter) if err != nil { return sendOpenResponsesError(c, 500, "server_error", fmt.Sprintf("failed to get events: %v", err), "") } // Send all buffered events for _, event := range events { fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.EventType, string(event.Data)) c.Response().Flush() } // Get the current status stored.mu.RLock() status := stored.Response.Status stored.mu.RUnlock() // If response is still in progress, subscribe to new events if status == schema.ORStatusQueued || status == schema.ORStatusInProgress { eventsChan, err := store.GetEventsChan(responseID) if err != nil { // Response might have completed, just finish fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } // Track last sent sequence number lastSeq := startingAfter if len(events) > 0 { lastSeq = events[len(events)-1].SequenceNumber } // Wait for new events or completion for { select { case <-c.Request().Context().Done(): // Client disconnected return nil case <-eventsChan: // New events available newEvents, err := store.GetEventsAfter(responseID, lastSeq) if err != nil { break } for _, event := range newEvents { fmt.Fprintf(c.Response().Writer, "event: %s\ndata: %s\n\n", event.EventType, string(event.Data)) c.Response().Flush() lastSeq = event.SequenceNumber } // Check if response is now complete stored.mu.RLock() status = stored.Response.Status stored.mu.RUnlock() if status != schema.ORStatusQueued && status != schema.ORStatusInProgress { fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } case <-time.After(30 * time.Second): // Timeout - send keepalive or check status stored.mu.RLock() status = stored.Response.Status stored.mu.RUnlock() if status != schema.ORStatusQueued && status != schema.ORStatusInProgress { fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } } } } // Response already complete fmt.Fprintf(c.Response().Writer, "data: [DONE]\n\n") c.Response().Flush() return nil } // CancelResponseEndpoint returns a handler for POST /responses/:id/cancel // This endpoint cancels a background response if it's still in progress // @Summary Cancel a response // @Description Cancel a background response if it's still in progress // @Param id path string true "Response ID" // @Success 200 {object} schema.ORResponseResource "Response" // @Failure 400 {object} map[string]interface{} "Bad Request" // @Failure 404 {object} map[string]interface{} "Not Found" // @Router /v1/responses/{id}/cancel [post] func CancelResponseEndpoint() func(c echo.Context) error { return func(c echo.Context) error { responseID := c.Param("id") if responseID == "" { return sendOpenResponsesError(c, 400, "invalid_request_error", "response ID is required", "id") } store := GetGlobalStore() response, err := store.Cancel(responseID) if err != nil { return sendOpenResponsesError(c, 404, "not_found", fmt.Sprintf("response not found: %s", responseID), "id") } // Return the final response object return c.JSON(200, response) } } ================================================ FILE: core/http/endpoints/openresponses/store.go ================================================ package openresponses import ( "context" "encoding/json" "fmt" "sync" "time" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/xlog" ) // ResponseStore provides thread-safe storage for Open Responses API responses type ResponseStore struct { mu sync.RWMutex responses map[string]*StoredResponse ttl time.Duration // Time-to-live for stored responses (0 = no expiration) cleanupCtx context.Context cleanupCancel context.CancelFunc } // StreamedEvent represents a buffered SSE event for streaming resume type StreamedEvent struct { SequenceNumber int `json:"sequence_number"` EventType string `json:"event_type"` Data []byte `json:"data"` // JSON-serialized event } // StoredResponse contains a complete response with its input request and output items type StoredResponse struct { Request *schema.OpenResponsesRequest Response *schema.ORResponseResource Items map[string]*schema.ORItemField // item_id -> item mapping for quick lookup StoredAt time.Time ExpiresAt *time.Time // nil if no expiration // Background execution support CancelFunc context.CancelFunc // For cancellation of background tasks StreamEvents []StreamedEvent // Buffered events for streaming resume StreamEnabled bool // Was created with stream=true IsBackground bool // Was created with background=true EventsChan chan struct{} // Signals new events for live subscribers mu sync.RWMutex // Protect concurrent access to this response } var ( globalStore *ResponseStore storeOnce sync.Once ) // GetGlobalStore returns the singleton response store instance func GetGlobalStore() *ResponseStore { storeOnce.Do(func() { globalStore = NewResponseStore(0) // Default: no TTL, will be updated from appConfig }) return globalStore } // SetTTL updates the TTL for the store // This will affect all new responses stored after this call func (s *ResponseStore) SetTTL(ttl time.Duration) { s.mu.Lock() defer s.mu.Unlock() // Stop existing cleanup loop if running if s.cleanupCancel != nil { s.cleanupCancel() s.cleanupCancel = nil s.cleanupCtx = nil } s.ttl = ttl // If TTL > 0, start cleanup loop if ttl > 0 { s.cleanupCtx, s.cleanupCancel = context.WithCancel(context.Background()) go s.cleanupLoop(s.cleanupCtx) } xlog.Debug("Updated Open Responses store TTL", "ttl", ttl, "cleanup_running", ttl > 0) } // NewResponseStore creates a new response store with optional TTL // If ttl is 0, responses are stored indefinitely func NewResponseStore(ttl time.Duration) *ResponseStore { store := &ResponseStore{ responses: make(map[string]*StoredResponse), ttl: ttl, } // Start cleanup goroutine if TTL is set if ttl > 0 { store.cleanupCtx, store.cleanupCancel = context.WithCancel(context.Background()) go store.cleanupLoop(store.cleanupCtx) } return store } // Store stores a response with its request and items func (s *ResponseStore) Store(responseID string, request *schema.OpenResponsesRequest, response *schema.ORResponseResource) { s.mu.Lock() defer s.mu.Unlock() // Build item index for quick lookup items := make(map[string]*schema.ORItemField) for i := range response.Output { item := &response.Output[i] if item.ID != "" { items[item.ID] = item } } stored := &StoredResponse{ Request: request, Response: response, Items: items, StoredAt: time.Now(), ExpiresAt: nil, } // Set expiration if TTL is configured if s.ttl > 0 { expiresAt := time.Now().Add(s.ttl) stored.ExpiresAt = &expiresAt } s.responses[responseID] = stored xlog.Debug("Stored Open Responses response", "response_id", responseID, "items_count", len(items)) } // Get retrieves a stored response by ID func (s *ResponseStore) Get(responseID string) (*StoredResponse, error) { s.mu.RLock() defer s.mu.RUnlock() stored, exists := s.responses[responseID] if !exists { return nil, fmt.Errorf("response not found: %s", responseID) } // Check expiration if stored.ExpiresAt != nil && time.Now().After(*stored.ExpiresAt) { // Expired, but we'll return it anyway and let caller handle cleanup return nil, fmt.Errorf("response expired: %s", responseID) } return stored, nil } // GetItem retrieves a specific item from a stored response func (s *ResponseStore) GetItem(responseID, itemID string) (*schema.ORItemField, error) { stored, err := s.Get(responseID) if err != nil { return nil, err } item, exists := stored.Items[itemID] if !exists { return nil, fmt.Errorf("item not found: %s in response %s", itemID, responseID) } return item, nil } // FindItem searches for an item across all stored responses // Returns the item and the response ID it was found in func (s *ResponseStore) FindItem(itemID string) (*schema.ORItemField, string, error) { s.mu.RLock() defer s.mu.RUnlock() now := time.Now() for responseID, stored := range s.responses { // Skip expired responses if stored.ExpiresAt != nil && now.After(*stored.ExpiresAt) { continue } if item, exists := stored.Items[itemID]; exists { return item, responseID, nil } } return nil, "", fmt.Errorf("item not found in any stored response: %s", itemID) } // Delete removes a response from storage func (s *ResponseStore) Delete(responseID string) { s.mu.Lock() defer s.mu.Unlock() delete(s.responses, responseID) xlog.Debug("Deleted Open Responses response", "response_id", responseID) } // Cleanup removes expired responses func (s *ResponseStore) Cleanup() int { if s.ttl == 0 { return 0 } s.mu.Lock() defer s.mu.Unlock() now := time.Now() count := 0 for id, stored := range s.responses { if stored.ExpiresAt != nil && now.After(*stored.ExpiresAt) { delete(s.responses, id) count++ } } if count > 0 { xlog.Debug("Cleaned up expired Open Responses", "count", count) } return count } // cleanupLoop runs periodic cleanup of expired responses func (s *ResponseStore) cleanupLoop(ctx context.Context) { if s.ttl == 0 { return } ticker := time.NewTicker(s.ttl / 2) // Cleanup at half TTL interval defer ticker.Stop() for { select { case <-ctx.Done(): xlog.Debug("Stopped Open Responses store cleanup loop") return case <-ticker.C: s.Cleanup() } } } // Count returns the number of stored responses func (s *ResponseStore) Count() int { s.mu.RLock() defer s.mu.RUnlock() return len(s.responses) } // StoreBackground stores a background response with cancel function and optional streaming support func (s *ResponseStore) StoreBackground(responseID string, request *schema.OpenResponsesRequest, response *schema.ORResponseResource, cancelFunc context.CancelFunc, streamEnabled bool) { s.mu.Lock() defer s.mu.Unlock() // Build item index for quick lookup items := make(map[string]*schema.ORItemField) for i := range response.Output { item := &response.Output[i] if item.ID != "" { items[item.ID] = item } } stored := &StoredResponse{ Request: request, Response: response, Items: items, StoredAt: time.Now(), ExpiresAt: nil, CancelFunc: cancelFunc, StreamEvents: []StreamedEvent{}, StreamEnabled: streamEnabled, IsBackground: true, EventsChan: make(chan struct{}, 100), // Buffered channel for event notifications } // Set expiration if TTL is configured if s.ttl > 0 { expiresAt := time.Now().Add(s.ttl) stored.ExpiresAt = &expiresAt } s.responses[responseID] = stored xlog.Debug("Stored background Open Responses response", "response_id", responseID, "stream_enabled", streamEnabled) } // UpdateStatus updates the status of a stored response func (s *ResponseStore) UpdateStatus(responseID string, status string, completedAt *int64) error { s.mu.RLock() stored, exists := s.responses[responseID] s.mu.RUnlock() if !exists { return fmt.Errorf("response not found: %s", responseID) } stored.mu.Lock() defer stored.mu.Unlock() stored.Response.Status = status stored.Response.CompletedAt = completedAt xlog.Debug("Updated response status", "response_id", responseID, "status", status) return nil } // UpdateResponse updates the entire response object for a stored response func (s *ResponseStore) UpdateResponse(responseID string, response *schema.ORResponseResource) error { s.mu.RLock() stored, exists := s.responses[responseID] s.mu.RUnlock() if !exists { return fmt.Errorf("response not found: %s", responseID) } stored.mu.Lock() defer stored.mu.Unlock() // Rebuild item index items := make(map[string]*schema.ORItemField) for i := range response.Output { item := &response.Output[i] if item.ID != "" { items[item.ID] = item } } stored.Response = response stored.Items = items xlog.Debug("Updated response", "response_id", responseID, "status", response.Status, "items_count", len(items)) return nil } // AppendEvent appends a streaming event to the buffer for resume support func (s *ResponseStore) AppendEvent(responseID string, event *schema.ORStreamEvent) error { s.mu.RLock() stored, exists := s.responses[responseID] s.mu.RUnlock() if !exists { return fmt.Errorf("response not found: %s", responseID) } // Serialize the event data, err := json.Marshal(event) if err != nil { return fmt.Errorf("failed to marshal event: %w", err) } stored.mu.Lock() stored.StreamEvents = append(stored.StreamEvents, StreamedEvent{ SequenceNumber: event.SequenceNumber, EventType: event.Type, Data: data, }) stored.mu.Unlock() // Notify any subscribers of new event select { case stored.EventsChan <- struct{}{}: default: // Channel full, subscribers will catch up } return nil } // GetEventsAfter returns all events with sequence number greater than startingAfter func (s *ResponseStore) GetEventsAfter(responseID string, startingAfter int) ([]StreamedEvent, error) { s.mu.RLock() stored, exists := s.responses[responseID] s.mu.RUnlock() if !exists { return nil, fmt.Errorf("response not found: %s", responseID) } stored.mu.RLock() defer stored.mu.RUnlock() var result []StreamedEvent for _, event := range stored.StreamEvents { if event.SequenceNumber > startingAfter { result = append(result, event) } } return result, nil } // Cancel cancels a background response if it's still in progress func (s *ResponseStore) Cancel(responseID string) (*schema.ORResponseResource, error) { s.mu.RLock() stored, exists := s.responses[responseID] s.mu.RUnlock() if !exists { return nil, fmt.Errorf("response not found: %s", responseID) } stored.mu.Lock() defer stored.mu.Unlock() // If already in a terminal state, just return the response (idempotent) status := stored.Response.Status if status == schema.ORStatusCompleted || status == schema.ORStatusFailed || status == schema.ORStatusIncomplete || status == schema.ORStatusCancelled { xlog.Debug("Response already in terminal state", "response_id", responseID, "status", status) return stored.Response, nil } // Cancel the context if available if stored.CancelFunc != nil { stored.CancelFunc() xlog.Debug("Cancelled background response", "response_id", responseID) } // Update status to cancelled now := time.Now().Unix() stored.Response.Status = schema.ORStatusCancelled stored.Response.CompletedAt = &now return stored.Response, nil } // GetEventsChan returns the events notification channel for a response func (s *ResponseStore) GetEventsChan(responseID string) (chan struct{}, error) { s.mu.RLock() stored, exists := s.responses[responseID] s.mu.RUnlock() if !exists { return nil, fmt.Errorf("response not found: %s", responseID) } return stored.EventsChan, nil } // IsStreamEnabled checks if a response was created with streaming enabled func (s *ResponseStore) IsStreamEnabled(responseID string) (bool, error) { s.mu.RLock() stored, exists := s.responses[responseID] s.mu.RUnlock() if !exists { return false, fmt.Errorf("response not found: %s", responseID) } stored.mu.RLock() defer stored.mu.RUnlock() return stored.StreamEnabled, nil } ================================================ FILE: core/http/endpoints/openresponses/store_suite_test.go ================================================ package openresponses import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestStore(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "ResponseStore Suite") } ================================================ FILE: core/http/endpoints/openresponses/store_test.go ================================================ package openresponses import ( "context" "fmt" "time" "github.com/mudler/LocalAI/core/schema" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("ResponseStore", func() { var store *ResponseStore BeforeEach(func() { store = NewResponseStore(0) // No TTL for most tests }) AfterEach(func() { // Clean up }) Describe("Store and Get", func() { It("should store and retrieve a response", func() { responseID := "resp_test123" request := &schema.OpenResponsesRequest{ Model: "test-model", Input: "Hello", } response := &schema.ORResponseResource{ ID: responseID, Object: "response", CreatedAt: time.Now().Unix(), Status: "completed", Model: "test-model", Output: []schema.ORItemField{ { Type: "message", ID: "msg_123", Status: "completed", Role: "assistant", Content: []schema.ORContentPart{{ Type: "output_text", Text: "Hello, world!", Annotations: []schema.ORAnnotation{}, Logprobs: []schema.ORLogProb{}, }}, }, }, } store.Store(responseID, request, response) stored, err := store.Get(responseID) Expect(err).ToNot(HaveOccurred()) Expect(stored).ToNot(BeNil()) Expect(stored.Response.ID).To(Equal(responseID)) Expect(stored.Request.Model).To(Equal("test-model")) Expect(len(stored.Items)).To(Equal(1)) Expect(stored.Items["msg_123"]).ToNot(BeNil()) Expect(stored.Items["msg_123"].ID).To(Equal("msg_123")) }) It("should return error for non-existent response", func() { _, err := store.Get("nonexistent") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("not found")) }) It("should index all items by ID", func() { responseID := "resp_test456" request := &schema.OpenResponsesRequest{ Model: "test-model", Input: "Test", } response := &schema.ORResponseResource{ ID: responseID, Object: "response", Output: []schema.ORItemField{ { Type: "message", ID: "msg_1", Status: "completed", Role: "assistant", }, { Type: "function_call", ID: "fc_1", Status: "completed", CallID: "fc_1", Name: "test_function", Arguments: `{"arg": "value"}`, }, { Type: "message", ID: "msg_2", Status: "completed", Role: "assistant", }, }, } store.Store(responseID, request, response) stored, err := store.Get(responseID) Expect(err).ToNot(HaveOccurred()) Expect(len(stored.Items)).To(Equal(3)) Expect(stored.Items["msg_1"]).ToNot(BeNil()) Expect(stored.Items["fc_1"]).ToNot(BeNil()) Expect(stored.Items["msg_2"]).ToNot(BeNil()) }) It("should handle items without IDs", func() { responseID := "resp_test789" request := &schema.OpenResponsesRequest{ Model: "test-model", Input: "Test", } response := &schema.ORResponseResource{ ID: responseID, Object: "response", Output: []schema.ORItemField{ { Type: "message", ID: "", // No ID Status: "completed", Role: "assistant", }, { Type: "message", ID: "msg_with_id", Status: "completed", Role: "assistant", }, }, } store.Store(responseID, request, response) stored, err := store.Get(responseID) Expect(err).ToNot(HaveOccurred()) // Only items with IDs are indexed Expect(len(stored.Items)).To(Equal(1)) Expect(stored.Items["msg_with_id"]).ToNot(BeNil()) }) }) Describe("GetItem", func() { It("should retrieve a specific item by ID", func() { responseID := "resp_item_test" itemID := "msg_specific" request := &schema.OpenResponsesRequest{ Model: "test-model", Input: "Test", } response := &schema.ORResponseResource{ ID: responseID, Object: "response", Output: []schema.ORItemField{ { Type: "message", ID: itemID, Status: "completed", Role: "assistant", Content: []schema.ORContentPart{{ Type: "output_text", Text: "Specific message", Annotations: []schema.ORAnnotation{}, Logprobs: []schema.ORLogProb{}, }}, }, }, } store.Store(responseID, request, response) item, err := store.GetItem(responseID, itemID) Expect(err).ToNot(HaveOccurred()) Expect(item).ToNot(BeNil()) Expect(item.ID).To(Equal(itemID)) Expect(item.Type).To(Equal("message")) }) It("should return error for non-existent item", func() { responseID := "resp_item_test2" request := &schema.OpenResponsesRequest{ Model: "test-model", Input: "Test", } response := &schema.ORResponseResource{ ID: responseID, Object: "response", Output: []schema.ORItemField{ { Type: "message", ID: "msg_existing", Status: "completed", }, }, } store.Store(responseID, request, response) _, err := store.GetItem(responseID, "nonexistent_item") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("item not found")) }) It("should return error for non-existent response when getting item", func() { _, err := store.GetItem("nonexistent_response", "any_item") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("response not found")) }) }) Describe("FindItem", func() { It("should find an item across all stored responses", func() { // Store first response responseID1 := "resp_find_1" itemID1 := "msg_find_1" store.Store(responseID1, &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ ID: responseID1, Object: "response", Output: []schema.ORItemField{ {Type: "message", ID: itemID1, Status: "completed"}, }, }) // Store second response responseID2 := "resp_find_2" itemID2 := "msg_find_2" store.Store(responseID2, &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ ID: responseID2, Object: "response", Output: []schema.ORItemField{ {Type: "message", ID: itemID2, Status: "completed"}, }, }) // Find item from first response item, foundResponseID, err := store.FindItem(itemID1) Expect(err).ToNot(HaveOccurred()) Expect(item).ToNot(BeNil()) Expect(item.ID).To(Equal(itemID1)) Expect(foundResponseID).To(Equal(responseID1)) // Find item from second response item, foundResponseID, err = store.FindItem(itemID2) Expect(err).ToNot(HaveOccurred()) Expect(item).ToNot(BeNil()) Expect(item.ID).To(Equal(itemID2)) Expect(foundResponseID).To(Equal(responseID2)) }) It("should return error when item not found in any response", func() { _, _, err := store.FindItem("nonexistent_item") Expect(err).To(HaveOccurred()) Expect(err.Error()).To(ContainSubstring("item not found in any stored response")) }) }) Describe("Delete", func() { It("should delete a stored response", func() { responseID := "resp_delete_test" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ ID: responseID, Object: "response", } store.Store(responseID, request, response) Expect(store.Count()).To(Equal(1)) store.Delete(responseID) Expect(store.Count()).To(Equal(0)) _, err := store.Get(responseID) Expect(err).To(HaveOccurred()) }) It("should handle deleting non-existent response gracefully", func() { // Should not panic store.Delete("nonexistent") Expect(store.Count()).To(Equal(0)) }) }) Describe("Count", func() { It("should return correct count of stored responses", func() { Expect(store.Count()).To(Equal(0)) store.Store("resp_1", &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ID: "resp_1", Object: "response"}) Expect(store.Count()).To(Equal(1)) store.Store("resp_2", &schema.OpenResponsesRequest{Model: "test"}, &schema.ORResponseResource{ID: "resp_2", Object: "response"}) Expect(store.Count()).To(Equal(2)) store.Delete("resp_1") Expect(store.Count()).To(Equal(1)) }) }) Describe("TTL and Expiration", func() { It("should set expiration when TTL is configured", func() { ttlStore := NewResponseStore(100 * time.Millisecond) responseID := "resp_ttl_test" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ID: responseID, Object: "response"} ttlStore.Store(responseID, request, response) stored, err := ttlStore.Get(responseID) Expect(err).ToNot(HaveOccurred()) Expect(stored.ExpiresAt).ToNot(BeNil()) Expect(stored.ExpiresAt.After(time.Now())).To(BeTrue()) }) It("should not set expiration when TTL is 0", func() { responseID := "resp_no_ttl" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ID: responseID, Object: "response"} store.Store(responseID, request, response) stored, err := store.Get(responseID) Expect(err).ToNot(HaveOccurred()) Expect(stored.ExpiresAt).To(BeNil()) }) It("should clean up expired responses", func() { ttlStore := NewResponseStore(50 * time.Millisecond) responseID := "resp_expire_test" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ID: responseID, Object: "response"} ttlStore.Store(responseID, request, response) Expect(ttlStore.Count()).To(Equal(1)) // Wait for expiration (longer than TTL and cleanup interval) time.Sleep(150 * time.Millisecond) // Cleanup should remove expired response (may have already been cleaned by goroutine) count := ttlStore.Cleanup() // Count might be 0 if cleanup goroutine already ran, or 1 if we're first Expect(count).To(BeNumerically(">=", 0)) Expect(ttlStore.Count()).To(Equal(0)) _, err := ttlStore.Get(responseID) Expect(err).To(HaveOccurred()) }) It("should return error for expired response", func() { ttlStore := NewResponseStore(50 * time.Millisecond) responseID := "resp_expire_error" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ID: responseID, Object: "response"} ttlStore.Store(responseID, request, response) // Wait for expiration (but not long enough for cleanup goroutine to remove it) time.Sleep(75 * time.Millisecond) // Try to get before cleanup goroutine removes it _, err := ttlStore.Get(responseID) // Error could be "expired" or "not found" (if cleanup already ran) Expect(err).To(HaveOccurred()) // Either error message is acceptable errMsg := err.Error() Expect(errMsg).To(Or(ContainSubstring("expired"), ContainSubstring("not found"))) }) }) Describe("Thread Safety", func() { It("should handle concurrent stores and gets", func() { // This is a basic concurrency test done := make(chan bool, 10) for i := 0; i < 10; i++ { go func(id int) { responseID := fmt.Sprintf("resp_concurrent_%d", id) request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ ID: responseID, Object: "response", Output: []schema.ORItemField{ {Type: "message", ID: fmt.Sprintf("msg_%d", id), Status: "completed"}, }, } store.Store(responseID, request, response) // Retrieve immediately stored, err := store.Get(responseID) Expect(err).ToNot(HaveOccurred()) Expect(stored).ToNot(BeNil()) done <- true }(i) } // Wait for all goroutines for i := 0; i < 10; i++ { <-done } Expect(store.Count()).To(Equal(10)) }) }) Describe("GetGlobalStore", func() { It("should return singleton instance", func() { store1 := GetGlobalStore() store2 := GetGlobalStore() Expect(store1).To(Equal(store2)) }) It("should persist data across GetGlobalStore calls", func() { globalStore := GetGlobalStore() responseID := "resp_global_test" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ID: responseID, Object: "response"} globalStore.Store(responseID, request, response) // Get store again globalStore2 := GetGlobalStore() stored, err := globalStore2.Get(responseID) Expect(err).ToNot(HaveOccurred()) Expect(stored).ToNot(BeNil()) }) }) Describe("Background Mode Support", func() { It("should store background response with cancel function", func() { responseID := "resp_bg_test" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ ID: responseID, Object: "response", Status: schema.ORStatusQueued, } _, cancel := context.WithCancel(context.Background()) defer cancel() store.StoreBackground(responseID, request, response, cancel, true) stored, err := store.Get(responseID) Expect(err).ToNot(HaveOccurred()) Expect(stored).ToNot(BeNil()) Expect(stored.IsBackground).To(BeTrue()) Expect(stored.StreamEnabled).To(BeTrue()) Expect(stored.CancelFunc).ToNot(BeNil()) }) It("should update status of stored response", func() { responseID := "resp_status_test" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ ID: responseID, Object: "response", Status: schema.ORStatusQueued, } store.Store(responseID, request, response) err := store.UpdateStatus(responseID, schema.ORStatusInProgress, nil) Expect(err).ToNot(HaveOccurred()) stored, err := store.Get(responseID) Expect(err).ToNot(HaveOccurred()) Expect(stored.Response.Status).To(Equal(schema.ORStatusInProgress)) }) It("should append and retrieve streaming events", func() { responseID := "resp_events_test" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ ID: responseID, Object: "response", Status: schema.ORStatusInProgress, } _, cancel := context.WithCancel(context.Background()) defer cancel() store.StoreBackground(responseID, request, response, cancel, true) // Append events event1 := &schema.ORStreamEvent{ Type: "response.created", SequenceNumber: 0, } event2 := &schema.ORStreamEvent{ Type: "response.in_progress", SequenceNumber: 1, } event3 := &schema.ORStreamEvent{ Type: "response.output_text.delta", SequenceNumber: 2, } err := store.AppendEvent(responseID, event1) Expect(err).ToNot(HaveOccurred()) err = store.AppendEvent(responseID, event2) Expect(err).ToNot(HaveOccurred()) err = store.AppendEvent(responseID, event3) Expect(err).ToNot(HaveOccurred()) // Get all events after -1 (all events) events, err := store.GetEventsAfter(responseID, -1) Expect(err).ToNot(HaveOccurred()) Expect(events).To(HaveLen(3)) // Get events after sequence 1 events, err = store.GetEventsAfter(responseID, 1) Expect(err).ToNot(HaveOccurred()) Expect(events).To(HaveLen(1)) Expect(events[0].SequenceNumber).To(Equal(2)) }) It("should cancel an in-progress response", func() { responseID := "resp_cancel_test" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ ID: responseID, Object: "response", Status: schema.ORStatusInProgress, } _, cancel := context.WithCancel(context.Background()) defer cancel() store.StoreBackground(responseID, request, response, cancel, false) // Cancel the response cancelledResponse, err := store.Cancel(responseID) Expect(err).ToNot(HaveOccurred()) Expect(cancelledResponse.Status).To(Equal(schema.ORStatusCancelled)) Expect(cancelledResponse.CompletedAt).ToNot(BeNil()) }) It("should be idempotent when cancelling already completed response", func() { responseID := "resp_idempotent_cancel" request := &schema.OpenResponsesRequest{Model: "test"} completedAt := time.Now().Unix() response := &schema.ORResponseResource{ ID: responseID, Object: "response", Status: schema.ORStatusCompleted, CompletedAt: &completedAt, } store.Store(responseID, request, response) // Try to cancel a completed response cancelledResponse, err := store.Cancel(responseID) Expect(err).ToNot(HaveOccurred()) // Status should remain completed (not changed to cancelled) Expect(cancelledResponse.Status).To(Equal(schema.ORStatusCompleted)) }) It("should check if streaming is enabled", func() { responseID := "resp_stream_check" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ ID: responseID, Object: "response", Status: schema.ORStatusQueued, } _, cancel := context.WithCancel(context.Background()) defer cancel() store.StoreBackground(responseID, request, response, cancel, true) enabled, err := store.IsStreamEnabled(responseID) Expect(err).ToNot(HaveOccurred()) Expect(enabled).To(BeTrue()) // Store another without streaming responseID2 := "resp_no_stream" store.StoreBackground(responseID2, request, response, cancel, false) enabled2, err := store.IsStreamEnabled(responseID2) Expect(err).ToNot(HaveOccurred()) Expect(enabled2).To(BeFalse()) }) It("should notify subscribers of new events", func() { responseID := "resp_events_chan" request := &schema.OpenResponsesRequest{Model: "test"} response := &schema.ORResponseResource{ ID: responseID, Object: "response", Status: schema.ORStatusInProgress, } _, cancel := context.WithCancel(context.Background()) defer cancel() store.StoreBackground(responseID, request, response, cancel, true) eventsChan, err := store.GetEventsChan(responseID) Expect(err).ToNot(HaveOccurred()) Expect(eventsChan).ToNot(BeNil()) // Append an event event := &schema.ORStreamEvent{ Type: "response.output_text.delta", SequenceNumber: 0, } go func() { time.Sleep(10 * time.Millisecond) store.AppendEvent(responseID, event) }() // Wait for notification select { case <-eventsChan: // Event received case <-time.After(1 * time.Second): Fail("Timeout waiting for event notification") } }) }) }) ================================================ FILE: core/http/endpoints/openresponses/websocket.go ================================================ package openresponses import ( "context" "encoding/json" "fmt" "net/http" "sync" "time" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/xlog" ) const ( wsMaxMessageSize = 10 * 1024 * 1024 // 10MB wsConnectionLimit = 60 * time.Minute ) var wsUpgrader = websocket.Upgrader{ CheckOrigin: func(r *http.Request) bool { return true }, } // lockedConn wraps a websocket connection with a mutex for safe concurrent writes type lockedConn struct { *websocket.Conn sync.Mutex } func (lc *lockedConn) writeJSON(v any) error { lc.Lock() defer lc.Unlock() return lc.Conn.WriteJSON(v) } // WebSocketEndpoint handles WebSocket mode for the Responses API. // Clients connect via ws://:/v1/responses and send response.create messages. // Events are streamed back over the WebSocket connection instead of SSE. func WebSocketEndpoint(application *application.Application) echo.HandlerFunc { cl := application.ModelConfigLoader() ml := application.ModelLoader() evaluator := application.TemplatesEvaluator() appConfig := application.ApplicationConfig() return func(c echo.Context) error { ws, err := wsUpgrader.Upgrade(c.Response(), c.Request(), nil) if err != nil { return err } defer ws.Close() ws.SetReadLimit(wsMaxMessageSize) // Set absolute deadline so blocking ReadMessage unblocks after the limit deadline := time.Now().Add(wsConnectionLimit) ws.SetReadDeadline(deadline) ws.SetWriteDeadline(deadline) conn := &lockedConn{Conn: ws} // Context for cancelling in-flight work when the connection closes connCtx, connCancel := context.WithDeadline(context.Background(), deadline) defer connCancel() xlog.Debug("WebSocket Responses connection established", "address", ws.RemoteAddr().String()) handleWebSocketConnection(connCtx, conn, cl, ml, evaluator, appConfig) return nil } } // handleWebSocketConnection runs the read loop for a single WebSocket connection. func handleWebSocketConnection(connCtx context.Context, conn *lockedConn, cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) { // Track in-flight response to enforce one-at-a-time var inflight sync.Mutex // Read loop for { select { case <-connCtx.Done(): sendWSError(conn, "websocket_connection_limit_reached", "Connection exceeded maximum duration", "") return default: } _, msgBytes, err := conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseNormalClosure) { xlog.Debug("WebSocket Responses read error", "error", err) } return } // Parse the envelope to determine message type var envelope struct { Type string `json:"type"` } if err := json.Unmarshal(msgBytes, &envelope); err != nil { sendWSError(conn, "invalid_request", "invalid JSON message", "") continue } if envelope.Type != "response.create" { sendWSError(conn, "invalid_request", fmt.Sprintf("unsupported message type: %s", envelope.Type), "type") continue } // Parse the full request var wsMsg schema.ORWebSocketMessage if err := json.Unmarshal(msgBytes, &wsMsg); err != nil { sendWSError(conn, "invalid_request", fmt.Sprintf("failed to parse request: %v", err), "") continue } // Enforce one in-flight response at a time (non-blocking check) if !inflight.TryLock() { sendWSError(conn, "invalid_request", "a response is already in progress on this connection", "") continue } go func() { defer inflight.Unlock() handleWSResponseCreate(connCtx, conn, &wsMsg.OpenResponsesRequest, cl, ml, evaluator, appConfig) }() } } // handleWSResponseCreate processes a single response.create message and streams events over WebSocket. // It reuses the existing background stream infrastructure: the request is processed via // handleBackgroundStream which buffers events into the store, and a forwarder goroutine // reads those events and sends them over the WebSocket. func handleWSResponseCreate(connCtx context.Context, conn *lockedConn, input *schema.OpenResponsesRequest, cl *config.ModelConfigLoader, ml *model.ModelLoader, evaluator *templates.Evaluator, appConfig *config.ApplicationConfig) { createdAt := time.Now().Unix() responseID := fmt.Sprintf("resp_%s", uuid.New().String()) if input.Model == "" { sendWSError(conn, "invalid_request", "model is required", "model") return } // Resolve model configuration (same logic as middleware.SetModelAndConfig) cfg, err := cl.LoadModelConfigFileByNameDefaultOptions(input.Model, appConfig) if err != nil { xlog.Warn("WebSocket Responses: model config not found", "model", input.Model, "error", err) sendWSError(conn, "invalid_request", fmt.Sprintf("model not found: %s", input.Model), "model") return } if cfg.Model == "" { cfg.Model = input.Model } // Merge request params into config (same as mergeOpenResponsesRequestAndModelConfig) if err := middleware.MergeOpenResponsesConfig(cfg, input); err != nil { sendWSError(conn, "invalid_request", fmt.Sprintf("invalid configuration: %v", err), "") return } // Set up context with cancellation tied to connection lifetime reqCtx, reqCancel := context.WithCancel(connCtx) defer reqCancel() input.Context = reqCtx input.Cancel = reqCancel store := GetGlobalStore() if appConfig.OpenResponsesStoreTTL > 0 { store.SetTTL(appConfig.OpenResponsesStoreTTL) } shouldStore := true if input.Store != nil && !*input.Store { shouldStore = false } // Handle previous_response_id var messages []schema.Message if input.PreviousResponseID != "" { stored, err := store.Get(input.PreviousResponseID) if err != nil { sendWSErrorEvent(conn, "previous_response_not_found", fmt.Sprintf("previous response not found: %s", input.PreviousResponseID), "previous_response_id") return } previousInputMessages, err := convertORInputToMessages(stored.Request.Input, cfg) if err != nil { sendWSError(conn, "invalid_request", fmt.Sprintf("failed to convert previous input: %v", err), "") return } previousOutputMessages, err := convertOROutputItemsToMessages(stored.Response.Output) if err != nil { sendWSError(conn, "invalid_request", fmt.Sprintf("failed to convert previous response: %v", err), "") return } messages = previousInputMessages messages = append(messages, previousOutputMessages...) } // Convert current input to messages newMessages, err := convertORInputToMessages(input.Input, cfg) if err != nil { sendWSError(conn, "invalid_request", fmt.Sprintf("failed to parse input: %v", err), "") return } messages = append(messages, newMessages...) if input.Instructions != "" { messages = append([]schema.Message{{Role: "system", StringContent: input.Instructions}}, messages...) } // Handle tools var funcs functions.Functions var shouldUseFn bool if len(input.Tools) > 0 { funcs, shouldUseFn = convertORToolsToFunctions(input, cfg) } // Create OpenAI-compatible request openAIReq := &schema.OpenAIRequest{ PredictionOptions: schema.PredictionOptions{ BasicModelRequest: schema.BasicModelRequest{Model: input.Model}, Temperature: input.Temperature, TopP: input.TopP, Maxtokens: input.MaxOutputTokens, }, Messages: messages, Stream: true, // WebSocket mode always streams Context: reqCtx, Cancel: reqCancel, Functions: funcs, } if input.TextFormat != nil { openAIReq.ResponseFormat = convertTextFormatToResponseFormat(input.TextFormat) } // Generate grammar for function calling if shouldUseFn && !cfg.FunctionsConfig.GrammarConfig.NoGrammar { noActionName := "answer" noActionDescription := "use this action to answer without performing any action" if cfg.FunctionsConfig.NoActionFunctionName != "" { noActionName = cfg.FunctionsConfig.NoActionFunctionName } if cfg.FunctionsConfig.NoActionDescriptionName != "" { noActionDescription = cfg.FunctionsConfig.NoActionDescriptionName } noActionGrammar := functions.Function{ Name: noActionName, Description: noActionDescription, Parameters: map[string]interface{}{ "properties": map[string]interface{}{ "message": map[string]interface{}{ "type": "string", "description": "The message to reply the user with", }, }, }, } funcsWithNoAction := make(functions.Functions, len(funcs)) copy(funcsWithNoAction, funcs) if !cfg.FunctionsConfig.DisableNoAction { funcsWithNoAction = append(funcsWithNoAction, noActionGrammar) } if cfg.FunctionToCall() != "" { funcsWithNoAction = funcsWithNoAction.Select(cfg.FunctionToCall()) } jsStruct := funcsWithNoAction.ToJSONStructure(cfg.FunctionsConfig.FunctionNameKey, cfg.FunctionsConfig.FunctionNameKey) g, err := jsStruct.Grammar(cfg.FunctionsConfig.GrammarOptions()...) if err == nil { cfg.Grammar = g } else { xlog.Error("WebSocket Responses: failed generating grammar", "error", err) } } // Merge contiguous assistant messages openAIReq.Messages = mergeContiguousAssistantMessages(openAIReq.Messages) predInput := evaluator.TemplateMessages(*openAIReq, openAIReq.Messages, cfg, funcs, shouldUseFn) // Use the background stream infrastructure: store the request as a background task, // process it via handleBackgroundStream, and forward buffered events over WebSocket. queuedResponse := buildORResponse(responseID, createdAt, nil, schema.ORStatusQueued, input, []schema.ORItemField{}, nil, shouldStore) store.StoreBackground(responseID, input, queuedResponse, reqCancel, true) // Start processing in a goroutine processDone := make(chan struct{}) go func() { defer close(processDone) store.UpdateStatus(responseID, schema.ORStatusInProgress, nil) finalResponse, bgErr := handleBackgroundStream(reqCtx, store, responseID, createdAt, input, cfg, ml, cl, appConfig, predInput, openAIReq, funcs, shouldUseFn, nil, nil) if bgErr != nil { xlog.Error("WebSocket Responses: processing failed", "response_id", responseID, "error", bgErr) now := time.Now().Unix() store.UpdateStatus(responseID, schema.ORStatusFailed, &now) // Buffer an error event so the client sees the failure failedResponse := buildORResponse(responseID, createdAt, &now, schema.ORStatusFailed, input, []schema.ORItemField{}, nil, shouldStore) bufferEvent(store, responseID, &schema.ORStreamEvent{ Type: "response.failed", Response: failedResponse, Error: &schema.ORErrorPayload{ Type: "server_error", Message: bgErr.Error(), }, }) return } if finalResponse != nil { store.UpdateResponse(responseID, finalResponse) } }() // Forward events from the store to the WebSocket connection forwardEvents(reqCtx, conn, store, responseID, processDone, shouldStore) } // forwardEvents subscribes to events for a response and sends them over the WebSocket. // This mirrors handleStreamResume but writes JSON to WebSocket instead of SSE. func forwardEvents(ctx context.Context, conn *lockedConn, store *ResponseStore, responseID string, done <-chan struct{}, shouldStore bool) { eventsChan, err := store.GetEventsChan(responseID) if err != nil { return } lastSeq := -1 for { // Drain all available events events, err := store.GetEventsAfter(responseID, lastSeq) if err != nil { return } for _, event := range events { var parsed schema.ORStreamEvent if err := json.Unmarshal(event.Data, &parsed); err != nil { continue } if err := conn.writeJSON(&parsed); err != nil { return } lastSeq = event.SequenceNumber } // Check if processing is done and all events have been sent select { case <-done: // Drain any final events finalEvents, err := store.GetEventsAfter(responseID, lastSeq) if err == nil { for _, event := range finalEvents { var parsed schema.ORStreamEvent if err := json.Unmarshal(event.Data, &parsed); err != nil { continue } if err := conn.writeJSON(&parsed); err != nil { return } } } // Clean up non-stored responses from the cache if !shouldStore { store.Delete(responseID) } return default: } // Wait for new events, completion, or context cancellation select { case <-ctx.Done(): return case <-done: // Will drain in next iteration case <-eventsChan: // New events available } } } func sendWSError(conn *lockedConn, errType, message, param string) { event := schema.ORStreamEvent{ Type: "error", Error: &schema.ORErrorPayload{ Type: errType, Message: message, Param: param, }, } conn.writeJSON(&event) } func sendWSErrorEvent(conn *lockedConn, code, message, param string) { event := schema.ORStreamEvent{ Type: "error", Error: &schema.ORErrorPayload{ Type: "invalid_request_error", Code: code, Message: message, Param: param, }, } conn.writeJSON(&event) } ================================================ FILE: core/http/explorer.go ================================================ package http import ( "io/fs" "net/http" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/explorer" "github.com/mudler/LocalAI/core/http/middleware" "github.com/mudler/LocalAI/core/http/routes" "github.com/mudler/xlog" ) func Explorer(db *explorer.Database) *echo.Echo { e := echo.New() // Set renderer e.Renderer = renderEngine() // Hide banner e.HideBanner = true e.Pre(middleware.StripPathPrefix()) routes.RegisterExplorerRoutes(e, db) // Favicon handler e.GET("/favicon.svg", func(c echo.Context) error { data, err := embedDirStatic.ReadFile("static/favicon.svg") if err != nil { return c.NoContent(http.StatusNotFound) } c.Response().Header().Set("Content-Type", "image/svg+xml") return c.Blob(http.StatusOK, "image/svg+xml", data) }) // Static files - use fs.Sub to create a filesystem rooted at "static" staticFS, err := fs.Sub(embedDirStatic, "static") if err != nil { // Log error but continue - static files might not work xlog.Error("failed to create static filesystem", "error", err) } else { e.StaticFS("/static", staticFS) } // Define a custom 404 handler // Note: keep this at the bottom! e.GET("/*", notFoundHandler) return e } ================================================ FILE: core/http/http_suite_test.go ================================================ package http_test import ( "os" "path/filepath" "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var ( tmpdir string modelDir string ) func TestLocalAI(t *testing.T) { RegisterFailHandler(Fail) var err error tmpdir, err = os.MkdirTemp("", "") Expect(err).ToNot(HaveOccurred()) modelDir = filepath.Join(tmpdir, "models") err = os.Mkdir(modelDir, 0750) Expect(err).ToNot(HaveOccurred()) AfterSuite(func() { err := os.RemoveAll(tmpdir) Expect(err).ToNot(HaveOccurred()) }) RunSpecs(t, "LocalAI HTTP test suite") } ================================================ FILE: core/http/middleware/auth.go ================================================ package middleware import ( "crypto/subtle" "errors" "net/http" "strings" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" ) var ErrMissingOrMalformedAPIKey = errors.New("missing or malformed API Key") // GetKeyAuthConfig returns Echo's KeyAuth middleware configuration func GetKeyAuthConfig(applicationConfig *config.ApplicationConfig) (echo.MiddlewareFunc, error) { // Create validator function validator := getApiKeyValidationFunction(applicationConfig) // Create error handler errorHandler := getApiKeyErrorHandler(applicationConfig) // Create Next function (skip middleware for certain requests) skipper := getApiKeyRequiredFilterFunction(applicationConfig) // Wrap it with our custom key lookup that checks multiple sources return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if len(applicationConfig.ApiKeys) == 0 { return next(c) } // Skip if skipper says so if skipper != nil && skipper(c) { return next(c) } // Try to extract key from multiple sources key, err := extractKeyFromMultipleSources(c) if err != nil { return errorHandler(err, c) } // Validate the key valid, err := validator(key, c) if err != nil || !valid { return errorHandler(ErrMissingOrMalformedAPIKey, c) } // Store key in context for later use c.Set("api_key", key) return next(c) } }, nil } // extractKeyFromMultipleSources checks multiple sources for the API key // in order: Authorization header, x-api-key header, xi-api-key header, token cookie func extractKeyFromMultipleSources(c echo.Context) (string, error) { // Check Authorization header first auth := c.Request().Header.Get("Authorization") if auth != "" { // Check for Bearer scheme if strings.HasPrefix(auth, "Bearer ") { return strings.TrimPrefix(auth, "Bearer "), nil } // If no Bearer prefix, return as-is (for backward compatibility) return auth, nil } // Check x-api-key header if key := c.Request().Header.Get("x-api-key"); key != "" { return key, nil } // Check xi-api-key header if key := c.Request().Header.Get("xi-api-key"); key != "" { return key, nil } // Check token cookie cookie, err := c.Cookie("token") if err == nil && cookie != nil && cookie.Value != "" { return cookie.Value, nil } return "", ErrMissingOrMalformedAPIKey } func getApiKeyErrorHandler(applicationConfig *config.ApplicationConfig) func(error, echo.Context) error { return func(err error, c echo.Context) error { if errors.Is(err, ErrMissingOrMalformedAPIKey) { if len(applicationConfig.ApiKeys) == 0 { return nil // if no keys are set up, any error we get here is not an error. } c.Response().Header().Set("WWW-Authenticate", "Bearer") if applicationConfig.OpaqueErrors { return c.NoContent(http.StatusUnauthorized) } // Check if the request content type is JSON contentType := c.Request().Header.Get("Content-Type") if strings.Contains(contentType, "application/json") { return c.JSON(http.StatusUnauthorized, schema.ErrorResponse{ Error: &schema.APIError{ Message: "An authentication key is required", Code: 401, Type: "invalid_request_error", }, }) } return c.Render(http.StatusUnauthorized, "views/login", map[string]interface{}{ "BaseURL": BaseURL(c), }) } if applicationConfig.OpaqueErrors { return c.NoContent(http.StatusInternalServerError) } return err } } func getApiKeyValidationFunction(applicationConfig *config.ApplicationConfig) func(string, echo.Context) (bool, error) { if applicationConfig.UseSubtleKeyComparison { return func(key string, c echo.Context) (bool, error) { if len(applicationConfig.ApiKeys) == 0 { return true, nil // If no keys are setup, accept everything } for _, validKey := range applicationConfig.ApiKeys { if subtle.ConstantTimeCompare([]byte(key), []byte(validKey)) == 1 { return true, nil } } return false, ErrMissingOrMalformedAPIKey } } return func(key string, c echo.Context) (bool, error) { if len(applicationConfig.ApiKeys) == 0 { return true, nil // If no keys are setup, accept everything } for _, validKey := range applicationConfig.ApiKeys { if key == validKey { return true, nil } } return false, ErrMissingOrMalformedAPIKey } } func getApiKeyRequiredFilterFunction(applicationConfig *config.ApplicationConfig) middleware.Skipper { return func(c echo.Context) bool { path := c.Request().URL.Path for _, p := range applicationConfig.PathWithoutAuth { if strings.HasPrefix(path, p) { return true } } // Handle GET request exemptions if enabled if applicationConfig.DisableApiKeyRequirementForHttpGet { if c.Request().Method != http.MethodGet { return false } for _, rx := range applicationConfig.HttpGetExemptedEndpoints { if rx.MatchString(c.Path()) { return true } } } return false } } ================================================ FILE: core/http/middleware/auth_test.go ================================================ package middleware_test import ( "net/http" "net/http/httptest" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/http/middleware" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) // ok is a simple handler that returns 200 OK. func ok(c echo.Context) error { return c.String(http.StatusOK, "ok") } // newAuthApp creates a minimal Echo app with auth middleware applied. // Requests that fail auth with Content-Type: application/json get a JSON 401 // (no template renderer needed). func newAuthApp(appConfig *config.ApplicationConfig) *echo.Echo { e := echo.New() mw, err := GetKeyAuthConfig(appConfig) Expect(err).ToNot(HaveOccurred()) e.Use(mw) // Sensitive API routes e.GET("/v1/models", ok) e.POST("/v1/chat/completions", ok) // UI routes e.GET("/app", ok) e.GET("/app/*", ok) e.GET("/browse", ok) e.GET("/browse/*", ok) e.GET("/login", ok) e.GET("/explorer", ok) e.GET("/assets/*", ok) e.POST("/app", ok) return e } // doRequest performs an HTTP request against the given Echo app and returns the recorder. func doRequest(e *echo.Echo, method, path string, opts ...func(*http.Request)) *httptest.ResponseRecorder { req := httptest.NewRequest(method, path, nil) req.Header.Set("Content-Type", "application/json") for _, opt := range opts { opt(req) } rec := httptest.NewRecorder() e.ServeHTTP(rec, req) return rec } func withBearerToken(token string) func(*http.Request) { return func(req *http.Request) { req.Header.Set("Authorization", "Bearer "+token) } } func withXApiKey(key string) func(*http.Request) { return func(req *http.Request) { req.Header.Set("x-api-key", key) } } func withXiApiKey(key string) func(*http.Request) { return func(req *http.Request) { req.Header.Set("xi-api-key", key) } } func withTokenCookie(token string) func(*http.Request) { return func(req *http.Request) { req.AddCookie(&http.Cookie{Name: "token", Value: token}) } } var _ = Describe("Auth Middleware", func() { Context("when API keys are configured", func() { var app *echo.Echo const validKey = "sk-test-key-123" BeforeEach(func() { appConfig := config.NewApplicationConfig() appConfig.ApiKeys = []string{validKey} app = newAuthApp(appConfig) }) It("returns 401 for GET request without a key", func() { rec := doRequest(app, http.MethodGet, "/v1/models") Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("returns 401 for POST request without a key", func() { rec := doRequest(app, http.MethodPost, "/v1/chat/completions") Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("returns 401 for request with an invalid key", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken("wrong-key")) Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("passes through with valid Bearer token in Authorization header", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withBearerToken(validKey)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("passes through with valid x-api-key header", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withXApiKey(validKey)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("passes through with valid xi-api-key header", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withXiApiKey(validKey)) Expect(rec.Code).To(Equal(http.StatusOK)) }) It("passes through with valid token cookie", func() { rec := doRequest(app, http.MethodGet, "/v1/models", withTokenCookie(validKey)) Expect(rec.Code).To(Equal(http.StatusOK)) }) }) Context("when no API keys are configured", func() { var app *echo.Echo BeforeEach(func() { appConfig := config.NewApplicationConfig() app = newAuthApp(appConfig) }) It("passes through without any key", func() { rec := doRequest(app, http.MethodGet, "/v1/models") Expect(rec.Code).To(Equal(http.StatusOK)) }) }) Context("GET exempted endpoints (feature enabled)", func() { var app *echo.Echo const validKey = "sk-test-key-456" BeforeEach(func() { appConfig := config.NewApplicationConfig( config.WithApiKeys([]string{validKey}), config.WithDisableApiKeyRequirementForHttpGet(true), config.WithHttpGetExemptedEndpoints([]string{ "^/$", "^/app(/.*)?$", "^/browse(/.*)?$", "^/login/?$", "^/explorer/?$", "^/assets/.*$", "^/static/.*$", "^/swagger.*$", }), ) app = newAuthApp(appConfig) }) It("allows GET to /app without a key", func() { rec := doRequest(app, http.MethodGet, "/app") Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows GET to /app/chat/model sub-route without a key", func() { rec := doRequest(app, http.MethodGet, "/app/chat/llama3") Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows GET to /browse/models without a key", func() { rec := doRequest(app, http.MethodGet, "/browse/models") Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows GET to /login without a key", func() { rec := doRequest(app, http.MethodGet, "/login") Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows GET to /explorer without a key", func() { rec := doRequest(app, http.MethodGet, "/explorer") Expect(rec.Code).To(Equal(http.StatusOK)) }) It("allows GET to /assets/main.js without a key", func() { rec := doRequest(app, http.MethodGet, "/assets/main.js") Expect(rec.Code).To(Equal(http.StatusOK)) }) It("rejects POST to /app without a key", func() { rec := doRequest(app, http.MethodPost, "/app") Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) It("rejects GET to /v1/models without a key", func() { rec := doRequest(app, http.MethodGet, "/v1/models") Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) }) Context("GET exempted endpoints (feature disabled)", func() { var app *echo.Echo const validKey = "sk-test-key-789" BeforeEach(func() { appConfig := config.NewApplicationConfig( config.WithApiKeys([]string{validKey}), // DisableApiKeyRequirementForHttpGet defaults to false config.WithHttpGetExemptedEndpoints([]string{ "^/$", "^/app(/.*)?$", }), ) app = newAuthApp(appConfig) }) It("requires auth for GET to /app even though it matches exempted pattern", func() { rec := doRequest(app, http.MethodGet, "/app") Expect(rec.Code).To(Equal(http.StatusUnauthorized)) }) }) }) ================================================ FILE: core/http/middleware/baseurl.go ================================================ package middleware import ( "strings" "github.com/labstack/echo/v4" ) // BaseURL returns the base URL for the given HTTP request context. // It takes into account that the app may be exposed by a reverse-proxy under a different protocol, host and path. // The returned URL is guaranteed to end with `/`. // The method should be used in conjunction with the StripPathPrefix middleware. func BaseURL(c echo.Context) string { path := c.Path() origPath := c.Request().URL.Path // Check if StripPathPrefix middleware stored the original path if storedPath, ok := c.Get("_original_path").(string); ok && storedPath != "" { origPath = storedPath } // Check X-Forwarded-Proto for scheme scheme := "http" if c.Request().Header.Get("X-Forwarded-Proto") == "https" { scheme = "https" } else if c.Request().TLS != nil { scheme = "https" } // Check X-Forwarded-Host for host host := c.Request().Host if forwardedHost := c.Request().Header.Get("X-Forwarded-Host"); forwardedHost != "" { host = forwardedHost } if path != origPath && strings.HasSuffix(origPath, path) && len(path) > 0 { prefixLen := len(origPath) - len(path) if prefixLen > 0 && prefixLen <= len(origPath) { pathPrefix := origPath[:prefixLen] if !strings.HasSuffix(pathPrefix, "/") { pathPrefix += "/" } return scheme + "://" + host + pathPrefix } } return scheme + "://" + host + "/" } ================================================ FILE: core/http/middleware/baseurl_test.go ================================================ package middleware import ( "net/http/httptest" "github.com/labstack/echo/v4" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("BaseURL", func() { Context("without prefix", func() { It("should return base URL without prefix", func() { app := echo.New() actualURL := "" // Register route - use the actual request path so routing works routePath := "/hello/world" app.GET(routePath, func(c echo.Context) error { actualURL = BaseURL(c) return nil }) req := httptest.NewRequest("GET", "/hello/world", nil) rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(200), "response status code") Expect(actualURL).To(Equal("http://example.com/"), "base URL") }) }) Context("with prefix", func() { It("should return base URL with prefix", func() { app := echo.New() actualURL := "" // Register route with the stripped path (after middleware removes prefix) routePath := "/hello/world" app.GET(routePath, func(c echo.Context) error { // Simulate what StripPathPrefix middleware does - store original path c.Set("_original_path", "/myprefix/hello/world") // Modify the request path to simulate prefix stripping c.Request().URL.Path = "/hello/world" actualURL = BaseURL(c) return nil }) // Make request with stripped path (middleware would have already processed it) req := httptest.NewRequest("GET", "/hello/world", nil) rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(200), "response status code") Expect(actualURL).To(Equal("http://example.com/myprefix/"), "base URL") }) }) }) ================================================ FILE: core/http/middleware/middleware_suite_test.go ================================================ package middleware_test import ( "testing" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) func TestMiddleware(t *testing.T) { RegisterFailHandler(Fail) RunSpecs(t, "Middleware test suite") } ================================================ FILE: core/http/middleware/request.go ================================================ package middleware import ( "context" "encoding/json" "fmt" "net/http" "strconv" "strings" "github.com/google/uuid" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/schema" "github.com/mudler/LocalAI/core/services" "github.com/mudler/LocalAI/core/templates" "github.com/mudler/LocalAI/pkg/functions" "github.com/mudler/LocalAI/pkg/model" "github.com/mudler/LocalAI/pkg/utils" "github.com/mudler/xlog" ) type correlationIDKeyType string // CorrelationIDKey to track request across process boundary const CorrelationIDKey correlationIDKeyType = "correlationID" type RequestExtractor struct { modelConfigLoader *config.ModelConfigLoader modelLoader *model.ModelLoader applicationConfig *config.ApplicationConfig } func NewRequestExtractor(modelConfigLoader *config.ModelConfigLoader, modelLoader *model.ModelLoader, applicationConfig *config.ApplicationConfig) *RequestExtractor { return &RequestExtractor{ modelConfigLoader: modelConfigLoader, modelLoader: modelLoader, applicationConfig: applicationConfig, } } const CONTEXT_LOCALS_KEY_MODEL_NAME = "MODEL_NAME" const CONTEXT_LOCALS_KEY_LOCALAI_REQUEST = "LOCALAI_REQUEST" const CONTEXT_LOCALS_KEY_MODEL_CONFIG = "MODEL_CONFIG" // TODO: Refactor to not return error if unchanged func (re *RequestExtractor) setModelNameFromRequest(c echo.Context) { model, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) if ok && model != "" { return } model = c.Param("model") if model == "" { model = c.QueryParam("model") } // Check FormValue for multipart/form-data requests (e.g., /v1/images/inpainting) if model == "" { model = c.FormValue("model") } if model == "" { // Set model from bearer token, if available auth := c.Request().Header.Get("Authorization") bearer := strings.TrimPrefix(auth, "Bearer ") if bearer != "" && bearer != auth { exists, err := services.CheckIfModelExists(re.modelConfigLoader, re.modelLoader, bearer, services.ALWAYS_INCLUDE) if err == nil && exists { model = bearer } } } c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, model) } func (re *RequestExtractor) BuildConstantDefaultModelNameMiddleware(defaultModelName string) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { re.setModelNameFromRequest(c) localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) if !ok || localModelName == "" { c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, defaultModelName) xlog.Debug("context local model name not found, setting to default", "defaultModelName", defaultModelName) } return next(c) } } } func (re *RequestExtractor) BuildFilteredFirstAvailableDefaultModel(filterFn config.ModelConfigFilterFn) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { re.setModelNameFromRequest(c) localModelName := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) if localModelName != "" { // Don't overwrite existing values return next(c) } modelNames, err := services.ListModels(re.modelConfigLoader, re.modelLoader, filterFn, services.SKIP_IF_CONFIGURED) if err != nil { xlog.Error("non-fatal error calling ListModels during SetDefaultModelNameToFirstAvailable()", "error", err) return next(c) } if len(modelNames) == 0 { xlog.Warn("SetDefaultModelNameToFirstAvailable used with no matching models installed") // This is non-fatal - making it so was breaking the case of direct installation of raw models // return errors.New("this endpoint requires at least one model to be installed") return next(c) } c.Set(CONTEXT_LOCALS_KEY_MODEL_NAME, modelNames[0]) xlog.Debug("context local model name not found, setting to the first model", "first model name", modelNames[0]) return next(c) } } } // TODO: If context and cancel above belong on all methods, move that part of above into here! // Otherwise, it's in its own method below for now func (re *RequestExtractor) SetModelAndConfig(initializer func() schema.LocalAIRequest) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { input := initializer() if input == nil { return echo.NewHTTPError(http.StatusBadRequest, "unable to initialize body") } if err := c.Bind(input); err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("failed parsing request body: %v", err)) } // If this request doesn't have an associated model name, fetch it from earlier in the middleware chain if input.ModelName(nil) == "" { localModelName, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_NAME).(string) if ok && localModelName != "" { xlog.Debug("overriding empty model name in request body with value found earlier in middleware chain", "context localModelName", localModelName) input.ModelName(&localModelName) } } cfg, err := re.modelConfigLoader.LoadModelConfigFileByNameDefaultOptions(input.ModelName(nil), re.applicationConfig) if err != nil { xlog.Warn("Model Configuration File not found", "model", input.ModelName(nil), "error", err) } else if cfg.Model == "" && input.ModelName(nil) != "" { xlog.Debug("config does not include model, using input", "input.ModelName", input.ModelName(nil)) cfg.Model = input.ModelName(nil) } c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) return next(c) } } } func (re *RequestExtractor) SetOpenAIRequest(c echo.Context) error { input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenAIRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } // Extract or generate the correlation ID correlationID := c.Request().Header.Get("X-Correlation-ID") if correlationID == "" { correlationID = uuid.New().String() } c.Response().Header().Set("X-Correlation-ID", correlationID) // Use the request context directly - Echo properly supports context cancellation! // No need for workarounds like handleConnectionCancellation reqCtx := c.Request().Context() c1, cancel := context.WithCancel(re.applicationConfig.Context) // Cancel when request context is cancelled (client disconnects) go func() { select { case <-reqCtx.Done(): cancel() case <-c1.Done(): // Already cancelled } }() // Add the correlation ID to the new context ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) input.Context = ctxWithCorrelationID input.Cancel = cancel err := mergeOpenAIRequestAndModelConfig(cfg, input) if err != nil { return err } if cfg.Model == "" { xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) cfg.Model = input.Model } c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) return nil } func mergeOpenAIRequestAndModelConfig(config *config.ModelConfig, input *schema.OpenAIRequest) error { if input.Echo { config.Echo = input.Echo } if input.TopK != nil { config.TopK = input.TopK } if input.TopP != nil { config.TopP = input.TopP } if input.Backend != "" { config.Backend = input.Backend } if input.ClipSkip != 0 { config.Diffusers.ClipSkip = input.ClipSkip } if input.NegativePromptScale != 0 { config.NegativePromptScale = input.NegativePromptScale } if input.NegativePrompt != "" { config.NegativePrompt = input.NegativePrompt } if input.RopeFreqBase != 0 { config.RopeFreqBase = input.RopeFreqBase } if input.RopeFreqScale != 0 { config.RopeFreqScale = input.RopeFreqScale } if input.Grammar != "" { config.Grammar = input.Grammar } if input.Temperature != nil { config.Temperature = input.Temperature } if input.Maxtokens != nil { config.Maxtokens = input.Maxtokens } if input.ResponseFormat != nil { switch responseFormat := input.ResponseFormat.(type) { case string: config.ResponseFormat = responseFormat case map[string]interface{}: config.ResponseFormatMap = responseFormat } } switch stop := input.Stop.(type) { case string: if stop != "" { config.StopWords = append(config.StopWords, stop) } case []interface{}: for _, pp := range stop { if s, ok := pp.(string); ok { config.StopWords = append(config.StopWords, s) } } } if len(input.Tools) > 0 { for _, tool := range input.Tools { input.Functions = append(input.Functions, tool.Function) } } if input.ToolsChoice != nil { var toolChoice functions.Tool switch content := input.ToolsChoice.(type) { case string: _ = json.Unmarshal([]byte(content), &toolChoice) case map[string]interface{}: dat, _ := json.Marshal(content) _ = json.Unmarshal(dat, &toolChoice) } input.FunctionCall = map[string]interface{}{ "name": toolChoice.Function.Name, } } // Decode each request's message content imgIndex, vidIndex, audioIndex := 0, 0, 0 for i, m := range input.Messages { nrOfImgsInMessage := 0 nrOfVideosInMessage := 0 nrOfAudiosInMessage := 0 switch content := m.Content.(type) { case string: input.Messages[i].StringContent = content case []interface{}: dat, _ := json.Marshal(content) c := []schema.Content{} json.Unmarshal(dat, &c) textContent := "" // we will template this at the end CONTENT: for _, pp := range c { switch pp.Type { case "text": textContent += pp.Text //input.Messages[i].StringContent = pp.Text case "video", "video_url": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.VideoURL.URL) if err != nil { xlog.Error("Failed encoding video", "error", err) continue CONTENT } input.Messages[i].StringVideos = append(input.Messages[i].StringVideos, base64) // TODO: make sure that we only return base64 stuff vidIndex++ nrOfVideosInMessage++ case "audio_url", "audio": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.AudioURL.URL) if err != nil { xlog.Error("Failed encoding audio", "error", err) continue CONTENT } input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, base64) // TODO: make sure that we only return base64 stuff audioIndex++ nrOfAudiosInMessage++ case "input_audio": // TODO: make sure that we only return base64 stuff input.Messages[i].StringAudios = append(input.Messages[i].StringAudios, pp.InputAudio.Data) audioIndex++ nrOfAudiosInMessage++ case "image_url", "image": // Decode content as base64 either if it's an URL or base64 text base64, err := utils.GetContentURIAsBase64(pp.ImageURL.URL) if err != nil { xlog.Error("Failed encoding image", "error", err) continue CONTENT } input.Messages[i].StringImages = append(input.Messages[i].StringImages, base64) // TODO: make sure that we only return base64 stuff imgIndex++ nrOfImgsInMessage++ } } input.Messages[i].StringContent, _ = templates.TemplateMultiModal(config.TemplateConfig.Multimodal, templates.MultiModalOptions{ TotalImages: imgIndex, TotalVideos: vidIndex, TotalAudios: audioIndex, ImagesInMessage: nrOfImgsInMessage, VideosInMessage: nrOfVideosInMessage, AudiosInMessage: nrOfAudiosInMessage, }, textContent) } } if input.RepeatPenalty != 0 { config.RepeatPenalty = input.RepeatPenalty } if input.FrequencyPenalty != 0 { config.FrequencyPenalty = input.FrequencyPenalty } if input.PresencePenalty != 0 { config.PresencePenalty = input.PresencePenalty } if input.Keep != 0 { config.Keep = input.Keep } if input.Batch != 0 { config.Batch = input.Batch } if input.IgnoreEOS { config.IgnoreEOS = input.IgnoreEOS } if input.Seed != nil { config.Seed = input.Seed } if input.TypicalP != nil { config.TypicalP = input.TypicalP } xlog.Debug("input.Input", "input", fmt.Sprintf("%+v", input.Input)) switch inputs := input.Input.(type) { case string: if inputs != "" { config.InputStrings = append(config.InputStrings, inputs) } case []any: for _, pp := range inputs { switch i := pp.(type) { case string: config.InputStrings = append(config.InputStrings, i) case []any: tokens := []int{} inputStrings := []string{} for _, ii := range i { switch ii := ii.(type) { case int: tokens = append(tokens, ii) case float64: tokens = append(tokens, int(ii)) case string: inputStrings = append(inputStrings, ii) default: xlog.Error("Unknown input type", "type", fmt.Sprintf("%T", ii)) } } config.InputToken = append(config.InputToken, tokens) config.InputStrings = append(config.InputStrings, inputStrings...) } } } // Can be either a string or an object switch fnc := input.FunctionCall.(type) { case string: if fnc != "" { config.SetFunctionCallString(fnc) } case map[string]interface{}: var name string n, exists := fnc["name"] if exists { nn, e := n.(string) if e { name = nn } } config.SetFunctionCallNameString(name) } switch p := input.Prompt.(type) { case string: config.PromptStrings = append(config.PromptStrings, p) case []interface{}: for _, pp := range p { if s, ok := pp.(string); ok { config.PromptStrings = append(config.PromptStrings, s) } } } // If a quality was defined as number, convert it to step if input.Quality != "" { q, err := strconv.Atoi(input.Quality) if err == nil { config.Step = q } } if valid, _ := config.Validate(); valid { return nil } return fmt.Errorf("unable to validate configuration after merging") } func (re *RequestExtractor) SetOpenResponsesRequest(c echo.Context) error { input, ok := c.Get(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST).(*schema.OpenResponsesRequest) if !ok || input.Model == "" { return echo.ErrBadRequest } // Convert input items to Messages (this will be done in the endpoint handler) // We store the input in the request for the endpoint to process cfg, ok := c.Get(CONTEXT_LOCALS_KEY_MODEL_CONFIG).(*config.ModelConfig) if !ok || cfg == nil { return echo.ErrBadRequest } // Extract or generate the correlation ID (Open Responses uses x-request-id) correlationID := c.Request().Header.Get("x-request-id") if correlationID == "" { correlationID = uuid.New().String() } c.Response().Header().Set("x-request-id", correlationID) // Use the request context directly - Echo properly supports context cancellation! reqCtx := c.Request().Context() c1, cancel := context.WithCancel(re.applicationConfig.Context) // Cancel when request context is cancelled (client disconnects) go func() { select { case <-reqCtx.Done(): cancel() case <-c1.Done(): // Already cancelled } }() // Add the correlation ID to the new context ctxWithCorrelationID := context.WithValue(c1, CorrelationIDKey, correlationID) input.Context = ctxWithCorrelationID input.Cancel = cancel err := MergeOpenResponsesConfig(cfg, input) if err != nil { return err } if cfg.Model == "" { xlog.Debug("replacing empty cfg.Model with input value", "input.Model", input.Model) cfg.Model = input.Model } c.Set(CONTEXT_LOCALS_KEY_LOCALAI_REQUEST, input) c.Set(CONTEXT_LOCALS_KEY_MODEL_CONFIG, cfg) return nil } // MergeOpenResponsesConfig merges request parameters into the model configuration. func MergeOpenResponsesConfig(config *config.ModelConfig, input *schema.OpenResponsesRequest) error { // Temperature if input.Temperature != nil { config.Temperature = input.Temperature } // TopP if input.TopP != nil { config.TopP = input.TopP } // MaxOutputTokens -> Maxtokens if input.MaxOutputTokens != nil { config.Maxtokens = input.MaxOutputTokens } // Convert tools to functions - this will be handled in the endpoint handler // We just validate that tools are present if needed // Handle tool_choice if input.ToolChoice != nil { switch tc := input.ToolChoice.(type) { case string: // "auto", "required", or "none" if tc == "required" { config.SetFunctionCallString("required") } else if tc == "none" { // Don't use tools - handled in endpoint } // "auto" is default - let model decide case map[string]interface{}: // Specific tool: {type:"function", name:"..."} if tcType, ok := tc["type"].(string); ok && tcType == "function" { if name, ok := tc["name"].(string); ok { config.SetFunctionCallString(name) } } } } if valid, _ := config.Validate(); valid { return nil } return fmt.Errorf("unable to validate configuration after merging") } ================================================ FILE: core/http/middleware/strippathprefix.go ================================================ package middleware import ( "strings" "github.com/labstack/echo/v4" ) // StripPathPrefix returns middleware that strips a path prefix from the request path. // The path prefix is obtained from the X-Forwarded-Prefix HTTP request header. // This must be registered as Pre middleware (using e.Pre()) to modify the path before routing. func StripPathPrefix() echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { prefixes := c.Request().Header.Values("X-Forwarded-Prefix") originalPath := c.Request().URL.Path for _, prefix := range prefixes { if prefix != "" { normalizedPrefix := prefix if !strings.HasSuffix(prefix, "/") { normalizedPrefix = prefix + "/" } if strings.HasPrefix(originalPath, normalizedPrefix) { // Update the request path by stripping the normalized prefix newPath := originalPath[len(normalizedPrefix):] if newPath == "" { newPath = "/" } // Ensure path starts with / for proper routing if !strings.HasPrefix(newPath, "/") { newPath = "/" + newPath } // Update the URL path - Echo's router uses URL.Path for routing c.Request().URL.Path = newPath c.Request().URL.RawPath = "" // Update RequestURI to match the new path (needed for proper routing) if c.Request().URL.RawQuery != "" { c.Request().RequestURI = newPath + "?" + c.Request().URL.RawQuery } else { c.Request().RequestURI = newPath } // Store original path for BaseURL utility c.Set("_original_path", originalPath) break } else if originalPath == prefix || originalPath == prefix+"/" { // Redirect to prefix with trailing slash (use 302 to match test expectations) return c.Redirect(302, normalizedPrefix) } } } return next(c) } } } ================================================ FILE: core/http/middleware/strippathprefix_test.go ================================================ package middleware import ( "net/http/httptest" "github.com/labstack/echo/v4" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" ) var _ = Describe("StripPathPrefix", func() { var app *echo.Echo var actualPath string var appInitialized bool BeforeEach(func() { actualPath = "" if !appInitialized { app = echo.New() app.Pre(StripPathPrefix()) app.GET("/hello/world", func(c echo.Context) error { actualPath = c.Request().URL.Path return nil }) app.GET("/", func(c echo.Context) error { actualPath = c.Request().URL.Path return nil }) appInitialized = true } }) Context("without prefix", func() { It("should not modify path when no header is present", func() { req := httptest.NewRequest("GET", "/hello/world", nil) rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(200), "response status code") Expect(actualPath).To(Equal("/hello/world"), "rewritten path") }) It("should not modify root path when no header is present", func() { req := httptest.NewRequest("GET", "/", nil) rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(200), "response status code") Expect(actualPath).To(Equal("/"), "rewritten path") }) It("should not modify path when header does not match", func() { req := httptest.NewRequest("GET", "/hello/world", nil) req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"} rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(200), "response status code") Expect(actualPath).To(Equal("/hello/world"), "rewritten path") }) }) Context("with prefix", func() { It("should return 404 when prefix does not match header", func() { req := httptest.NewRequest("GET", "/prefix/hello/world", nil) req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/"} rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(404), "response status code") }) It("should strip matching prefix from path", func() { req := httptest.NewRequest("GET", "/myprefix/hello/world", nil) req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/"} rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(200), "response status code") Expect(actualPath).To(Equal("/hello/world"), "rewritten path") }) It("should strip prefix when it matches the first header value", func() { req := httptest.NewRequest("GET", "/myprefix/hello/world", nil) req.Header["X-Forwarded-Prefix"] = []string{"/myprefix/", "/otherprefix/"} rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(200), "response status code") Expect(actualPath).To(Equal("/hello/world"), "rewritten path") }) It("should strip prefix when it matches the second header value", func() { req := httptest.NewRequest("GET", "/myprefix/hello/world", nil) req.Header["X-Forwarded-Prefix"] = []string{"/otherprefix/", "/myprefix/"} rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(200), "response status code") Expect(actualPath).To(Equal("/hello/world"), "rewritten path") }) It("should strip prefix when header does not end with slash", func() { req := httptest.NewRequest("GET", "/myprefix/hello/world", nil) req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"} rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(200), "response status code") Expect(actualPath).To(Equal("/hello/world"), "rewritten path") }) It("should return 404 when prefix does not match header without trailing slash", func() { req := httptest.NewRequest("GET", "/myprefix-suffix/hello/world", nil) req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"} rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(404), "response status code") }) It("should redirect when prefix does not end with a slash", func() { req := httptest.NewRequest("GET", "/myprefix", nil) req.Header["X-Forwarded-Prefix"] = []string{"/myprefix"} rec := httptest.NewRecorder() app.ServeHTTP(rec, req) Expect(rec.Code).To(Equal(302), "response status code") Expect(rec.Header().Get("Location")).To(Equal("/myprefix/"), "redirect location") }) }) }) ================================================ FILE: core/http/middleware/trace.go ================================================ package middleware import ( "bytes" "io" "net/http" "sort" "sync" "time" "github.com/emirpasic/gods/v2/queues/circularbuffer" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/xlog" ) type APIExchangeRequest struct { Method string `json:"method"` Path string `json:"path"` Headers *http.Header `json:"headers"` Body *[]byte `json:"body"` } type APIExchangeResponse struct { Status int `json:"status"` Headers *http.Header `json:"headers"` Body *[]byte `json:"body"` } type APIExchange struct { Timestamp time.Time `json:"timestamp"` Duration time.Duration `json:"duration"` Request APIExchangeRequest `json:"request"` Response APIExchangeResponse `json:"response"` Error string `json:"error,omitempty"` UserID string `json:"user_id,omitempty"` UserName string `json:"user_name,omitempty"` } var traceBuffer *circularbuffer.Queue[APIExchange] var mu sync.Mutex var logChan = make(chan APIExchange, 100) var initOnce sync.Once type bodyWriter struct { http.ResponseWriter body *bytes.Buffer } func (w *bodyWriter) Write(b []byte) (int, error) { w.body.Write(b) return w.ResponseWriter.Write(b) } func (w *bodyWriter) Flush() { if flusher, ok := w.ResponseWriter.(http.Flusher); ok { flusher.Flush() } } func initializeTracing(maxItems int) { initOnce.Do(func() { if maxItems <= 0 { maxItems = 100 } mu.Lock() traceBuffer = circularbuffer.New[APIExchange](maxItems) mu.Unlock() go func() { for exchange := range logChan { mu.Lock() if traceBuffer != nil { traceBuffer.Enqueue(exchange) } mu.Unlock() } }() }) } // TraceMiddleware intercepts and logs JSON API requests and responses func TraceMiddleware(app *application.Application) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if !app.ApplicationConfig().EnableTracing { return next(c) } initializeTracing(app.ApplicationConfig().TracingMaxItems) if c.Request().Header.Get("Content-Type") != "application/json" { return next(c) } body, err := io.ReadAll(c.Request().Body) if err != nil { xlog.Error("Failed to read request body") return err } // Restore the body for downstream handlers c.Request().Body = io.NopCloser(bytes.NewBuffer(body)) startTime := time.Now() // Wrap response writer to capture body resBody := new(bytes.Buffer) mw := &bodyWriter{ ResponseWriter: c.Response().Writer, body: resBody, } c.Response().Writer = mw handlerErr := next(c) // Restore original writer unconditionally c.Response().Writer = mw.ResponseWriter // Determine response status (use 500 if handler errored and no status was set) status := c.Response().Status if status == 0 && handlerErr != nil { status = http.StatusInternalServerError } // Create exchange log (always, even on error) requestHeaders := c.Request().Header.Clone() requestBody := make([]byte, len(body)) copy(requestBody, body) responseHeaders := c.Response().Header().Clone() responseBody := make([]byte, resBody.Len()) copy(responseBody, resBody.Bytes()) exchange := APIExchange{ Timestamp: startTime, Duration: time.Since(startTime), Request: APIExchangeRequest{ Method: c.Request().Method, Path: c.Path(), Headers: &requestHeaders, Body: &requestBody, }, Response: APIExchangeResponse{ Status: status, Headers: &responseHeaders, Body: &responseBody, }, } if handlerErr != nil { exchange.Error = handlerErr.Error() } if user := auth.GetUser(c); user != nil { exchange.UserID = user.ID exchange.UserName = user.Name } select { case logChan <- exchange: default: xlog.Warn("Trace channel full, dropping trace") } return handlerErr } } } // GetTraces returns a copy of the logged API exchanges for display func GetTraces() []APIExchange { mu.Lock() if traceBuffer == nil { mu.Unlock() return []APIExchange{} } traces := traceBuffer.Values() mu.Unlock() sort.Slice(traces, func(i, j int) bool { return traces[i].Timestamp.After(traces[j].Timestamp) }) return traces } // ClearTraces clears the in-memory logs func ClearTraces() { mu.Lock() if traceBuffer != nil { traceBuffer.Clear() } mu.Unlock() } ================================================ FILE: core/http/middleware/usage.go ================================================ package middleware import ( "bytes" "encoding/json" "sync" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/http/auth" "github.com/mudler/xlog" "gorm.io/gorm" ) const ( usageFlushInterval = 5 * time.Second usageMaxPending = 5000 ) // usageBatcher accumulates usage records and flushes them to the DB periodically. type usageBatcher struct { mu sync.Mutex pending []*auth.UsageRecord db *gorm.DB } func (b *usageBatcher) add(r *auth.UsageRecord) { b.mu.Lock() b.pending = append(b.pending, r) b.mu.Unlock() } func (b *usageBatcher) flush() { b.mu.Lock() batch := b.pending b.pending = nil b.mu.Unlock() if len(batch) == 0 { return } if err := b.db.Create(&batch).Error; err != nil { xlog.Error("Failed to flush usage batch", "count", len(batch), "error", err) // Re-queue failed records with a cap to avoid unbounded growth b.mu.Lock() if len(b.pending) < usageMaxPending { b.pending = append(batch, b.pending...) } b.mu.Unlock() } } var batcher *usageBatcher // InitUsageRecorder starts a background goroutine that periodically flushes // accumulated usage records to the database. func InitUsageRecorder(db *gorm.DB) { if db == nil { return } batcher = &usageBatcher{db: db} go func() { ticker := time.NewTicker(usageFlushInterval) defer ticker.Stop() for range ticker.C { batcher.flush() } }() } // usageResponseBody is the minimal structure we need from the response JSON. type usageResponseBody struct { Model string `json:"model"` Usage *struct { PromptTokens int64 `json:"prompt_tokens"` CompletionTokens int64 `json:"completion_tokens"` TotalTokens int64 `json:"total_tokens"` } `json:"usage"` } // UsageMiddleware extracts token usage from OpenAI-compatible response JSON // and records it per-user. func UsageMiddleware(db *gorm.DB) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if db == nil || batcher == nil { return next(c) } startTime := time.Now() // Wrap response writer to capture body resBody := new(bytes.Buffer) origWriter := c.Response().Writer mw := &bodyWriter{ ResponseWriter: origWriter, body: resBody, } c.Response().Writer = mw handlerErr := next(c) // Restore original writer c.Response().Writer = origWriter // Only record on successful responses if c.Response().Status < 200 || c.Response().Status >= 300 { return handlerErr } // Get authenticated user user := auth.GetUser(c) if user == nil { return handlerErr } // Try to parse usage from response responseBytes := resBody.Bytes() if len(responseBytes) == 0 { return handlerErr } // Check content type ct := c.Response().Header().Get("Content-Type") isJSON := ct == "" || ct == "application/json" || bytes.HasPrefix([]byte(ct), []byte("application/json")) isSSE := bytes.HasPrefix([]byte(ct), []byte("text/event-stream")) if !isJSON && !isSSE { return handlerErr } var resp usageResponseBody if isSSE { last, ok := lastSSEData(responseBytes) if !ok { return handlerErr } if err := json.Unmarshal(last, &resp); err != nil { return handlerErr } } else { if err := json.Unmarshal(responseBytes, &resp); err != nil { return handlerErr } } if resp.Usage == nil { return handlerErr } record := &auth.UsageRecord{ UserID: user.ID, UserName: user.Name, Model: resp.Model, Endpoint: c.Request().URL.Path, PromptTokens: resp.Usage.PromptTokens, CompletionTokens: resp.Usage.CompletionTokens, TotalTokens: resp.Usage.TotalTokens, Duration: time.Since(startTime).Milliseconds(), CreatedAt: startTime, } batcher.add(record) return handlerErr } } } // lastSSEData returns the payload of the last "data: " line whose content is not "[DONE]". func lastSSEData(b []byte) ([]byte, bool) { prefix := []byte("data: ") var last []byte for _, line := range bytes.Split(b, []byte("\n")) { line = bytes.TrimRight(line, "\r") if bytes.HasPrefix(line, prefix) { payload := line[len(prefix):] if !bytes.Equal(payload, []byte("[DONE]")) { last = payload } } } return last, last != nil } ================================================ FILE: core/http/openresponses_test.go ================================================ package http_test import ( "bytes" "context" "encoding/json" "io" "net/http" "os" "strings" "time" "github.com/labstack/echo/v4" "github.com/mudler/LocalAI/core/application" "github.com/mudler/LocalAI/core/config" . "github.com/mudler/LocalAI/core/http" "github.com/mudler/LocalAI/pkg/system" . "github.com/onsi/ginkgo/v2" . "github.com/onsi/gomega" "github.com/mudler/xlog" ) const testModel = "Qwen3-VL-2B-Instruct-GGUF" var _ = Describe("Open Responses API", func() { var app *echo.Echo var c context.Context var cancel context.CancelFunc commonOpts := []config.AppOption{ config.WithDebug(true), } Context("API with ephemeral models", func() { BeforeEach(func(sc SpecContext) { var err error backendPath := os.Getenv("BACKENDS_PATH") c, cancel = context.WithCancel(context.Background()) systemState, err := system.GetSystemState( system.WithBackendPath(backendPath), system.WithModelPath(modelDir), ) Expect(err).ToNot(HaveOccurred()) application, err := application.New( append(commonOpts, config.WithContext(c), config.WithSystemState(systemState), config.WithApiKeys([]string{apiKey}), config.WithModelsURL("https://huggingface.co/unsloth/Qwen3-VL-2B-Instruct-GGUF"), )...) Expect(err).ToNot(HaveOccurred()) app, err = API(application) Expect(err).ToNot(HaveOccurred()) go func() { if err := app.Start("127.0.0.1:9090"); err != nil && err != http.ErrServerClosed { xlog.Error("server error", "error", err) } }() // Wait for API to be ready Eventually(func() error { resp, err := http.Get("http://127.0.0.1:9090/healthz") if err != nil { return err } resp.Body.Close() return nil }, "2m").ShouldNot(HaveOccurred()) }) AfterEach(func(sc SpecContext) { cancel() if app != nil { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() err := app.Shutdown(ctx) Expect(err).ToNot(HaveOccurred()) } }) Context("HTTP Protocol Compliance", func() { It("MUST accept application/json Content-Type", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() // Should accept the request (may fail on model not found, but should accept Content-Type) Expect(resp.StatusCode).To(Or(Equal(200), Equal(400), Equal(500))) }) It("MUST return application/json for non-streaming responses", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", "stream": false, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() contentType := resp.Header.Get("Content-Type") if resp.StatusCode == 200 { Expect(contentType).To(ContainSubstring("application/json")) } }) It("MUST return text/event-stream for streaming responses", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", "stream": true, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() contentType := resp.Header.Get("Content-Type") if resp.StatusCode == 200 { Expect(contentType).To(Equal("text/event-stream")) } }) It("MUST end streaming with [DONE] terminal event", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", "stream": true, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { body, err := io.ReadAll(resp.Body) Expect(err).ToNot(HaveOccurred()) bodyStr := string(body) // Should end with [DONE] Expect(bodyStr).To(ContainSubstring("data: [DONE]")) } }) It("MUST have event field matching type in body", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", "stream": true, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { body, err := io.ReadAll(resp.Body) Expect(err).ToNot(HaveOccurred()) bodyStr := string(body) // Parse SSE events lines := strings.Split(bodyStr, "\n") for i, line := range lines { if strings.HasPrefix(line, "event: ") { eventType := strings.TrimPrefix(line, "event: ") // Next line should be data: with matching type if i+1 < len(lines) && strings.HasPrefix(lines[i+1], "data: ") { dataLine := strings.TrimPrefix(lines[i+1], "data: ") var eventData map[string]interface{} if err := json.Unmarshal([]byte(dataLine), &eventData); err == nil { if typeVal, ok := eventData["type"].(string); ok { Expect(typeVal).To(Equal(eventType)) } } } } } } }) }) Context("Response Structure", func() { It("MUST return id field", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HaveKey("id")) Expect(response["id"]).ToNot(BeEmpty()) } }) It("MUST return object field as 'response'", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HaveKey("object")) Expect(response["object"]).To(Equal("response")) } }) It("MUST return created_at timestamp", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HaveKey("created_at")) // Should be a number (unix timestamp) createdAt, ok := response["created_at"].(float64) Expect(ok).To(BeTrue()) Expect(createdAt).To(BeNumerically(">", 0)) } }) It("MUST return status field", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HaveKey("status")) status, ok := response["status"].(string) Expect(ok).To(BeTrue()) Expect(status).To(BeElementOf("in_progress", "completed", "failed", "incomplete")) } }) It("MUST return model field", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HaveKey("model")) Expect(response["model"]).ToNot(BeEmpty()) } }) It("MUST return output array of items", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) Expect(response).To(HaveKey("output")) output, ok := response["output"].([]interface{}) Expect(ok).To(BeTrue()) Expect(output).ToNot(BeNil()) } }) }) Context("Items", func() { It("MUST include id field on all items", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) output, ok := response["output"].([]interface{}) if ok { for _, item := range output { itemMap, ok := item.(map[string]interface{}) Expect(ok).To(BeTrue()) Expect(itemMap).To(HaveKey("id")) Expect(itemMap["id"]).ToNot(BeEmpty()) } } } }) It("MUST include type field on all items", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) output, ok := response["output"].([]interface{}) if ok { for _, item := range output { itemMap, ok := item.(map[string]interface{}) Expect(ok).To(BeTrue()) Expect(itemMap).To(HaveKey("type")) Expect(itemMap["type"]).ToNot(BeEmpty()) } } } }) It("MUST include status field on all items", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) output, ok := response["output"].([]interface{}) if ok { for _, item := range output { itemMap, ok := item.(map[string]interface{}) Expect(ok).To(BeTrue()) Expect(itemMap).To(HaveKey("status")) status, ok := itemMap["status"].(string) Expect(ok).To(BeTrue()) Expect(status).To(BeElementOf("in_progress", "completed", "incomplete")) } } } }) It("MUST support message items with role field", func() { reqBody := map[string]interface{}{ "model": testModel, "input": []map[string]interface{}{ { "type": "message", "role": "user", "content": []map[string]interface{}{ { "type": "input_text", "text": "Hello", }, }, }, }, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) output, ok := response["output"].([]interface{}) if ok && len(output) > 0 { itemMap, ok := output[0].(map[string]interface{}) Expect(ok).To(BeTrue()) if itemMap["type"] == "message" { Expect(itemMap).To(HaveKey("role")) role, ok := itemMap["role"].(string) Expect(ok).To(BeTrue()) Expect(role).To(BeElementOf("user", "assistant", "system", "developer")) } } } }) }) Context("Content Types", func() { It("MUST support input_text content", func() { reqBody := map[string]interface{}{ "model": testModel, "input": []map[string]interface{}{ { "type": "message", "role": "user", "content": []map[string]interface{}{ { "type": "input_text", "text": "Hello world", }, }, }, }, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() // Should accept the request Expect(resp.StatusCode).To(Or(Equal(200), Equal(400), Equal(500))) }) It("MUST support input_image content with URL", func() { reqBody := map[string]interface{}{ "model": testModel, "input": []map[string]interface{}{ { "type": "message", "role": "user", "content": []map[string]interface{}{ { "type": "input_image", "image_url": "https://example.com/image.png", "detail": "auto", }, }, }, }, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() // Should accept the request Expect(resp.StatusCode).To(Or(Equal(200), Equal(400), Equal(500))) }) It("MUST support input_image content with base64", func() { reqBody := map[string]interface{}{ "model": testModel, "input": []map[string]interface{}{ { "type": "message", "role": "user", "content": []map[string]interface{}{ { "type": "input_image", "image_url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==", "detail": "auto", }, }, }, }, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() // Should accept the request Expect(resp.StatusCode).To(Or(Equal(200), Equal(400), Equal(500))) }) It("MUST support output_text content", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { var response map[string]interface{} body, _ := io.ReadAll(resp.Body) err = json.Unmarshal(body, &response) Expect(err).ToNot(HaveOccurred()) output, ok := response["output"].([]interface{}) if ok && len(output) > 0 { itemMap, ok := output[0].(map[string]interface{}) Expect(ok).To(BeTrue()) if itemMap["type"] == "message" { content, ok := itemMap["content"].([]interface{}) if ok && len(content) > 0 { contentMap, ok := content[0].(map[string]interface{}) if ok { contentType, _ := contentMap["type"].(string) if contentType == "output_text" { Expect(contentMap).To(HaveKey("text")) } } } } } } }) }) Context("Streaming Events", func() { It("MUST emit response.created as first event", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", "stream": true, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { body, err := io.ReadAll(resp.Body) Expect(err).ToNot(HaveOccurred()) bodyStr := string(body) // Should contain response.created event Expect(bodyStr).To(ContainSubstring("response.created")) } }) It("MUST include sequence_number in all events", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Hello", "stream": true, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode == 200 { body, err := io.ReadAll(resp.Body) Expect(err).ToNot(HaveOccurred()) bodyStr := string(body) // Parse SSE events and check for sequence_number lines := strings.Split(bodyStr, "\n") for _, line := range lines { if strings.HasPrefix(line, "data: ") { dataLine := strings.TrimPrefix(line, "data: ") if dataLine != "[DONE]" { var eventData map[string]interface{} if err := json.Unmarshal([]byte(dataLine), &eventData); err == nil { if _, hasType := eventData["type"]; hasType { Expect(eventData).To(HaveKey("sequence_number")) } } } } } } }) }) Context("Error Handling", func() { It("MUST return structured error with type and message fields", func() { reqBody := map[string]interface{}{ "model": "nonexistent-model", "input": "Hello", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() if resp.StatusCode >= 400 { var errorResp map[string]interface{} body, _ := io.ReadAll(resp.Body) json.Unmarshal(body, &errorResp) if errorResp["error"] != nil { errorObj, ok := errorResp["error"].(map[string]interface{}) if ok { Expect(errorObj).To(HaveKey("type")) Expect(errorObj).To(HaveKey("message")) } } } }) }) Context("Previous Response ID", func() { It("should load previous response and concatenate context", func() { // First, create a response reqBody1 := map[string]interface{}{ "model": testModel, "input": "What is 2+2?", } payload1, err := json.Marshal(reqBody1) Expect(err).ToNot(HaveOccurred()) req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1)) Expect(err).ToNot(HaveOccurred()) req1.Header.Set("Content-Type", "application/json") req1.Header.Set("Authorization", bearerKey) client := &http.Client{} resp1, err := client.Do(req1) Expect(err).ToNot(HaveOccurred()) defer resp1.Body.Close() // Check if first response succeeded if resp1.StatusCode != 200 { Skip("First response failed, skipping previous_response_id test (backend may not be available)") } var response1 map[string]interface{} body1, err := io.ReadAll(resp1.Body) Expect(err).ToNot(HaveOccurred()) err = json.Unmarshal(body1, &response1) Expect(err).ToNot(HaveOccurred()) responseID, ok := response1["id"].(string) Expect(ok).To(BeTrue()) Expect(responseID).ToNot(BeEmpty()) // Now create a new response with previous_response_id reqBody2 := map[string]interface{}{ "model": testModel, "input": "What about 3+3?", "previous_response_id": responseID, } payload2, err := json.Marshal(reqBody2) Expect(err).ToNot(HaveOccurred()) req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2)) Expect(err).ToNot(HaveOccurred()) req2.Header.Set("Content-Type", "application/json") req2.Header.Set("Authorization", bearerKey) resp2, err := client.Do(req2) Expect(err).ToNot(HaveOccurred()) defer resp2.Body.Close() var response2 map[string]interface{} body2, err := io.ReadAll(resp2.Body) Expect(err).ToNot(HaveOccurred()) err = json.Unmarshal(body2, &response2) Expect(err).ToNot(HaveOccurred()) Expect(response2["previous_response_id"]).To(Equal(responseID)) Expect(response2["status"]).To(Equal("completed")) }) It("should return error for invalid previous_response_id", func() { reqBody := map[string]interface{}{ "model": testModel, "input": "Test", "previous_response_id": "nonexistent_response_id", } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() Expect(resp.StatusCode).To(Equal(404)) var errorResp map[string]interface{} body, _ := io.ReadAll(resp.Body) json.Unmarshal(body, &errorResp) if errorResp["error"] != nil { errorObj, ok := errorResp["error"].(map[string]interface{}) if ok { Expect(errorObj["type"]).To(Equal("not_found")) Expect(errorObj["param"]).To(Equal("previous_response_id")) } } }) }) Context("Item Reference", func() { It("should resolve item_reference in input", func() { // First, create a response with items reqBody1 := map[string]interface{}{ "model": testModel, "input": "Hello", } payload1, err := json.Marshal(reqBody1) Expect(err).ToNot(HaveOccurred()) req1, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload1)) Expect(err).ToNot(HaveOccurred()) req1.Header.Set("Content-Type", "application/json") req1.Header.Set("Authorization", bearerKey) client := &http.Client{} resp1, err := client.Do(req1) Expect(err).ToNot(HaveOccurred()) defer resp1.Body.Close() // Check if first response succeeded if resp1.StatusCode != 200 { Skip("First response failed, skipping item_reference test (backend may not be available)") } var response1 map[string]interface{} body1, err := io.ReadAll(resp1.Body) Expect(err).ToNot(HaveOccurred()) err = json.Unmarshal(body1, &response1) Expect(err).ToNot(HaveOccurred()) // Get the first output item ID output, ok := response1["output"].([]interface{}) Expect(ok).To(BeTrue()) Expect(len(output)).To(BeNumerically(">", 0)) firstItem, ok := output[0].(map[string]interface{}) Expect(ok).To(BeTrue()) itemID, ok := firstItem["id"].(string) Expect(ok).To(BeTrue()) Expect(itemID).ToNot(BeEmpty()) // Now create a new response with item_reference reqBody2 := map[string]interface{}{ "model": testModel, "input": []interface{}{ map[string]interface{}{ "type": "item_reference", "item_id": itemID, }, map[string]interface{}{ "type": "message", "role": "user", "content": "Continue from the previous message", }, }, } payload2, err := json.Marshal(reqBody2) Expect(err).ToNot(HaveOccurred()) req2, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload2)) Expect(err).ToNot(HaveOccurred()) req2.Header.Set("Content-Type", "application/json") req2.Header.Set("Authorization", bearerKey) resp2, err := client.Do(req2) Expect(err).ToNot(HaveOccurred()) defer resp2.Body.Close() // Should succeed (item reference resolved) Expect(resp2.StatusCode).To(Equal(200)) }) It("should return error for invalid item_reference", func() { reqBody := map[string]interface{}{ "model": testModel, "input": []interface{}{ map[string]interface{}{ "type": "item_reference", "item_id": "nonexistent_item_id", }, }, } payload, err := json.Marshal(reqBody) Expect(err).ToNot(HaveOccurred()) req, err := http.NewRequest("POST", "http://127.0.0.1:9090/v1/responses", bytes.NewBuffer(payload)) Expect(err).ToNot(HaveOccurred()) req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", bearerKey) client := &http.Client{} resp, err := client.Do(req) Expect(err).ToNot(HaveOccurred()) defer resp.Body.Close() // Should return error Expect(resp.StatusCode).To(BeNumerically(">=", 400)) }) }) }) }) ================================================ FILE: core/http/react-ui/e2e/backend-logs.spec.js ================================================ import { test, expect } from '@playwright/test' test.describe('Backend Logs', () => { test('model detail page shows title', async ({ page }) => { await page.goto('/app/backend-logs/mock-model') await expect(page.locator('.page-title')).toContainText('mock-model') }) test('no back arrow link on detail page', async ({ page }) => { await page.goto('/app/backend-logs/mock-model') await expect(page.locator('a[href="/app/backend-logs"]')).not.toBeVisible() }) test('filter buttons are visible', async ({ page }) => { await page.goto('/app/backend-logs/mock-model') await expect(page.locator('button', { hasText: 'All' })).toBeVisible() await expect(page.locator('button', { hasText: 'stdout' })).toBeVisible() await expect(page.locator('button', { hasText: 'stderr' })).toBeVisible() }) test('filter buttons toggle active state', async ({ page }) => { await page.goto('/app/backend-logs/mock-model') const allBtn = page.locator('button', { hasText: 'All' }) const stdoutBtn = page.locator('button', { hasText: 'stdout' }) // All is active by default await expect(allBtn).toHaveClass(/btn-primary/) // Click stdout await stdoutBtn.click() await expect(stdoutBtn).toHaveClass(/btn-primary/) await expect(allBtn).not.toHaveClass(/btn-primary/) }) test('export button is present', async ({ page }) => { await page.goto('/app/backend-logs/mock-model') await expect(page.locator('button', { hasText: 'Export' })).toBeVisible() }) test('auto-scroll checkbox is present', async ({ page }) => { await page.goto('/app/backend-logs/mock-model') await expect(page.locator('text=Auto-scroll')).toBeVisible() }) test('clear button is present', async ({ page }) => { await page.goto('/app/backend-logs/mock-model') await expect(page.locator('button', { hasText: 'Clear' })).toBeVisible() }) test('details toggle button is present and toggles', async ({ page }) => { await page.goto('/app/backend-logs/mock-model') // "Text only" button visible by default (details are shown) const toggleBtn = page.locator('button', { hasText: 'Text only' }) await expect(toggleBtn).toBeVisible() // Click to hide details await toggleBtn.click() // Button label changes to "Show details" await expect(page.locator('button', { hasText: 'Show details' })).toBeVisible() }) }) ================================================ FILE: core/http/react-ui/e2e/manage-logs-link.spec.js ================================================ import { test, expect } from '@playwright/test' test.describe('Manage Page - Backend Logs Link', () => { test('models table shows terminal icon for logs', async ({ page }) => { await page.goto('/app/manage') // Wait for models to load await expect(page.locator('.table')).toBeVisible({ timeout: 10_000 }) // Check for terminal icon (backend logs link) const terminalIcon = page.locator('a[title="Backend logs"] i.fa-terminal') await expect(terminalIcon.first()).toBeVisible() }) test('terminal icon links to backend-logs page', async ({ page }) => { await page.goto('/app/manage') await expect(page.locator('.table')).toBeVisible({ timeout: 10_000 }) const logsLink = page.locator('a[title="Backend logs"]').first() await expect(logsLink).toBeVisible() // Link uses href="#" with onClick for navigation const href = await logsLink.getAttribute('href') expect(href).toBe('#') // Click and verify navigation await logsLink.click() await expect(page).toHaveURL(/\/app\/backend-logs\//) }) }) ================================================ FILE: core/http/react-ui/e2e/models-gallery.spec.js ================================================ import { test, expect } from '@playwright/test' const MOCK_MODELS_RESPONSE = { models: [ { name: 'llama-model', description: 'A llama model', backend: 'llama-cpp', installed: false, tags: ['llm'] }, { name: 'whisper-model', description: 'A whisper model', backend: 'whisper', installed: true, tags: ['stt'] }, { name: 'stablediffusion-model', description: 'An image model', backend: 'stablediffusion', installed: false, tags: ['sd'] }, { name: 'unknown-model', description: 'No backend', backend: '', installed: false, tags: [] }, ], allBackends: ['llama-cpp', 'stablediffusion', 'whisper'], allTags: ['llm', 'sd', 'stt'], availableModels: 4, installedModels: 1, totalPages: 1, currentPage: 1, } test.describe('Models Gallery - Backend Features', () => { test.beforeEach(async ({ page }) => { await page.route('**/api/models*', (route) => { route.fulfill({ contentType: 'application/json', body: JSON.stringify(MOCK_MODELS_RESPONSE), }) }) await page.goto('/app/models') // Wait for the table to render await expect(page.locator('th', { hasText: 'Backend' })).toBeVisible({ timeout: 10_000 }) }) test('backend column header is visible', async ({ page }) => { await expect(page.locator('th', { hasText: 'Backend' })).toBeVisible() }) test('backend badges shown in table rows', async ({ page }) => { const table = page.locator('table') await expect(table.locator('.badge', { hasText: 'llama-cpp' })).toBeVisible() await expect(table.locator('.badge', { hasText: /^whisper$/ })).toBeVisible() }) test('backend dropdown is visible', async ({ page }) => { await expect(page.locator('button', { hasText: 'All Backends' })).toBeVisible() }) test('clicking backend dropdown opens searchable panel', async ({ page }) => { await page.locator('button', { hasText: 'All Backends' }).click() await expect(page.locator('input[placeholder="Search backends..."]')).toBeVisible() }) test('typing in search filters dropdown options', async ({ page }) => { await page.locator('button', { hasText: 'All Backends' }).click() const searchInput = page.locator('input[placeholder="Search backends..."]') await searchInput.fill('llama') // llama-cpp option should be visible, whisper should not const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..') .locator('..') await expect(dropdown.locator('text=llama-cpp')).toBeVisible() await expect(dropdown.locator('text=whisper')).not.toBeVisible() }) test('selecting a backend updates the dropdown label', async ({ page }) => { await page.locator('button', { hasText: 'All Backends' }).click() // Click the llama-cpp option within the dropdown (not the table badge) const dropdown = page.locator('input[placeholder="Search backends..."]').locator('..').locator('..') await dropdown.locator('text=llama-cpp').click() // The dropdown button should now show the selected backend instead of "All Backends" await expect(page.locator('button span', { hasText: 'llama-cpp' })).toBeVisible() }) test('expanded row shows backend in detail', async ({ page }) => { // Click the first model row to expand it await page.locator('tr', { hasText: 'llama-model' }).click() // The detail view should show Backend label and value const detail = page.locator('td[colspan="8"]') await expect(detail.locator('text=Backend')).toBeVisible() await expect(detail.locator('text=llama-cpp')).toBeVisible() }) }) ================================================ FILE: core/http/react-ui/e2e/navigation.spec.js ================================================ import { test, expect } from '@playwright/test' test.describe('Navigation', () => { test('/ redirects to /app', async ({ page }) => { await page.goto('/') await expect(page).toHaveURL(/\/app/) }) test('/app shows home page with LocalAI title', async ({ page }) => { await page.goto('/app') await expect(page.locator('.sidebar')).toBeVisible() await expect(page.locator('.home-page')).toBeVisible() }) test('sidebar traces link navigates to /app/traces', async ({ page }) => { await page.goto('/app') const tracesLink = page.locator('a.nav-item[href="/app/traces"]') await expect(tracesLink).toBeVisible() await tracesLink.click() await expect(page).toHaveURL(/\/app\/traces/) await expect(page.getByRole('heading', { name: 'Traces', exact: true })).toBeVisible() }) }) ================================================ FILE: core/http/react-ui/e2e/settings-backend-logging.spec.js ================================================ import { test, expect } from '@playwright/test' test.describe('Settings - Backend Logging', () => { test.beforeEach(async ({ page }) => { await page.goto('/app/settings') // Wait for settings to load await expect(page.locator('h3', { hasText: 'Tracing' })).toBeVisible({ timeout: 10_000 }) }) test('backend logging toggle is visible in tracing section', async ({ page }) => { await expect(page.locator('text=Enable Backend Logging')).toBeVisible() }) test('backend logging toggle can be toggled', async ({ page }) => { // Find the checkbox associated with backend logging const section = page.locator('div', { has: page.locator('text=Enable Backend Logging') }) const checkbox = section.locator('input[type="checkbox"]').last() // Toggle on const wasChecked = await checkbox.isChecked() await checkbox.locator('..').click() if (wasChecked) { await expect(checkbox).not.toBeChecked() } else { await expect(checkbox).toBeChecked() } }) test('save shows toast', async ({ page }) => { // Click save button await page.locator('button', { hasText: 'Save' }).click() // Verify toast appears await expect(page.locator('text=Settings saved')).toBeVisible({ timeout: 5_000 }) }) }) ================================================ FILE: core/http/react-ui/e2e/traces-errors.spec.js ================================================ import { test, expect } from '@playwright/test' test.describe('Traces - Error Display', () => { test.beforeEach(async ({ page }) => { // Mock API traces with sample data so the table renders await page.route('**/api/traces', (route) => { route.fulfill({ contentType: 'application/json', body: JSON.stringify([ { request: { method: 'POST', path: '/v1/chat/completions' }, response: { status: 200 }, error: null, }, ]), }) }) // Mock backend traces with sample data await page.route('**/api/backend-traces', (route) => { route.fulfill({ contentType: 'application/json', body: JSON.stringify([ { type: 'model_load', timestamp: Date.now() * 1_000_000, model_name: 'mock-model', summary: 'Loaded model', duration: 500_000_000, error: null, }, ]), }) }) await page.goto('/app/traces') await expect(page.locator('text=Tracing is')).toBeVisible({ timeout: 10_000 }) }) test('API traces tab has Result column header', async ({ page }) => { // API tab is active by default await expect(page.locator('th', { hasText: 'Result' })).toBeVisible() }) test('backend traces tab shows model_load type if present', async ({ page }) => { // Switch to backend traces tab await page.locator('button', { hasText: 'Backend Traces' }).click() // The table should be visible with Type column await expect(page.locator('th', { hasText: 'Type' })).toBeVisible() }) }) ================================================ FILE: core/http/react-ui/e2e/traces.spec.js ================================================ import { test, expect } from '@playwright/test' test.describe('Traces Settings', () => { test.beforeEach(async ({ page }) => { await page.goto('/app/traces') // Wait for settings panel to load await expect(page.locator('text=Tracing is')).toBeVisible({ timeout: 10_000 }) }) test('settings panel is visible on page load', async ({ page }) => { await expect(page.locator('text=Tracing is')).toBeVisible() }) test('expand and collapse settings', async ({ page }) => { // The test server starts with tracing enabled, so the panel starts collapsed const settingsHeader = page.locator('button', { hasText: 'Tracing is' }) // Click to expand await settingsHeader.click() await expect(page.locator('text=Enable Tracing')).toBeVisible() // Click to collapse await settingsHeader.click() await expect(page.locator('text=Enable Tracing')).not.toBeVisible() }) test('toggle tracing on and off', async ({ page }) => { // Expand settings const settingsHeader = page.locator('button', { hasText: 'Tracing is' }) await settingsHeader.click() await expect(page.locator('text=Enable Tracing')).toBeVisible() // The Toggle component is a