Repository: router-for-me/CLIProxyAPI
Branch: main
Commit: db63f9b5d60e
Files: 451
Total size: 3.6 MB
Directory structure:
gitextract_9oqd9is2/
├── .dockerignore
├── .github/
│ ├── FUNDING.yml
│ ├── ISSUE_TEMPLATE/
│ │ └── bug_report.md
│ └── workflows/
│ ├── docker-image.yml
│ ├── pr-path-guard.yml
│ ├── pr-test-build.yml
│ └── release.yaml
├── .gitignore
├── .goreleaser.yml
├── Dockerfile
├── LICENSE
├── README.md
├── README_CN.md
├── auths/
│ └── .gitkeep
├── cmd/
│ ├── fetch_antigravity_models/
│ │ └── main.go
│ └── server/
│ └── main.go
├── config.example.yaml
├── docker-build.ps1
├── docker-build.sh
├── docker-compose.yml
├── examples/
│ ├── custom-provider/
│ │ └── main.go
│ ├── http-request/
│ │ └── main.go
│ └── translator/
│ └── main.go
├── go.mod
├── go.sum
├── internal/
│ ├── access/
│ │ ├── config_access/
│ │ │ └── provider.go
│ │ └── reconcile.go
│ ├── api/
│ │ ├── handlers/
│ │ │ └── management/
│ │ │ ├── api_tools.go
│ │ │ ├── api_tools_test.go
│ │ │ ├── auth_files.go
│ │ │ ├── auth_files_delete_test.go
│ │ │ ├── config_basic.go
│ │ │ ├── config_lists.go
│ │ │ ├── handler.go
│ │ │ ├── logs.go
│ │ │ ├── model_definitions.go
│ │ │ ├── oauth_callback.go
│ │ │ ├── oauth_sessions.go
│ │ │ ├── quota.go
│ │ │ ├── test_store_test.go
│ │ │ ├── usage.go
│ │ │ └── vertex_import.go
│ │ ├── middleware/
│ │ │ ├── request_logging.go
│ │ │ ├── request_logging_test.go
│ │ │ ├── response_writer.go
│ │ │ └── response_writer_test.go
│ │ ├── modules/
│ │ │ ├── amp/
│ │ │ │ ├── amp.go
│ │ │ │ ├── amp_test.go
│ │ │ │ ├── fallback_handlers.go
│ │ │ │ ├── fallback_handlers_test.go
│ │ │ │ ├── gemini_bridge.go
│ │ │ │ ├── gemini_bridge_test.go
│ │ │ │ ├── model_mapping.go
│ │ │ │ ├── model_mapping_test.go
│ │ │ │ ├── proxy.go
│ │ │ │ ├── proxy_test.go
│ │ │ │ ├── response_rewriter.go
│ │ │ │ ├── response_rewriter_test.go
│ │ │ │ ├── routes.go
│ │ │ │ ├── routes_test.go
│ │ │ │ ├── secret.go
│ │ │ │ └── secret_test.go
│ │ │ └── modules.go
│ │ ├── server.go
│ │ └── server_test.go
│ ├── auth/
│ │ ├── antigravity/
│ │ │ ├── auth.go
│ │ │ ├── constants.go
│ │ │ └── filename.go
│ │ ├── claude/
│ │ │ ├── anthropic.go
│ │ │ ├── anthropic_auth.go
│ │ │ ├── errors.go
│ │ │ ├── html_templates.go
│ │ │ ├── oauth_server.go
│ │ │ ├── pkce.go
│ │ │ ├── token.go
│ │ │ └── utls_transport.go
│ │ ├── codex/
│ │ │ ├── errors.go
│ │ │ ├── filename.go
│ │ │ ├── html_templates.go
│ │ │ ├── jwt_parser.go
│ │ │ ├── oauth_server.go
│ │ │ ├── openai.go
│ │ │ ├── openai_auth.go
│ │ │ ├── openai_auth_test.go
│ │ │ ├── pkce.go
│ │ │ └── token.go
│ │ ├── empty/
│ │ │ └── token.go
│ │ ├── gemini/
│ │ │ ├── gemini_auth.go
│ │ │ └── gemini_token.go
│ │ ├── iflow/
│ │ │ ├── cookie_helpers.go
│ │ │ ├── iflow_auth.go
│ │ │ ├── iflow_token.go
│ │ │ └── oauth_server.go
│ │ ├── kimi/
│ │ │ ├── kimi.go
│ │ │ └── token.go
│ │ ├── models.go
│ │ ├── qwen/
│ │ │ ├── qwen_auth.go
│ │ │ └── qwen_token.go
│ │ └── vertex/
│ │ ├── keyutil.go
│ │ └── vertex_credentials.go
│ ├── browser/
│ │ └── browser.go
│ ├── buildinfo/
│ │ └── buildinfo.go
│ ├── cache/
│ │ ├── signature_cache.go
│ │ └── signature_cache_test.go
│ ├── cmd/
│ │ ├── anthropic_login.go
│ │ ├── antigravity_login.go
│ │ ├── auth_manager.go
│ │ ├── iflow_cookie.go
│ │ ├── iflow_login.go
│ │ ├── kimi_login.go
│ │ ├── login.go
│ │ ├── openai_device_login.go
│ │ ├── openai_login.go
│ │ ├── qwen_login.go
│ │ ├── run.go
│ │ └── vertex_import.go
│ ├── config/
│ │ ├── codex_websocket_header_defaults_test.go
│ │ ├── config.go
│ │ ├── oauth_model_alias_test.go
│ │ ├── sdk_config.go
│ │ └── vertex_compat.go
│ ├── constant/
│ │ └── constant.go
│ ├── interfaces/
│ │ ├── api_handler.go
│ │ ├── client_models.go
│ │ ├── error_message.go
│ │ └── types.go
│ ├── logging/
│ │ ├── gin_logger.go
│ │ ├── gin_logger_test.go
│ │ ├── global_logger.go
│ │ ├── log_dir_cleaner.go
│ │ ├── log_dir_cleaner_test.go
│ │ ├── request_logger.go
│ │ └── requestid.go
│ ├── managementasset/
│ │ └── updater.go
│ ├── misc/
│ │ ├── claude_code_instructions.go
│ │ ├── claude_code_instructions.txt
│ │ ├── copy-example-config.go
│ │ ├── credentials.go
│ │ ├── header_utils.go
│ │ ├── mime-type.go
│ │ └── oauth.go
│ ├── registry/
│ │ ├── model_definitions.go
│ │ ├── model_registry.go
│ │ ├── model_registry_cache_test.go
│ │ ├── model_registry_hook_test.go
│ │ ├── model_registry_safety_test.go
│ │ ├── model_updater.go
│ │ └── models/
│ │ └── models.json
│ ├── runtime/
│ │ ├── executor/
│ │ │ ├── aistudio_executor.go
│ │ │ ├── antigravity_executor.go
│ │ │ ├── antigravity_executor_buildrequest_test.go
│ │ │ ├── cache_helpers.go
│ │ │ ├── caching_verify_test.go
│ │ │ ├── claude_executor.go
│ │ │ ├── claude_executor_test.go
│ │ │ ├── cloak_obfuscate.go
│ │ │ ├── cloak_utils.go
│ │ │ ├── codex_executor.go
│ │ │ ├── codex_executor_cache_test.go
│ │ │ ├── codex_executor_retry_test.go
│ │ │ ├── codex_websockets_executor.go
│ │ │ ├── codex_websockets_executor_test.go
│ │ │ ├── gemini_cli_executor.go
│ │ │ ├── gemini_executor.go
│ │ │ ├── gemini_vertex_executor.go
│ │ │ ├── iflow_executor.go
│ │ │ ├── iflow_executor_test.go
│ │ │ ├── kimi_executor.go
│ │ │ ├── kimi_executor_test.go
│ │ │ ├── logging_helpers.go
│ │ │ ├── openai_compat_executor.go
│ │ │ ├── openai_compat_executor_compact_test.go
│ │ │ ├── payload_helpers.go
│ │ │ ├── proxy_helpers.go
│ │ │ ├── proxy_helpers_test.go
│ │ │ ├── qwen_executor.go
│ │ │ ├── qwen_executor_test.go
│ │ │ ├── thinking_providers.go
│ │ │ ├── token_helpers.go
│ │ │ ├── usage_helpers.go
│ │ │ ├── usage_helpers_test.go
│ │ │ ├── user_id_cache.go
│ │ │ └── user_id_cache_test.go
│ │ └── geminicli/
│ │ └── state.go
│ ├── store/
│ │ ├── gitstore.go
│ │ ├── objectstore.go
│ │ └── postgresstore.go
│ ├── thinking/
│ │ ├── apply.go
│ │ ├── apply_user_defined_test.go
│ │ ├── convert.go
│ │ ├── errors.go
│ │ ├── provider/
│ │ │ ├── antigravity/
│ │ │ │ └── apply.go
│ │ │ ├── claude/
│ │ │ │ └── apply.go
│ │ │ ├── codex/
│ │ │ │ └── apply.go
│ │ │ ├── gemini/
│ │ │ │ └── apply.go
│ │ │ ├── geminicli/
│ │ │ │ └── apply.go
│ │ │ ├── iflow/
│ │ │ │ └── apply.go
│ │ │ ├── kimi/
│ │ │ │ ├── apply.go
│ │ │ │ └── apply_test.go
│ │ │ └── openai/
│ │ │ └── apply.go
│ │ ├── strip.go
│ │ ├── suffix.go
│ │ ├── text.go
│ │ ├── types.go
│ │ └── validate.go
│ ├── translator/
│ │ ├── antigravity/
│ │ │ ├── claude/
│ │ │ │ ├── antigravity_claude_request.go
│ │ │ │ ├── antigravity_claude_request_test.go
│ │ │ │ ├── antigravity_claude_response.go
│ │ │ │ ├── antigravity_claude_response_test.go
│ │ │ │ └── init.go
│ │ │ ├── gemini/
│ │ │ │ ├── antigravity_gemini_request.go
│ │ │ │ ├── antigravity_gemini_request_test.go
│ │ │ │ ├── antigravity_gemini_response.go
│ │ │ │ ├── antigravity_gemini_response_test.go
│ │ │ │ └── init.go
│ │ │ └── openai/
│ │ │ ├── chat-completions/
│ │ │ │ ├── antigravity_openai_request.go
│ │ │ │ ├── antigravity_openai_response.go
│ │ │ │ ├── antigravity_openai_response_test.go
│ │ │ │ └── init.go
│ │ │ └── responses/
│ │ │ ├── antigravity_openai-responses_request.go
│ │ │ ├── antigravity_openai-responses_response.go
│ │ │ └── init.go
│ │ ├── claude/
│ │ │ ├── gemini/
│ │ │ │ ├── claude_gemini_request.go
│ │ │ │ ├── claude_gemini_response.go
│ │ │ │ └── init.go
│ │ │ ├── gemini-cli/
│ │ │ │ ├── claude_gemini-cli_request.go
│ │ │ │ ├── claude_gemini-cli_response.go
│ │ │ │ └── init.go
│ │ │ └── openai/
│ │ │ ├── chat-completions/
│ │ │ │ ├── claude_openai_request.go
│ │ │ │ ├── claude_openai_request_test.go
│ │ │ │ ├── claude_openai_response.go
│ │ │ │ └── init.go
│ │ │ └── responses/
│ │ │ ├── claude_openai-responses_request.go
│ │ │ ├── claude_openai-responses_response.go
│ │ │ └── init.go
│ │ ├── codex/
│ │ │ ├── claude/
│ │ │ │ ├── codex_claude_request.go
│ │ │ │ ├── codex_claude_request_test.go
│ │ │ │ ├── codex_claude_response.go
│ │ │ │ └── init.go
│ │ │ ├── gemini/
│ │ │ │ ├── codex_gemini_request.go
│ │ │ │ ├── codex_gemini_response.go
│ │ │ │ └── init.go
│ │ │ ├── gemini-cli/
│ │ │ │ ├── codex_gemini-cli_request.go
│ │ │ │ ├── codex_gemini-cli_response.go
│ │ │ │ └── init.go
│ │ │ └── openai/
│ │ │ ├── chat-completions/
│ │ │ │ ├── codex_openai_request.go
│ │ │ │ ├── codex_openai_request_test.go
│ │ │ │ ├── codex_openai_response.go
│ │ │ │ ├── codex_openai_response_test.go
│ │ │ │ └── init.go
│ │ │ └── responses/
│ │ │ ├── codex_openai-responses_request.go
│ │ │ ├── codex_openai-responses_request_test.go
│ │ │ ├── codex_openai-responses_response.go
│ │ │ └── init.go
│ │ ├── gemini/
│ │ │ ├── claude/
│ │ │ │ ├── gemini_claude_request.go
│ │ │ │ ├── gemini_claude_request_test.go
│ │ │ │ ├── gemini_claude_response.go
│ │ │ │ └── init.go
│ │ │ ├── common/
│ │ │ │ └── safety.go
│ │ │ ├── gemini/
│ │ │ │ ├── gemini_gemini_request.go
│ │ │ │ ├── gemini_gemini_request_test.go
│ │ │ │ ├── gemini_gemini_response.go
│ │ │ │ └── init.go
│ │ │ ├── gemini-cli/
│ │ │ │ ├── gemini_gemini-cli_request.go
│ │ │ │ ├── gemini_gemini-cli_response.go
│ │ │ │ └── init.go
│ │ │ └── openai/
│ │ │ ├── chat-completions/
│ │ │ │ ├── gemini_openai_request.go
│ │ │ │ ├── gemini_openai_response.go
│ │ │ │ └── init.go
│ │ │ └── responses/
│ │ │ ├── gemini_openai-responses_request.go
│ │ │ ├── gemini_openai-responses_response.go
│ │ │ ├── gemini_openai-responses_response_test.go
│ │ │ └── init.go
│ │ ├── gemini-cli/
│ │ │ ├── claude/
│ │ │ │ ├── gemini-cli_claude_request.go
│ │ │ │ ├── gemini-cli_claude_request_test.go
│ │ │ │ ├── gemini-cli_claude_response.go
│ │ │ │ └── init.go
│ │ │ ├── gemini/
│ │ │ │ ├── gemini-cli_gemini_request.go
│ │ │ │ ├── gemini-cli_gemini_response.go
│ │ │ │ └── init.go
│ │ │ └── openai/
│ │ │ ├── chat-completions/
│ │ │ │ ├── gemini-cli_openai_request.go
│ │ │ │ ├── gemini-cli_openai_response.go
│ │ │ │ └── init.go
│ │ │ └── responses/
│ │ │ ├── gemini-cli_openai-responses_request.go
│ │ │ ├── gemini-cli_openai-responses_response.go
│ │ │ └── init.go
│ │ ├── init.go
│ │ ├── openai/
│ │ │ ├── claude/
│ │ │ │ ├── init.go
│ │ │ │ ├── openai_claude_request.go
│ │ │ │ ├── openai_claude_request_test.go
│ │ │ │ └── openai_claude_response.go
│ │ │ ├── gemini/
│ │ │ │ ├── init.go
│ │ │ │ ├── openai_gemini_request.go
│ │ │ │ └── openai_gemini_response.go
│ │ │ ├── gemini-cli/
│ │ │ │ ├── init.go
│ │ │ │ ├── openai_gemini_request.go
│ │ │ │ └── openai_gemini_response.go
│ │ │ └── openai/
│ │ │ ├── chat-completions/
│ │ │ │ ├── init.go
│ │ │ │ ├── openai_openai_request.go
│ │ │ │ └── openai_openai_response.go
│ │ │ └── responses/
│ │ │ ├── init.go
│ │ │ ├── openai_openai-responses_request.go
│ │ │ └── openai_openai-responses_response.go
│ │ └── translator/
│ │ └── translator.go
│ ├── tui/
│ │ ├── app.go
│ │ ├── auth_tab.go
│ │ ├── browser.go
│ │ ├── client.go
│ │ ├── config_tab.go
│ │ ├── dashboard.go
│ │ ├── i18n.go
│ │ ├── keys_tab.go
│ │ ├── loghook.go
│ │ ├── logs_tab.go
│ │ ├── oauth_tab.go
│ │ ├── styles.go
│ │ └── usage_tab.go
│ ├── usage/
│ │ └── logger_plugin.go
│ ├── util/
│ │ ├── claude_model.go
│ │ ├── claude_model_test.go
│ │ ├── claude_tool_id.go
│ │ ├── gemini_schema.go
│ │ ├── gemini_schema_test.go
│ │ ├── header_helpers.go
│ │ ├── image.go
│ │ ├── provider.go
│ │ ├── proxy.go
│ │ ├── sanitize_test.go
│ │ ├── ssh_helper.go
│ │ ├── translator.go
│ │ └── util.go
│ ├── watcher/
│ │ ├── clients.go
│ │ ├── config_reload.go
│ │ ├── diff/
│ │ │ ├── auth_diff.go
│ │ │ ├── config_diff.go
│ │ │ ├── config_diff_test.go
│ │ │ ├── model_hash.go
│ │ │ ├── model_hash_test.go
│ │ │ ├── models_summary.go
│ │ │ ├── oauth_excluded.go
│ │ │ ├── oauth_excluded_test.go
│ │ │ ├── oauth_model_alias.go
│ │ │ ├── openai_compat.go
│ │ │ └── openai_compat_test.go
│ │ ├── dispatcher.go
│ │ ├── events.go
│ │ ├── synthesizer/
│ │ │ ├── config.go
│ │ │ ├── config_test.go
│ │ │ ├── context.go
│ │ │ ├── file.go
│ │ │ ├── file_test.go
│ │ │ ├── helpers.go
│ │ │ ├── helpers_test.go
│ │ │ └── interface.go
│ │ ├── watcher.go
│ │ └── watcher_test.go
│ └── wsrelay/
│ ├── http.go
│ ├── manager.go
│ ├── message.go
│ └── session.go
├── sdk/
│ ├── access/
│ │ ├── errors.go
│ │ ├── manager.go
│ │ ├── registry.go
│ │ └── types.go
│ ├── api/
│ │ ├── handlers/
│ │ │ ├── claude/
│ │ │ │ └── code_handlers.go
│ │ │ ├── gemini/
│ │ │ │ ├── gemini-cli_handlers.go
│ │ │ │ └── gemini_handlers.go
│ │ │ ├── handlers.go
│ │ │ ├── handlers_error_response_test.go
│ │ │ ├── handlers_request_details_test.go
│ │ │ ├── handlers_stream_bootstrap_test.go
│ │ │ ├── header_filter.go
│ │ │ ├── header_filter_test.go
│ │ │ ├── openai/
│ │ │ │ ├── openai_handlers.go
│ │ │ │ ├── openai_responses_compact_test.go
│ │ │ │ ├── openai_responses_handlers.go
│ │ │ │ ├── openai_responses_handlers_stream_error_test.go
│ │ │ │ ├── openai_responses_websocket.go
│ │ │ │ └── openai_responses_websocket_test.go
│ │ │ ├── openai_responses_stream_error.go
│ │ │ ├── openai_responses_stream_error_test.go
│ │ │ └── stream_forwarder.go
│ │ ├── management.go
│ │ └── options.go
│ ├── auth/
│ │ ├── antigravity.go
│ │ ├── claude.go
│ │ ├── codex.go
│ │ ├── codex_device.go
│ │ ├── errors.go
│ │ ├── filestore.go
│ │ ├── filestore_test.go
│ │ ├── gemini.go
│ │ ├── iflow.go
│ │ ├── interfaces.go
│ │ ├── kimi.go
│ │ ├── manager.go
│ │ ├── qwen.go
│ │ ├── refresh_registry.go
│ │ └── store_registry.go
│ ├── cliproxy/
│ │ ├── auth/
│ │ │ ├── api_key_model_alias_test.go
│ │ │ ├── conductor.go
│ │ │ ├── conductor_availability_test.go
│ │ │ ├── conductor_executor_replace_test.go
│ │ │ ├── conductor_overrides_test.go
│ │ │ ├── conductor_scheduler_refresh_test.go
│ │ │ ├── conductor_update_test.go
│ │ │ ├── errors.go
│ │ │ ├── oauth_model_alias.go
│ │ │ ├── oauth_model_alias_test.go
│ │ │ ├── openai_compat_pool_test.go
│ │ │ ├── persist_policy.go
│ │ │ ├── persist_policy_test.go
│ │ │ ├── scheduler.go
│ │ │ ├── scheduler_benchmark_test.go
│ │ │ ├── scheduler_test.go
│ │ │ ├── selector.go
│ │ │ ├── selector_test.go
│ │ │ ├── status.go
│ │ │ ├── store.go
│ │ │ ├── types.go
│ │ │ └── types_test.go
│ │ ├── builder.go
│ │ ├── executor/
│ │ │ ├── context.go
│ │ │ └── types.go
│ │ ├── model_registry.go
│ │ ├── pipeline/
│ │ │ └── context.go
│ │ ├── pprof_server.go
│ │ ├── providers.go
│ │ ├── rtprovider.go
│ │ ├── rtprovider_test.go
│ │ ├── service.go
│ │ ├── service_codex_executor_binding_test.go
│ │ ├── service_excluded_models_test.go
│ │ ├── service_oauth_model_alias_test.go
│ │ ├── types.go
│ │ ├── usage/
│ │ │ └── manager.go
│ │ └── watcher.go
│ ├── config/
│ │ └── config.go
│ ├── logging/
│ │ └── request_logger.go
│ ├── proxyutil/
│ │ ├── proxy.go
│ │ └── proxy_test.go
│ └── translator/
│ ├── builtin/
│ │ └── builtin.go
│ ├── format.go
│ ├── formats.go
│ ├── helpers.go
│ ├── pipeline.go
│ ├── registry.go
│ └── types.go
└── test/
├── amp_management_test.go
├── builtin_tools_translation_test.go
└── thinking_conversion_test.go
================================================
FILE CONTENTS
================================================
================================================
FILE: .dockerignore
================================================
# Git and GitHub folders
.git/*
.github/*
# Docker and CI/CD related files
docker-compose.yml
.dockerignore
.gitignore
.goreleaser.yml
Dockerfile
# Documentation and license
docs/*
README.md
README_CN.md
LICENSE
# Runtime data folders (should be mounted as volumes)
auths/*
logs/*
conv/*
config.yaml
# Development/editor
bin/*
.vscode/*
.claude/*
.codex/*
.gemini/*
.serena/*
.agent/*
.agents/*
.opencode/*
.idea/*
.bmad/*
_bmad/*
_bmad-output/*
================================================
FILE: .github/FUNDING.yml
================================================
github: [router-for-me]
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: ''
assignees: ''
---
**Is it a request payload issue?**
[ ] Yes, this is a request payload issue. I am using a client/cURL to send a request payload, but I received an unexpected error.
[ ] No, it's another issue.
**If it's a request payload issue, you MUST know**
Our team doesn't have any GODs or ORACLEs or MIND READERs. Please make sure to attach the request log or curl payload.
**Describe the bug**
A clear and concise description of what the bug is.
**CLI Type**
What type of CLI account do you use? (gemini-cli, gemini, codex, claude code or openai-compatibility)
**Model Name**
What model are you using? (example: gemini-2.5-pro, claude-sonnet-4-20250514, gpt-5, etc.)
**LLM Client**
What LLM Client are you using? (example: roo-code, cline, claude code, etc.)
**Request Information**
The best way is to paste the cURL command of the HTTP request here.
Alternatively, you can set `request-log: true` in the `config.yaml` file and then upload the detailed log file.
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
**OS Type**
- OS: [e.g. macOS]
- Version [e.g. 15.6.0]
**Additional context**
Add any other context about the problem here.
================================================
FILE: .github/workflows/docker-image.yml
================================================
name: docker-image
on:
push:
tags:
- v*
env:
APP_NAME: CLIProxyAPI
DOCKERHUB_REPO: eceasy/cli-proxy-api
jobs:
docker_amd64:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata
run: |
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Build and push (amd64)
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/amd64
push: true
build-args: |
VERSION=${{ env.VERSION }}
COMMIT=${{ env.COMMIT }}
BUILD_DATE=${{ env.BUILD_DATE }}
tags: |
${{ env.DOCKERHUB_REPO }}:latest-amd64
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-amd64
docker_arm64:
runs-on: ubuntu-24.04-arm
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata
run: |
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Build and push (arm64)
uses: docker/build-push-action@v6
with:
context: .
platforms: linux/arm64
push: true
build-args: |
VERSION=${{ env.VERSION }}
COMMIT=${{ env.COMMIT }}
BUILD_DATE=${{ env.BUILD_DATE }}
tags: |
${{ env.DOCKERHUB_REPO }}:latest-arm64
${{ env.DOCKERHUB_REPO }}:${{ env.VERSION }}-arm64
docker_manifest:
runs-on: ubuntu-latest
needs:
- docker_amd64
- docker_arm64
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate Build Metadata
run: |
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- name: Create and push multi-arch manifests
run: |
docker buildx imagetools create \
--tag "${DOCKERHUB_REPO}:latest" \
"${DOCKERHUB_REPO}:latest-amd64" \
"${DOCKERHUB_REPO}:latest-arm64"
docker buildx imagetools create \
--tag "${DOCKERHUB_REPO}:${VERSION}" \
"${DOCKERHUB_REPO}:${VERSION}-amd64" \
"${DOCKERHUB_REPO}:${VERSION}-arm64"
- name: Cleanup temporary tags
continue-on-error: true
env:
DOCKERHUB_USERNAME: ${{ secrets.DOCKERHUB_USERNAME }}
DOCKERHUB_TOKEN: ${{ secrets.DOCKERHUB_TOKEN }}
run: |
set -euo pipefail
namespace="${DOCKERHUB_REPO%%/*}"
repo_name="${DOCKERHUB_REPO#*/}"
token="$(
curl -fsSL \
-H 'Content-Type: application/json' \
-d "{\"username\":\"${DOCKERHUB_USERNAME}\",\"password\":\"${DOCKERHUB_TOKEN}\"}" \
'https://hub.docker.com/v2/users/login/' \
| python3 -c 'import json,sys; print(json.load(sys.stdin)["token"])'
)"
delete_tag() {
local tag="$1"
local url="https://hub.docker.com/v2/repositories/${namespace}/${repo_name}/tags/${tag}/"
local http_code
http_code="$(curl -sS -o /dev/null -w "%{http_code}" -X DELETE -H "Authorization: JWT ${token}" "${url}" || true)"
if [ "${http_code}" = "204" ] || [ "${http_code}" = "404" ]; then
echo "Docker Hub tag removed (or missing): ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})"
return 0
fi
echo "Docker Hub tag delete failed: ${DOCKERHUB_REPO}:${tag} (HTTP ${http_code})"
return 0
}
delete_tag "latest-amd64"
delete_tag "latest-arm64"
delete_tag "${VERSION}-amd64"
delete_tag "${VERSION}-arm64"
================================================
FILE: .github/workflows/pr-path-guard.yml
================================================
name: translator-path-guard
on:
pull_request:
types:
- opened
- synchronize
- reopened
jobs:
ensure-no-translator-changes:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Detect internal/translator changes
id: changed-files
uses: tj-actions/changed-files@v45
with:
files: |
internal/translator/**
- name: Fail when restricted paths change
if: steps.changed-files.outputs.any_changed == 'true'
run: |
echo "Changes under internal/translator are not allowed in pull requests."
echo "You need to create an issue for our maintenance team to make the necessary changes."
exit 1
================================================
FILE: .github/workflows/pr-test-build.yml
================================================
name: pr-test-build
on:
pull_request:
permissions:
contents: read
jobs:
build:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- name: Set up Go
uses: actions/setup-go@v5
with:
go-version-file: go.mod
cache: true
- name: Build
run: |
go build -o test-output ./cmd/server
rm -f test-output
================================================
FILE: .github/workflows/release.yaml
================================================
name: goreleaser
on:
push:
# run only against tags
tags:
- '*'
permissions:
contents: write
jobs:
goreleaser:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Refresh models catalog
run: |
git fetch --depth 1 https://github.com/router-for-me/models.git main
git show FETCH_HEAD:models.json > internal/registry/models/models.json
- run: git fetch --force --tags
- uses: actions/setup-go@v4
with:
go-version: '>=1.26.0'
cache: true
- name: Generate Build Metadata
run: |
echo "VERSION=${GITHUB_REF_NAME}" >> $GITHUB_ENV
echo COMMIT=`git rev-parse --short HEAD` >> $GITHUB_ENV
echo BUILD_DATE=`date -u +%Y-%m-%dT%H:%M:%SZ` >> $GITHUB_ENV
- uses: goreleaser/goreleaser-action@v4
with:
distribution: goreleaser
version: latest
args: release --clean --skip=validate
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
VERSION: ${{ env.VERSION }}
COMMIT: ${{ env.COMMIT }}
BUILD_DATE: ${{ env.BUILD_DATE }}
================================================
FILE: .gitignore
================================================
# Binaries
cli-proxy-api
*.exe
# Configuration
config.yaml
.env
# Generated content
bin/*
logs/*
conv/*
temp/*
refs/*
# Storage backends
pgstore/*
gitstore/*
objectstore/*
# Static assets
static/*
# Authentication data
auths/*
!auths/.gitkeep
# Documentation
docs/*
AGENTS.md
CLAUDE.md
GEMINI.md
# Tooling metadata
.vscode/*
.codex/*
.claude/*
.gemini/*
.serena/*
.agent/*
.agents/*
.agents/*
.opencode/*
.idea/*
.bmad/*
_bmad/*
_bmad-output/*
# macOS
.DS_Store
._*
================================================
FILE: .goreleaser.yml
================================================
version: 2
builds:
- id: "cli-proxy-api"
env:
- CGO_ENABLED=0
goos:
- linux
- windows
- darwin
goarch:
- amd64
- arm64
main: ./cmd/server/
binary: cli-proxy-api
ldflags:
- -s -w -X 'main.Version={{.Version}}' -X 'main.Commit={{.ShortCommit}}' -X 'main.BuildDate={{.Date}}'
archives:
- id: "cli-proxy-api"
format: tar.gz
format_overrides:
- goos: windows
format: zip
files:
- LICENSE
- README.md
- README_CN.md
- config.example.yaml
checksum:
name_template: 'checksums.txt'
snapshot:
name_template: "{{ incpatch .Version }}-next"
changelog:
sort: asc
filters:
exclude:
- '^docs:'
- '^test:'
================================================
FILE: Dockerfile
================================================
FROM golang:1.26-alpine AS builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . .
ARG VERSION=dev
ARG COMMIT=none
ARG BUILD_DATE=unknown
RUN CGO_ENABLED=0 GOOS=linux go build -ldflags="-s -w -X 'main.Version=${VERSION}' -X 'main.Commit=${COMMIT}' -X 'main.BuildDate=${BUILD_DATE}'" -o ./CLIProxyAPI ./cmd/server/
FROM alpine:3.22.0
RUN apk add --no-cache tzdata
RUN mkdir /CLIProxyAPI
COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI
COPY config.example.yaml /CLIProxyAPI/config.example.yaml
WORKDIR /CLIProxyAPI
EXPOSE 8317
ENV TZ=Asia/Shanghai
RUN cp /usr/share/zoneinfo/${TZ} /etc/localtime && echo "${TZ}" > /etc/timezone
CMD ["./CLIProxyAPI"]
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2025-2005.9 Luis Pater
Copyright (c) 2025.9-present Router-For.ME
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# CLI Proxy API
English | [中文](README_CN.md)
A proxy server that provides OpenAI/Gemini/Claude/Codex compatible API interfaces for CLI.
It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth.
So you can use local or multi-account CLI access with OpenAI(include Responses)/Gemini/Claude-compatible clients and SDKs.
## Sponsor
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $10/month. It provides access to their flagship GLM-4.7 & (GLM-5 Only Available for Pro Users)model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
---
Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using this link and enter the "cliproxyapi" promo code during recharge to get 10% off.
Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via this link to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!
## Overview
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
- OpenAI Codex support (GPT models) via OAuth login
- Claude Code support via OAuth login
- Qwen Code support via OAuth login
- iFlow support via OAuth login
- Amp CLI and IDE extensions support with provider routing
- Streaming and non-streaming responses
- Function calling/tools support
- Multimodal input support (text and images)
- Multiple accounts with round-robin load balancing (Gemini, OpenAI, Claude, Qwen and iFlow)
- Simple CLI authentication flows (Gemini, OpenAI, Claude, Qwen and iFlow)
- Generative Language API Key support
- AI Studio Build multi-account load balancing
- Gemini CLI multi-account load balancing
- Claude Code multi-account load balancing
- Qwen Code multi-account load balancing
- iFlow multi-account load balancing
- OpenAI Codex multi-account load balancing
- OpenAI-compatible upstream providers via config (e.g., OpenRouter)
- Reusable Go SDK for embedding the proxy (see `docs/sdk-usage.md`)
## Getting Started
CLIProxyAPI Guides: [https://help.router-for.me/](https://help.router-for.me/)
## Management API
see [MANAGEMENT_API.md](https://help.router-for.me/management/api)
## Amp CLI Support
CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and Amp IDE extensions, enabling you to use your Google/ChatGPT/Claude OAuth subscriptions with Amp's coding tools:
- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`)
- Management proxy for OAuth authentication and account features
- Smart model fallback with automatic routing
- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
- Security-first design with localhost-only management endpoints
**→ [Complete Amp CLI Integration Guide](https://help.router-for.me/agent-client/amp-cli.html)**
## SDK Docs
- Usage: [docs/sdk-usage.md](docs/sdk-usage.md)
- Advanced (executors & translators): [docs/sdk-advanced.md](docs/sdk-advanced.md)
- Access: [docs/sdk-access.md](docs/sdk-access.md)
- Watcher: [docs/sdk-watcher.md](docs/sdk-watcher.md)
- Custom Provider Example: `examples/custom-provider`
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.
1. Fork the repository
2. Create your feature branch (`git checkout -b feature/amazing-feature`)
3. Commit your changes (`git commit -m 'Add some amazing feature'`)
4. Push to the branch (`git push origin feature/amazing-feature`)
5. Open a Pull Request
## Who is with us?
Those projects are based on CLIProxyAPI:
### [vibeproxy](https://github.com/automazeio/vibeproxy)
Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with AI coding tools - no API keys needed
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
### [Quotio](https://github.com/nguyenphutrong/quotio)
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
### [CodMate](https://github.com/loocor/CodMate)
Native macOS SwiftUI app for managing CLI AI sessions (Codex, Claude Code, Gemini CLI) with unified provider management, Git review, project organization, global search, and terminal integration. Integrates CLIProxyAPI to provide OAuth authentication for Codex, Claude, Gemini, Antigravity, and Qwen Code, with built-in and third-party provider rerouting through a single proxy endpoint - no API keys needed for OAuth providers.
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
Windows-native CLIProxyAPI fork with TUI, system tray, and multi-provider OAuth for AI coding tools - no API keys needed.
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
VSCode extension for quick switching between Claude Code models, featuring integrated CLIProxyAPI as its backend with automatic background lifecycle management.
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
Windows desktop app built with Tauri + React for monitoring AI coding assistant quotas via CLIProxyAPI. Track usage across Gemini, Claude, OpenAI Codex, and Antigravity accounts with real-time dashboard, system tray integration, and one-click proxy control - no API keys needed.
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
A lightweight web admin panel for CLIProxyAPI with health checks, resource monitoring, real-time logs, auto-update, request statistics and pricing display. Supports one-click installation and systemd service.
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君 is a cross-platform desktop application for managing AI programming assistants, supporting macOS, Windows, and Linux systems. Unified management of Claude Code, Gemini CLI, OpenAI Codex, Qwen Code, and other AI coding tools, with local proxy for multi-account quota tracking and one-click configuration.
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
A modern web-based management dashboard for CLIProxyAPI built with Next.js, React, and PostgreSQL. Features real-time log streaming, structured configuration editing, API key management, OAuth provider integration for Claude/Gemini/Codex, usage analytics, container management, and config sync with OpenCode via companion plugin - no manual YAML editing needed.
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
Browser extension for one-stop management of New API-compatible relay site accounts, featuring balance and usage dashboards, auto check-in, one-click key export to common apps, in-page API availability testing, and channel/model sync and redirection. It integrates with CLIProxyAPI through the Management API for one-click provider import and config sync.
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
Shadow AI is an AI assistant tool designed specifically for restricted environments. It provides a stealthy operation
mode without windows or traces, and enables cross-device AI Q&A interaction and control via the local area network (
LAN). Essentially, it is an automated collaboration layer of "screen/audio capture + AI inference + low-friction delivery",
helping users to immersively use AI assistants across applications on controlled devices or in restricted environments.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
## More choices
Those projects are ports of CLIProxyAPI or inspired by it:
### [9Router](https://github.com/decolua/9router)
A Next.js implementation inspired by CLIProxyAPI, easy to install and use, built from scratch with format translation (OpenAI/Claude/Gemini/Ollama), combo system with auto-fallback, multi-account management with exponential backoff, a Next.js web dashboard, and support for CLI tools (Cursor, Claude Code, Cline, RooCode) - no API keys needed.
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
Never stop coding. Smart routing to FREE & low-cost AI models with automatic fallback.
OmniRoute is an AI gateway for multi-provider LLMs: an OpenAI-compatible endpoint with smart routing, load balancing, retries, and fallbacks. Add policies, rate limits, caching, and observability for reliable, cost-aware inference.
> [!NOTE]
> If you have developed a port of CLIProxyAPI or a project inspired by it, please open a PR to add it to this list.
## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.
================================================
FILE: README_CN.md
================================================
# CLI 代理 API
[English](README.md) | 中文
一个为 CLI 提供 OpenAI/Gemini/Claude/Codex 兼容 API 接口的代理服务器。
现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)和 Claude Code。
您可以使用本地或多账户的CLI方式,通过任何与 OpenAI(包括Responses)/Gemini/Claude 兼容的客户端和SDK进行访问。
## 赞助商
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7(受限于算力,目前仅限Pro用户开放),为开发者提供顶尖的编码体验。
智谱AI为本产品提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
---
感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用此链接 注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。
感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过此链接 注册的用户,可享受首充8折,企业客户最高可享 7.5 折!
## 功能特性
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录)
- 新增 Claude Code 支持(OAuth 登录)
- 新增 Qwen Code 支持(OAuth 登录)
- 新增 iFlow 支持(OAuth 登录)
- 支持流式与非流式响应
- 函数调用/工具支持
- 多模态输入(文本、图片)
- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude、Qwen 与 iFlow)
- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude、Qwen 与 iFlow)
- 支持 Gemini AIStudio API 密钥
- 支持 AI Studio Build 多账户轮询
- 支持 Gemini CLI 多账户轮询
- 支持 Claude Code 多账户轮询
- 支持 Qwen Code 多账户轮询
- 支持 iFlow 多账户轮询
- 支持 OpenAI Codex 多账户轮询
- 通过配置接入上游 OpenAI 兼容提供商(例如 OpenRouter)
- 可复用的 Go SDK(见 `docs/sdk-usage_CN.md`)
## 新手入门
CLIProxyAPI 用户手册: [https://help.router-for.me/](https://help.router-for.me/cn/)
## 管理 API 文档
请参见 [MANAGEMENT_API_CN.md](https://help.router-for.me/cn/management/api)
## Amp CLI 支持
CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支持,可让你使用自己的 Google/ChatGPT/Claude OAuth 订阅来配合 Amp 编码工具:
- 提供商路由别名,兼容 Amp 的 API 路径模式(`/api/provider/{provider}/v1...`)
- 管理代理,处理 OAuth 认证和账号功能
- 智能模型回退与自动路由
- 以安全为先的设计,管理端点仅限 localhost
**→ [Amp CLI 完整集成指南](https://help.router-for.me/cn/agent-client/amp-cli.html)**
## SDK 文档
- 使用文档:[docs/sdk-usage_CN.md](docs/sdk-usage_CN.md)
- 高级(执行器与翻译器):[docs/sdk-advanced_CN.md](docs/sdk-advanced_CN.md)
- 认证: [docs/sdk-access_CN.md](docs/sdk-access_CN.md)
- 凭据加载/更新: [docs/sdk-watcher_CN.md](docs/sdk-watcher_CN.md)
- 自定义 Provider 示例:`examples/custom-provider`
## 贡献
欢迎贡献!请随时提交 Pull Request。
1. Fork 仓库
2. 创建您的功能分支(`git checkout -b feature/amazing-feature`)
3. 提交您的更改(`git commit -m 'Add some amazing feature'`)
4. 推送到分支(`git push origin feature/amazing-feature`)
5. 打开 Pull Request
## 谁与我们在一起?
这些项目基于 CLIProxyAPI:
### [vibeproxy](https://github.com/automazeio/vibeproxy)
一个原生 macOS 菜单栏应用,让您可以使用 Claude Code & ChatGPT 订阅服务和 AI 编程工具,无需 API 密钥。
### [Subtitle Translator](https://github.com/VjayC/SRT-Subtitle-Translator-Validator)
一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型(Gemini, Codex, Antigravity),无需 API 密钥。
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
基于 macOS 平台的原生 CLIProxyAPI GUI:配置供应商、模型映射以及OAuth端点,无需 API 密钥。
### [Quotio](https://github.com/nguyenphutrong/quotio)
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
### [CodMate](https://github.com/loocor/CodMate)
原生 macOS SwiftUI 应用,用于管理 CLI AI 会话(Claude Code、Codex、Gemini CLI),提供统一的提供商管理、Git 审查、项目组织、全局搜索和终端集成。集成 CLIProxyAPI 为 Codex、Claude、Gemini、Antigravity 和 Qwen Code 提供统一的 OAuth 认证,支持内置和第三方提供商通过单一代理端点重路由 - OAuth 提供商无需 API 密钥。
### [ProxyPilot](https://github.com/Finesssee/ProxyPilot)
原生 Windows CLIProxyAPI 分支,集成 TUI、系统托盘及多服务商 OAuth 认证,专为 AI 编程工具打造,无需 API 密钥。
### [Claude Proxy VSCode](https://github.com/uzhao/claude-proxy-vscode)
一款 VSCode 扩展,提供了在 VSCode 中快速切换 Claude Code 模型的功能,内置 CLIProxyAPI 作为其后端,支持后台自动启动和关闭。
### [ZeroLimit](https://github.com/0xtbug/zero-limit)
Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI 监控 AI 编程助手配额。支持跨 Gemini、Claude、OpenAI Codex 和 Antigravity 账户的使用量追踪,提供实时仪表盘、系统托盘集成和一键代理控制,无需 API 密钥。
### [CPA-XXX Panel](https://github.com/ferretgeek/CPA-X)
面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方库。主要功能包括:自动创建快捷方式、静默运行、密码管理、通道切换(Main / Plus)以及自动下载与更新。
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君是一款用于管理AI编程助手的跨平台桌面应用,支持macOS、Windows、Linux系统。统一管理Claude Code、Gemini CLI、OpenAI Codex、Qwen Code等AI编程工具,本地代理实现多账户配额跟踪和一键配置。
### [CLIProxyAPI Dashboard](https://github.com/itsmylife44/cliproxyapi-dashboard)
一个面向 CLIProxyAPI 的现代化 Web 管理仪表盘,基于 Next.js、React 和 PostgreSQL 构建。支持实时日志流、结构化配置编辑、API Key 管理、Claude/Gemini/Codex 的 OAuth 提供方集成、使用量分析、容器管理,并可通过配套插件与 OpenCode 同步配置,无需手动编辑 YAML。
### [All API Hub](https://github.com/qixing-jk/all-api-hub)
用于一站式管理 New API 兼容中转站账号的浏览器扩展,提供余额与用量看板、自动签到、密钥一键导出到常用应用、网页内 API 可用性测试,以及渠道与模型同步和重定向。支持通过 CLIProxyAPI Management API 一键导入 Provider 与同步配置。
### [Shadow AI](https://github.com/HEUDavid/shadow-ai)
Shadow AI 是一款专为受限环境设计的 AI 辅助工具。提供无窗口、无痕迹的隐蔽运行方式,并通过局域网实现跨设备的 AI 问答交互与控制。本质上是一个「屏幕/音频采集 + AI 推理 + 低摩擦投送」的自动化协作层,帮助用户在受控设备/受限环境下沉浸式跨应用地使用 AI 助手。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
## 更多选择
以下项目是 CLIProxyAPI 的移植版或受其启发:
### [9Router](https://github.com/decolua/9router)
基于 Next.js 的实现,灵感来自 CLIProxyAPI,易于安装使用;自研格式转换(OpenAI/Claude/Gemini/Ollama)、组合系统与自动回退、多账户管理(指数退避)、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
### [OmniRoute](https://github.com/diegosouzapw/OmniRoute)
代码不止,创新不停。智能路由至免费及低成本 AI 模型,并支持自动故障转移。
OmniRoute 是一个面向多供应商大语言模型的 AI 网关:它提供兼容 OpenAI 的端点,具备智能路由、负载均衡、重试及回退机制。通过添加策略、速率限制、缓存和可观测性,确保推理过程既可靠又具备成本意识。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
## 写给所有中国网友的
QQ 群:188637136
或
Telegram 群:https://t.me/CLIProxyAPI
================================================
FILE: auths/.gitkeep
================================================
================================================
FILE: cmd/fetch_antigravity_models/main.go
================================================
// Command fetch_antigravity_models connects to the Antigravity API using the
// stored auth credentials and saves the dynamically fetched model list to a
// JSON file for inspection or offline use.
//
// Usage:
//
// go run ./cmd/fetch_antigravity_models [flags]
//
// Flags:
//
// --auths-dir Directory containing auth JSON files (default: "auths")
// --output Output JSON file path (default: "antigravity_models.json")
// --pretty Pretty-print the output JSON (default: true)
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
sdkauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityModelsPath = "/v1internal:fetchAvailableModels"
)
func init() {
logging.SetupBaseLogger()
log.SetLevel(log.InfoLevel)
}
// modelOutput wraps the fetched model list with fetch metadata.
type modelOutput struct {
Models []modelEntry `json:"models"`
}
// modelEntry contains only the fields we want to keep for static model definitions.
type modelEntry struct {
ID string `json:"id"`
Object string `json:"object"`
OwnedBy string `json:"owned_by"`
Type string `json:"type"`
DisplayName string `json:"display_name"`
Name string `json:"name"`
Description string `json:"description"`
ContextLength int `json:"context_length,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
}
func main() {
var authsDir string
var outputPath string
var pretty bool
flag.StringVar(&authsDir, "auths-dir", "auths", "Directory containing auth JSON files")
flag.StringVar(&outputPath, "output", "antigravity_models.json", "Output JSON file path")
flag.BoolVar(&pretty, "pretty", true, "Pretty-print the output JSON")
flag.Parse()
// Resolve relative paths against the working directory.
wd, err := os.Getwd()
if err != nil {
fmt.Fprintf(os.Stderr, "error: cannot get working directory: %v\n", err)
os.Exit(1)
}
if !filepath.IsAbs(authsDir) {
authsDir = filepath.Join(wd, authsDir)
}
if !filepath.IsAbs(outputPath) {
outputPath = filepath.Join(wd, outputPath)
}
fmt.Printf("Scanning auth files in: %s\n", authsDir)
// Load all auth records from the directory.
fileStore := sdkauth.NewFileTokenStore()
fileStore.SetBaseDir(authsDir)
ctx := context.Background()
auths, err := fileStore.List(ctx)
if err != nil {
fmt.Fprintf(os.Stderr, "error: failed to list auth files: %v\n", err)
os.Exit(1)
}
if len(auths) == 0 {
fmt.Fprintf(os.Stderr, "error: no auth files found in %s\n", authsDir)
os.Exit(1)
}
// Find the first enabled antigravity auth.
var chosen *coreauth.Auth
for _, a := range auths {
if a == nil || a.Disabled {
continue
}
if strings.EqualFold(strings.TrimSpace(a.Provider), "antigravity") {
chosen = a
break
}
}
if chosen == nil {
fmt.Fprintf(os.Stderr, "error: no enabled antigravity auth found in %s\n", authsDir)
os.Exit(1)
}
fmt.Printf("Using auth: id=%s label=%s\n", chosen.ID, chosen.Label)
// Fetch models from the upstream Antigravity API.
fmt.Println("Fetching Antigravity model list from upstream...")
fetchCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
models := fetchModels(fetchCtx, chosen)
if len(models) == 0 {
fmt.Fprintln(os.Stderr, "warning: no models returned (API may be unavailable or token expired)")
} else {
fmt.Printf("Fetched %d models.\n", len(models))
}
// Build the output payload.
out := modelOutput{
Models: models,
}
// Marshal to JSON.
var raw []byte
if pretty {
raw, err = json.MarshalIndent(out, "", " ")
} else {
raw, err = json.Marshal(out)
}
if err != nil {
fmt.Fprintf(os.Stderr, "error: failed to marshal JSON: %v\n", err)
os.Exit(1)
}
if err = os.WriteFile(outputPath, raw, 0o644); err != nil {
fmt.Fprintf(os.Stderr, "error: failed to write output file %s: %v\n", outputPath, err)
os.Exit(1)
}
fmt.Printf("Model list saved to: %s\n", outputPath)
}
func fetchModels(ctx context.Context, auth *coreauth.Auth) []modelEntry {
accessToken := metaStringValue(auth.Metadata, "access_token")
if accessToken == "" {
fmt.Fprintln(os.Stderr, "error: no access token found in auth")
return nil
}
baseURLs := []string{antigravityBaseURLProd, antigravityBaseURLDaily, antigravitySandboxBaseURLDaily}
for _, baseURL := range baseURLs {
modelsURL := baseURL + antigravityModelsPath
var payload []byte
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok && strings.TrimSpace(pid) != "" {
payload = []byte(fmt.Sprintf(`{"project": "%s"}`, strings.TrimSpace(pid)))
}
}
if len(payload) == 0 {
payload = []byte(`{}`)
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, strings.NewReader(string(payload)))
if errReq != nil {
continue
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+accessToken)
httpReq.Header.Set("User-Agent", "antigravity/1.19.6 darwin/arm64")
httpClient := &http.Client{Timeout: 30 * time.Second}
if transport, _, errProxy := proxyutil.BuildHTTPTransport(auth.ProxyURL); errProxy == nil && transport != nil {
httpClient.Transport = transport
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
continue
}
bodyBytes, errRead := io.ReadAll(httpResp.Body)
httpResp.Body.Close()
if errRead != nil {
continue
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
continue
}
result := gjson.GetBytes(bodyBytes, "models")
if !result.Exists() {
continue
}
var models []modelEntry
for originalName, modelData := range result.Map() {
modelID := strings.TrimSpace(originalName)
if modelID == "" {
continue
}
// Skip internal/experimental models
switch modelID {
case "chat_20706", "chat_23310", "tab_flash_lite_preview", "tab_jump_flash_lite_preview", "gemini-2.5-flash-thinking", "gemini-2.5-pro":
continue
}
displayName := modelData.Get("displayName").String()
if displayName == "" {
displayName = modelID
}
entry := modelEntry{
ID: modelID,
Object: "model",
OwnedBy: "antigravity",
Type: "antigravity",
DisplayName: displayName,
Name: modelID,
Description: displayName,
}
if maxTok := modelData.Get("maxTokens").Int(); maxTok > 0 {
entry.ContextLength = int(maxTok)
}
if maxOut := modelData.Get("maxOutputTokens").Int(); maxOut > 0 {
entry.MaxCompletionTokens = int(maxOut)
}
models = append(models, entry)
}
return models
}
return nil
}
func metaStringValue(m map[string]interface{}, key string) string {
if m == nil {
return ""
}
v, ok := m[key]
if !ok {
return ""
}
switch val := v.(type) {
case string:
return val
default:
return ""
}
}
================================================
FILE: cmd/server/main.go
================================================
// Package main provides the entry point for the CLI Proxy API server.
// This server acts as a proxy that provides OpenAI/Gemini/Claude compatible API interfaces
// for CLI models, allowing CLI models to be used with tools and libraries designed for standard AI APIs.
package main
import (
"context"
"errors"
"flag"
"fmt"
"io"
"io/fs"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/joho/godotenv"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/store"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
"github.com/router-for-me/CLIProxyAPI/v6/internal/tui"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
var (
Version = "dev"
Commit = "none"
BuildDate = "unknown"
DefaultConfigPath = ""
)
// init initializes the shared logger setup.
func init() {
logging.SetupBaseLogger()
buildinfo.Version = Version
buildinfo.Commit = Commit
buildinfo.BuildDate = BuildDate
}
// main is the entry point of the application.
// It parses command-line flags, loads configuration, and starts the appropriate
// service based on the provided flags (login, codex-login, or server mode).
func main() {
fmt.Printf("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s\n", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
// Command-line flags to control the application's behavior.
var login bool
var codexLogin bool
var codexDeviceLogin bool
var claudeLogin bool
var qwenLogin bool
var iflowLogin bool
var iflowCookie bool
var noBrowser bool
var oauthCallbackPort int
var antigravityLogin bool
var kimiLogin bool
var projectID string
var vertexImport string
var configPath string
var password string
var tuiMode bool
var standalone bool
// Define command-line flags for different operation modes.
flag.BoolVar(&login, "login", false, "Login Google Account")
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
flag.BoolVar(&iflowCookie, "iflow-cookie", false, "Login to iFlow using Cookie")
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&password, "password", "", "")
flag.BoolVar(&tuiMode, "tui", false, "Start with terminal management UI")
flag.BoolVar(&standalone, "standalone", false, "In TUI mode, start an embedded local server")
flag.CommandLine.Usage = func() {
out := flag.CommandLine.Output()
_, _ = fmt.Fprintf(out, "Usage of %s\n", os.Args[0])
flag.CommandLine.VisitAll(func(f *flag.Flag) {
if f.Name == "password" {
return
}
s := fmt.Sprintf(" -%s", f.Name)
name, unquoteUsage := flag.UnquoteUsage(f)
if name != "" {
s += " " + name
}
if len(s) <= 4 {
s += " "
} else {
s += "\n "
}
if unquoteUsage != "" {
s += unquoteUsage
}
if f.DefValue != "" && f.DefValue != "false" && f.DefValue != "0" {
s += fmt.Sprintf(" (default %s)", f.DefValue)
}
_, _ = fmt.Fprint(out, s+"\n")
})
}
// Parse the command-line flags.
flag.Parse()
// Core application variables.
var err error
var cfg *config.Config
var isCloudDeploy bool
var (
usePostgresStore bool
pgStoreDSN string
pgStoreSchema string
pgStoreLocalPath string
pgStoreInst *store.PostgresStore
useGitStore bool
gitStoreRemoteURL string
gitStoreUser string
gitStorePassword string
gitStoreLocalPath string
gitStoreInst *store.GitTokenStore
gitStoreRoot string
useObjectStore bool
objectStoreEndpoint string
objectStoreAccess string
objectStoreSecret string
objectStoreBucket string
objectStoreLocalPath string
objectStoreInst *store.ObjectTokenStore
)
wd, err := os.Getwd()
if err != nil {
log.Errorf("failed to get working directory: %v", err)
return
}
// Load environment variables from .env if present.
if errLoad := godotenv.Load(filepath.Join(wd, ".env")); errLoad != nil {
if !errors.Is(errLoad, os.ErrNotExist) {
log.WithError(errLoad).Warn("failed to load .env file")
}
}
lookupEnv := func(keys ...string) (string, bool) {
for _, key := range keys {
if value, ok := os.LookupEnv(key); ok {
if trimmed := strings.TrimSpace(value); trimmed != "" {
return trimmed, true
}
}
}
return "", false
}
writableBase := util.WritablePath()
if value, ok := lookupEnv("PGSTORE_DSN", "pgstore_dsn"); ok {
usePostgresStore = true
pgStoreDSN = value
}
if usePostgresStore {
if value, ok := lookupEnv("PGSTORE_SCHEMA", "pgstore_schema"); ok {
pgStoreSchema = value
}
if value, ok := lookupEnv("PGSTORE_LOCAL_PATH", "pgstore_local_path"); ok {
pgStoreLocalPath = value
}
if pgStoreLocalPath == "" {
if writableBase != "" {
pgStoreLocalPath = writableBase
} else {
pgStoreLocalPath = wd
}
}
useGitStore = false
}
if value, ok := lookupEnv("GITSTORE_GIT_URL", "gitstore_git_url"); ok {
useGitStore = true
gitStoreRemoteURL = value
}
if value, ok := lookupEnv("GITSTORE_GIT_USERNAME", "gitstore_git_username"); ok {
gitStoreUser = value
}
if value, ok := lookupEnv("GITSTORE_GIT_TOKEN", "gitstore_git_token"); ok {
gitStorePassword = value
}
if value, ok := lookupEnv("GITSTORE_LOCAL_PATH", "gitstore_local_path"); ok {
gitStoreLocalPath = value
}
if value, ok := lookupEnv("OBJECTSTORE_ENDPOINT", "objectstore_endpoint"); ok {
useObjectStore = true
objectStoreEndpoint = value
}
if value, ok := lookupEnv("OBJECTSTORE_ACCESS_KEY", "objectstore_access_key"); ok {
objectStoreAccess = value
}
if value, ok := lookupEnv("OBJECTSTORE_SECRET_KEY", "objectstore_secret_key"); ok {
objectStoreSecret = value
}
if value, ok := lookupEnv("OBJECTSTORE_BUCKET", "objectstore_bucket"); ok {
objectStoreBucket = value
}
if value, ok := lookupEnv("OBJECTSTORE_LOCAL_PATH", "objectstore_local_path"); ok {
objectStoreLocalPath = value
}
// Check for cloud deploy mode only on first execution
// Read env var name in uppercase: DEPLOY
deployEnv := os.Getenv("DEPLOY")
if deployEnv == "cloud" {
isCloudDeploy = true
}
// Determine and load the configuration file.
// Prefer the Postgres store when configured, otherwise fallback to git or local files.
var configFilePath string
if usePostgresStore {
if pgStoreLocalPath == "" {
pgStoreLocalPath = wd
}
pgStoreLocalPath = filepath.Join(pgStoreLocalPath, "pgstore")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
pgStoreInst, err = store.NewPostgresStore(ctx, store.PostgresStoreConfig{
DSN: pgStoreDSN,
Schema: pgStoreSchema,
SpoolDir: pgStoreLocalPath,
})
cancel()
if err != nil {
log.Errorf("failed to initialize postgres token store: %v", err)
return
}
examplePath := filepath.Join(wd, "config.example.yaml")
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
cancel()
log.Errorf("failed to bootstrap postgres-backed config: %v", errBootstrap)
return
}
cancel()
configFilePath = pgStoreInst.ConfigPath()
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
if err == nil {
cfg.AuthDir = pgStoreInst.AuthDir()
log.Infof("postgres-backed token store enabled, workspace path: %s", pgStoreInst.WorkDir())
}
} else if useObjectStore {
if objectStoreLocalPath == "" {
if writableBase != "" {
objectStoreLocalPath = writableBase
} else {
objectStoreLocalPath = wd
}
}
objectStoreRoot := filepath.Join(objectStoreLocalPath, "objectstore")
resolvedEndpoint := strings.TrimSpace(objectStoreEndpoint)
useSSL := true
if strings.Contains(resolvedEndpoint, "://") {
parsed, errParse := url.Parse(resolvedEndpoint)
if errParse != nil {
log.Errorf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
return
}
switch strings.ToLower(parsed.Scheme) {
case "http":
useSSL = false
case "https":
useSSL = true
default:
log.Errorf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
return
}
if parsed.Host == "" {
log.Errorf("object store endpoint %q is missing host information", objectStoreEndpoint)
return
}
resolvedEndpoint = parsed.Host
if parsed.Path != "" && parsed.Path != "/" {
resolvedEndpoint = strings.TrimSuffix(parsed.Host+parsed.Path, "/")
}
}
resolvedEndpoint = strings.TrimRight(resolvedEndpoint, "/")
objCfg := store.ObjectStoreConfig{
Endpoint: resolvedEndpoint,
Bucket: objectStoreBucket,
AccessKey: objectStoreAccess,
SecretKey: objectStoreSecret,
LocalRoot: objectStoreRoot,
UseSSL: useSSL,
PathStyle: true,
}
objectStoreInst, err = store.NewObjectTokenStore(objCfg)
if err != nil {
log.Errorf("failed to initialize object token store: %v", err)
return
}
examplePath := filepath.Join(wd, "config.example.yaml")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
cancel()
log.Errorf("failed to bootstrap object-backed config: %v", errBootstrap)
return
}
cancel()
configFilePath = objectStoreInst.ConfigPath()
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
if err == nil {
if cfg == nil {
cfg = &config.Config{}
}
cfg.AuthDir = objectStoreInst.AuthDir()
log.Infof("object-backed token store enabled, bucket: %s", objectStoreBucket)
}
} else if useGitStore {
if gitStoreLocalPath == "" {
if writableBase != "" {
gitStoreLocalPath = writableBase
} else {
gitStoreLocalPath = wd
}
}
gitStoreRoot = filepath.Join(gitStoreLocalPath, "gitstore")
authDir := filepath.Join(gitStoreRoot, "auths")
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
gitStoreInst.SetBaseDir(authDir)
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
log.Errorf("failed to prepare git token store: %v", errRepo)
return
}
configFilePath = gitStoreInst.ConfigPath()
if configFilePath == "" {
configFilePath = filepath.Join(gitStoreRoot, "config", "config.yaml")
}
if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) {
examplePath := filepath.Join(wd, "config.example.yaml")
if _, errExample := os.Stat(examplePath); errExample != nil {
log.Errorf("failed to find template config file: %v", errExample)
return
}
if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil {
log.Errorf("failed to bootstrap git-backed config: %v", errCopy)
return
}
if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil {
log.Errorf("failed to commit initial git-backed config: %v", errCommit)
return
}
log.Infof("git-backed config initialized from template: %s", configFilePath)
} else if statErr != nil {
log.Errorf("failed to inspect git-backed config: %v", statErr)
return
}
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
if err == nil {
cfg.AuthDir = gitStoreInst.AuthDir()
log.Infof("git-backed token store enabled, repository path: %s", gitStoreRoot)
}
} else if configPath != "" {
configFilePath = configPath
cfg, err = config.LoadConfigOptional(configPath, isCloudDeploy)
} else {
wd, err = os.Getwd()
if err != nil {
log.Errorf("failed to get working directory: %v", err)
return
}
configFilePath = filepath.Join(wd, "config.yaml")
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
}
if err != nil {
log.Errorf("failed to load config: %v", err)
return
}
if cfg == nil {
cfg = &config.Config{}
}
// In cloud deploy mode, check if we have a valid configuration
var configFileExists bool
if isCloudDeploy {
if info, errStat := os.Stat(configFilePath); errStat != nil {
// Don't mislead: API server will not start until configuration is provided.
log.Info("Cloud deploy mode: No configuration file detected; standing by for configuration")
configFileExists = false
} else if info.IsDir() {
log.Info("Cloud deploy mode: Config path is a directory; standing by for configuration")
configFileExists = false
} else if cfg.Port == 0 {
// LoadConfigOptional returns empty config when file is empty or invalid.
// Config file exists but is empty or invalid; treat as missing config
log.Info("Cloud deploy mode: Configuration file is empty or invalid; standing by for valid configuration")
configFileExists = false
} else {
log.Info("Cloud deploy mode: Configuration file detected; starting service")
configFileExists = true
}
}
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to configure log output: %v", err)
return
}
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
// Set the log level based on the configuration.
util.SetLogLevel(cfg)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
return
} else {
cfg.AuthDir = resolvedAuthDir
}
managementasset.SetCurrentConfig(cfg)
// Create login options to be used in authentication flows.
options := &cmd.LoginOptions{
NoBrowser: noBrowser,
CallbackPort: oauthCallbackPort,
}
// Register the shared token store once so all components use the same persistence backend.
if usePostgresStore {
sdkAuth.RegisterTokenStore(pgStoreInst)
} else if useObjectStore {
sdkAuth.RegisterTokenStore(objectStoreInst)
} else if useGitStore {
sdkAuth.RegisterTokenStore(gitStoreInst)
} else {
sdkAuth.RegisterTokenStore(sdkAuth.NewFileTokenStore())
}
// Register built-in access providers before constructing services.
configaccess.Register(&cfg.SDKConfig)
// Handle different command modes based on the provided flags.
if vertexImport != "" {
// Handle Vertex service account import
cmd.DoVertexImport(cfg, vertexImport)
} else if login {
// Handle Google/Gemini login
cmd.DoLogin(cfg, projectID, options)
} else if antigravityLogin {
// Handle Antigravity login
cmd.DoAntigravityLogin(cfg, options)
} else if codexLogin {
// Handle Codex login
cmd.DoCodexLogin(cfg, options)
} else if codexDeviceLogin {
// Handle Codex device-code login
cmd.DoCodexDeviceLogin(cfg, options)
} else if claudeLogin {
// Handle Claude login
cmd.DoClaudeLogin(cfg, options)
} else if qwenLogin {
cmd.DoQwenLogin(cfg, options)
} else if iflowLogin {
cmd.DoIFlowLogin(cfg, options)
} else if iflowCookie {
cmd.DoIFlowCookieAuth(cfg, options)
} else if kimiLogin {
cmd.DoKimiLogin(cfg, options)
} else {
// In cloud deploy mode without config file, just wait for shutdown signals
if isCloudDeploy && !configFileExists {
// No config file available, just wait for shutdown
cmd.WaitForCloudDeploy()
return
}
if tuiMode {
if standalone {
// Standalone mode: start an embedded local server and connect TUI client to it.
managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
hook := tui.NewLogHook(2000)
hook.SetFormatter(&logging.LogFormatter{})
log.AddHook(hook)
origStdout := os.Stdout
origStderr := os.Stderr
origLogOutput := log.StandardLogger().Out
log.SetOutput(io.Discard)
devNull, errOpenDevNull := os.Open(os.DevNull)
if errOpenDevNull == nil {
os.Stdout = devNull
os.Stderr = devNull
}
restoreIO := func() {
os.Stdout = origStdout
os.Stderr = origStderr
log.SetOutput(origLogOutput)
if devNull != nil {
_ = devNull.Close()
}
}
localMgmtPassword := fmt.Sprintf("tui-%d-%d", os.Getpid(), time.Now().UnixNano())
if password == "" {
password = localMgmtPassword
}
cancel, done := cmd.StartServiceBackground(cfg, configFilePath, password)
client := tui.NewClient(cfg.Port, password)
ready := false
backoff := 100 * time.Millisecond
for i := 0; i < 30; i++ {
if _, errGetConfig := client.GetConfig(); errGetConfig == nil {
ready = true
break
}
time.Sleep(backoff)
if backoff < time.Second {
backoff = time.Duration(float64(backoff) * 1.5)
}
}
if !ready {
restoreIO()
cancel()
<-done
fmt.Fprintf(os.Stderr, "TUI error: embedded server is not ready\n")
return
}
if errRun := tui.Run(cfg.Port, password, hook, origStdout); errRun != nil {
restoreIO()
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
} else {
restoreIO()
}
cancel()
<-done
} else {
// Default TUI mode: pure management client.
// The proxy server must already be running.
if errRun := tui.Run(cfg.Port, password, nil, os.Stdout); errRun != nil {
fmt.Fprintf(os.Stderr, "TUI error: %v\n", errRun)
}
}
} else {
// Start the main proxy service
managementasset.StartAutoUpdater(context.Background(), configFilePath)
registry.StartModelsUpdater(context.Background())
cmd.StartService(cfg, configFilePath, password)
}
}
}
================================================
FILE: config.example.yaml
================================================
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
host: ""
# Server port
port: 8317
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
tls:
enable: false
cert: ""
key: ""
# Management API settings
remote-management:
# Whether to allow remote (non-localhost) management access.
# When false, only localhost can access management endpoints (a key is still required).
allow-remote: false
# Management key. If a plaintext value is provided here, it will be hashed on startup.
# All management requests (even from localhost) require this key.
# Leave empty to disable the Management API entirely (404 for all /v0/management routes).
secret-key: ""
# Disable the bundled management control panel asset download and HTTP route when true.
disable-control-panel: false
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
# Authentication directory (supports ~ for home directory)
auth-dir: "~/.cli-proxy-api"
# API keys for authentication
api-keys:
- "your-api-key-1"
- "your-api-key-2"
- "your-api-key-3"
# Enable debug logging
debug: false
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
pprof:
enable: false
addr: "127.0.0.1:8316"
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
# When true, write application logs to rotating files instead of stdout
logging-to-file: false
# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
# files are deleted until within the limit. Set to 0 to disable.
logs-max-total-size-mb: 0
# Maximum number of error log files retained when request logging is disabled.
# When exceeded, the oldest error log files are deleted. Default is 10. Set to 0 to disable cleanup.
error-logs-max-files: 10
# When false, disable in-memory usage statistics aggregation
usage-statistics-enabled: false
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
# Per-entry proxy-url also supports "direct" or "none" to bypass both the global proxy-url and environment proxies explicitly.
proxy-url: ""
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
force-model-prefix: false
# When true, forward filtered upstream response headers to downstream clients.
# Default is false (disabled).
passthrough-headers: false
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
request-retry: 3
# Maximum number of different credentials to try for one failed request.
# Set to 0 to keep legacy behavior (try all available credentials).
max-retry-credentials: 0
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
max-retry-interval: 30
# Quota exceeded behavior
quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
# Routing strategy for selecting credentials when multiple match.
routing:
strategy: "round-robin" # round-robin (default), fill-first
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
# When > 0, emit blank lines every N seconds for non-streaming responses to prevent idle timeouts.
nonstream-keepalive-interval: 0
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
# streaming:
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
# Gemini API keys
# gemini-api-key:
# - api-key: "AIzaSy...01"
# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
# base-url: "https://generativelanguage.googleapis.com"
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080"
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "gemini-2.5-flash" # upstream model name
# alias: "gemini-flash" # client alias mapped to the upstream model
# excluded-models:
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
# - api-key: "AIzaSy...02"
# Codex API keys
# codex-api-key:
# - api-key: "sk-atSM..."
# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
# base-url: "https://www.example.com" # use the custom codex API endpoint
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "gpt-5-codex" # upstream model name
# alias: "codex-latest" # client alias mapped to the upstream model
# excluded-models:
# - "gpt-5.1" # exclude specific models (exact match)
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
# - "*-mini" # wildcard matching suffix (e.g. gpt-5-codex-mini)
# - "*codex*" # wildcard matching substring (e.g. gpt-5-codex-low)
# Claude API keys
# claude-api-key:
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
# - api-key: "sk-atSM..."
# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
# base-url: "https://www.example.com" # use the custom claude API endpoint
# headers:
# X-Custom-Header: "custom-value"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# models:
# - name: "claude-3-5-sonnet-20241022" # upstream model name
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
# excluded-models:
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
# - "*-thinking" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
# cloak: # optional: request cloaking for non-Claude-Code clients
# mode: "auto" # "auto" (default): cloak only when client is not Claude Code
# # "always": always apply cloaking
# # "never": never apply cloaking
# strict-mode: false # false (default): prepend Claude Code prompt to user system messages
# # true: strip all user system messages, keep only Claude Code prompt
# sensitive-words: # optional: words to obfuscate with zero-width characters
# - "API"
# - "proxy"
# cache-user-id: true # optional: default is false; set true to reuse cached user_id per API key instead of generating a random one each request
# Default headers for Claude API requests. Update when Claude Code releases new versions.
# These are used as fallbacks when the client does not send its own headers.
# claude-header-defaults:
# user-agent: "claude-cli/2.1.44 (external, sdk-cli)"
# package-version: "0.74.0"
# runtime-version: "v24.3.0"
# timeout: "600"
# Default headers for Codex OAuth model requests.
# These are used only for file-backed/OAuth Codex requests when the client
# does not send the header. `user-agent` applies to HTTP and websocket requests;
# `beta-features` only applies to websocket requests. They do not apply to codex-api-key entries.
# codex-header-defaults:
# user-agent: "codex_cli_rs/0.114.0 (Mac OS 14.2.0; x86_64) vscode/1.111.0"
# beta-features: "multi_agent"
# OpenAI compatibility providers
# openai-compatibility:
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
# headers:
# X-Custom-Header: "custom-value"
# api-key-entries:
# - api-key: "sk-or-v1-...b780"
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# - api-key: "sk-or-v1-...b781" # without proxy-url
# models: # The models supported by the provider.
# - name: "moonshotai/kimi-k2:free" # The actual model name.
# alias: "kimi-k2" # The alias used in the API.
# # You may repeat the same alias to build an internal model pool.
# # The client still sees only one alias in the model list.
# # Requests to that alias will round-robin across the upstream names below,
# # and if the chosen upstream fails before producing output, the request will
# # continue with the next upstream model in the same alias pool.
# - name: "qwen3.5-plus"
# alias: "claude-opus-4.66"
# - name: "glm-5"
# alias: "claude-opus-4.66"
# - name: "kimi-k2.5"
# alias: "claude-opus-4.66"
# Vertex API keys (Vertex-compatible endpoints, base-url is optional)
# vertex-api-key:
# - api-key: "vk-123..." # x-goog-api-key header
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
# base-url: "https://example.com/api" # optional, e.g. https://zenmux.ai/api; falls back to Google Vertex when omitted
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
# # proxy-url: "direct" # optional: explicit direct connect for this credential
# headers:
# X-Custom-Header: "custom-value"
# models: # optional: map aliases to upstream model names
# - name: "gemini-2.5-flash" # upstream model name
# alias: "vertex-flash" # client-visible alias
# - name: "gemini-2.5-pro"
# alias: "vertex-pro"
# excluded-models: # optional: models to exclude from listing
# - "imagen-3.0-generate-002"
# - "imagen-*"
# Amp Integration
# ampcode:
# # Configure upstream URL for Amp CLI OAuth and management features
# upstream-url: "https://ampcode.com"
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
# upstream-api-key: ""
# # Per-client upstream API key mapping
# # Maps client API keys (from top-level api-keys) to different Amp upstream API keys.
# # Useful when different clients need to use different Amp accounts/quotas.
# # If a client key isn't mapped, falls back to upstream-api-key (default behavior).
# upstream-api-keys:
# - upstream-api-key: "amp_key_for_team_a" # Upstream key to use for these clients
# api-keys: # Client keys that use this upstream key
# - "your-api-key-1"
# - "your-api-key-2"
# - upstream-api-key: "amp_key_for_team_b"
# api-keys:
# - "your-api-key-3"
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
# restrict-management-to-localhost: false
# # Force model mappings to run before checking local API keys (default: false)
# force-model-mappings: false
# # Amp Model Mappings
# # Route unavailable Amp models to alternative models available in your local proxy.
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
# # but you have a similar model available (e.g., Claude Sonnet 4).
# model-mappings:
# - from: "claude-opus-4-5-20251101" # Model requested by Amp CLI
# to: "gemini-claude-opus-4-5-thinking" # Route to this available model instead
# - from: "claude-sonnet-4-5-20250929"
# to: "gemini-claude-sonnet-4-5-thinking"
# - from: "claude-haiku-4-5-20251001"
# to: "gemini-2.5-flash"
# Global OAuth model name aliases (per channel)
# These aliases rename model IDs for both model listing and request routing.
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
# You can repeat the same name with different aliases to expose multiple client model names.
# oauth-model-alias:
# gemini-cli:
# - name: "gemini-2.5-pro" # original model name under this channel
# alias: "g2.5p" # client-visible alias
# fork: true # when true, keep original and also add the alias as an extra model (default: false)
# vertex:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# aistudio:
# - name: "gemini-2.5-pro"
# alias: "g2.5p"
# antigravity:
# - name: "gemini-3-pro-high"
# alias: "gemini-3-pro-preview"
# claude:
# - name: "claude-sonnet-4-5-20250929"
# alias: "cs4.5"
# codex:
# - name: "gpt-5"
# alias: "g5"
# qwen:
# - name: "qwen3-coder-plus"
# alias: "qwen-plus"
# iflow:
# - name: "glm-4.7"
# alias: "glm-god"
# kimi:
# - name: "kimi-k2.5"
# alias: "k2.5"
# OAuth provider excluded models
# oauth-excluded-models:
# gemini-cli:
# - "gemini-2.5-pro" # exclude specific models (exact match)
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
# vertex:
# - "gemini-3-pro-preview"
# aistudio:
# - "gemini-3-pro-preview"
# antigravity:
# - "gemini-3-pro-preview"
# claude:
# - "claude-3-5-haiku-20241022"
# codex:
# - "gpt-5-codex-mini"
# qwen:
# - "vision-model"
# iflow:
# - "tstars2.0"
# kimi:
# - "kimi-k2-thinking"
# Optional payload configuration
# payload:
# default: # Default rules only set parameters when they are missing in the payload.
# - models:
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON path (gjson/sjson syntax) -> value
# "generationConfig.thinkingConfig.thinkingBudget": 32768
# default-raw: # Default raw rules set parameters using raw JSON when missing (must be valid JSON).
# - models:
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
# "generationConfig.responseJsonSchema": "{\"type\":\"object\",\"properties\":{\"answer\":{\"type\":\"string\"}}}"
# override: # Override rules always set parameters, overwriting any existing values.
# - models:
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON path (gjson/sjson syntax) -> value
# "reasoning.effort": "high"
# override-raw: # Override raw rules always set parameters using raw JSON (must be valid JSON).
# - models:
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
# "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}"
# filter: # Filter rules remove specified parameters from the payload.
# - models:
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
# params: # JSON paths (gjson/sjson syntax) to remove from the payload
# - "generationConfig.thinkingConfig.thinkingBudget"
# - "generationConfig.responseJsonSchema"
================================================
FILE: docker-build.ps1
================================================
# build.ps1 - Windows PowerShell Build Script
#
# This script automates the process of building and running the Docker container
# with version information dynamically injected at build time.
# Stop script execution on any error
$ErrorActionPreference = "Stop"
# --- Step 1: Choose Environment ---
Write-Host "Please select an option:"
Write-Host "1) Run using Pre-built Image (Recommended)"
Write-Host "2) Build from Source and Run (For Developers)"
$choice = Read-Host -Prompt "Enter choice [1-2]"
# --- Step 2: Execute based on choice ---
switch ($choice) {
"1" {
Write-Host "--- Running with Pre-built Image ---"
docker compose up -d --remove-orphans --no-build
Write-Host "Services are starting from remote image."
Write-Host "Run 'docker compose logs -f' to see the logs."
}
"2" {
Write-Host "--- Building from Source and Running ---"
# Get Version Information
$VERSION = (git describe --tags --always --dirty)
$COMMIT = (git rev-parse --short HEAD)
$BUILD_DATE = (Get-Date).ToUniversalTime().ToString("yyyy-MM-ddTHH:mm:ssZ")
Write-Host "Building with the following info:"
Write-Host " Version: $VERSION"
Write-Host " Commit: $COMMIT"
Write-Host " Build Date: $BUILD_DATE"
Write-Host "----------------------------------------"
# Build and start the services with a local-only image tag
$env:CLI_PROXY_IMAGE = "cli-proxy-api:local"
Write-Host "Building the Docker image..."
docker compose build --build-arg VERSION=$VERSION --build-arg COMMIT=$COMMIT --build-arg BUILD_DATE=$BUILD_DATE
Write-Host "Starting the services..."
docker compose up -d --remove-orphans --pull never
Write-Host "Build complete. Services are starting."
Write-Host "Run 'docker compose logs -f' to see the logs."
}
default {
Write-Host "Invalid choice. Please enter 1 or 2."
exit 1
}
}
================================================
FILE: docker-build.sh
================================================
#!/usr/bin/env bash
#
# build.sh - Linux/macOS Build Script
#
# This script automates the process of building and running the Docker container
# with version information dynamically injected at build time.
# Hidden feature: Preserve usage statistics across rebuilds
# Usage: ./docker-build.sh --with-usage
# First run prompts for management API key, saved to temp/stats/.api_secret
set -euo pipefail
STATS_DIR="temp/stats"
STATS_FILE="${STATS_DIR}/.usage_backup.json"
SECRET_FILE="${STATS_DIR}/.api_secret"
WITH_USAGE=false
get_port() {
if [[ -f "config.yaml" ]]; then
grep -E "^port:" config.yaml | sed -E 's/^port: *["'"'"']?([0-9]+)["'"'"']?.*$/\1/'
else
echo "8317"
fi
}
export_stats_api_secret() {
if [[ -f "${SECRET_FILE}" ]]; then
API_SECRET=$(cat "${SECRET_FILE}")
else
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
echo "First time using --with-usage. Management API key required."
read -r -p "Enter management key: " -s API_SECRET
echo
echo "${API_SECRET}" > "${SECRET_FILE}"
chmod 600 "${SECRET_FILE}"
fi
}
check_container_running() {
local port
port=$(get_port)
if ! curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
echo "Error: cli-proxy-api service is not responding at localhost:${port}"
echo "Please start the container first or use without --with-usage flag."
exit 1
fi
}
export_stats() {
local port
port=$(get_port)
if [[ ! -d "${STATS_DIR}" ]]; then
mkdir -p "${STATS_DIR}"
fi
check_container_running
echo "Exporting usage statistics..."
EXPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -H "X-Management-Key: ${API_SECRET}" \
"http://localhost:${port}/v0/management/usage/export")
HTTP_CODE=$(echo "${EXPORT_RESPONSE}" | tail -n1)
RESPONSE_BODY=$(echo "${EXPORT_RESPONSE}" | sed '$d')
if [[ "${HTTP_CODE}" != "200" ]]; then
echo "Export failed (HTTP ${HTTP_CODE}): ${RESPONSE_BODY}"
exit 1
fi
echo "${RESPONSE_BODY}" > "${STATS_FILE}"
echo "Statistics exported to ${STATS_FILE}"
}
import_stats() {
local port
port=$(get_port)
echo "Importing usage statistics..."
IMPORT_RESPONSE=$(curl -s -w "\n%{http_code}" -X POST \
-H "X-Management-Key: ${API_SECRET}" \
-H "Content-Type: application/json" \
-d @"${STATS_FILE}" \
"http://localhost:${port}/v0/management/usage/import")
IMPORT_CODE=$(echo "${IMPORT_RESPONSE}" | tail -n1)
IMPORT_BODY=$(echo "${IMPORT_RESPONSE}" | sed '$d')
if [[ "${IMPORT_CODE}" == "200" ]]; then
echo "Statistics imported successfully"
else
echo "Import failed (HTTP ${IMPORT_CODE}): ${IMPORT_BODY}"
fi
rm -f "${STATS_FILE}"
}
wait_for_service() {
local port
port=$(get_port)
echo "Waiting for service to be ready..."
for i in {1..30}; do
if curl -s -o /dev/null -w "%{http_code}" "http://localhost:${port}/" | grep -q "200"; then
break
fi
sleep 1
done
sleep 2
}
if [[ "${1:-}" == "--with-usage" ]]; then
WITH_USAGE=true
export_stats_api_secret
fi
# --- Step 1: Choose Environment ---
echo "Please select an option:"
echo "1) Run using Pre-built Image (Recommended)"
echo "2) Build from Source and Run (For Developers)"
read -r -p "Enter choice [1-2]: " choice
# --- Step 2: Execute based on choice ---
case "$choice" in
1)
echo "--- Running with Pre-built Image ---"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
docker compose up -d --remove-orphans --no-build
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Services are starting from remote image."
echo "Run 'docker compose logs -f' to see the logs."
;;
2)
echo "--- Building from Source and Running ---"
# Get Version Information
VERSION="$(git describe --tags --always --dirty)"
COMMIT="$(git rev-parse --short HEAD)"
BUILD_DATE="$(date -u +%Y-%m-%dT%H:%M:%SZ)"
echo "Building with the following info:"
echo " Version: ${VERSION}"
echo " Commit: ${COMMIT}"
echo " Build Date: ${BUILD_DATE}"
echo "----------------------------------------"
# Build and start the services with a local-only image tag
export CLI_PROXY_IMAGE="cli-proxy-api:local"
echo "Building the Docker image..."
docker compose build \
--build-arg VERSION="${VERSION}" \
--build-arg COMMIT="${COMMIT}" \
--build-arg BUILD_DATE="${BUILD_DATE}"
if [[ "${WITH_USAGE}" == "true" ]]; then
export_stats
fi
echo "Starting the services..."
docker compose up -d --remove-orphans --pull never
if [[ "${WITH_USAGE}" == "true" ]]; then
wait_for_service
import_stats
fi
echo "Build complete. Services are starting."
echo "Run 'docker compose logs -f' to see the logs."
;;
*)
echo "Invalid choice. Please enter 1 or 2."
exit 1
;;
esac
================================================
FILE: docker-compose.yml
================================================
services:
cli-proxy-api:
image: ${CLI_PROXY_IMAGE:-eceasy/cli-proxy-api:latest}
pull_policy: always
build:
context: .
dockerfile: Dockerfile
args:
VERSION: ${VERSION:-dev}
COMMIT: ${COMMIT:-none}
BUILD_DATE: ${BUILD_DATE:-unknown}
container_name: cli-proxy-api
# env_file:
# - .env
environment:
DEPLOY: ${DEPLOY:-}
ports:
- "8317:8317"
- "8085:8085"
- "1455:1455"
- "54545:54545"
- "51121:51121"
- "11451:11451"
volumes:
- ${CLI_PROXY_CONFIG_PATH:-./config.yaml}:/CLIProxyAPI/config.yaml
- ${CLI_PROXY_AUTH_PATH:-./auths}:/root/.cli-proxy-api
- ${CLI_PROXY_LOG_PATH:-./logs}:/CLIProxyAPI/logs
restart: unless-stopped
================================================
FILE: examples/custom-provider/main.go
================================================
// Package main demonstrates how to create a custom AI provider executor
// and integrate it with the CLI Proxy API server. This example shows how to:
// - Create a custom executor that implements the Executor interface
// - Register custom translators for request/response transformation
// - Integrate the custom provider with the SDK server
// - Register custom models in the model registry
//
// This example uses a simple echo service (httpbin.org) as the upstream API
// for demonstration purposes. In a real implementation, you would replace
// this with your actual AI service provider.
package main
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
const (
// providerKey is the identifier for our custom provider.
providerKey = "myprov"
// fOpenAI represents the OpenAI chat format.
fOpenAI = sdktr.Format("openai.chat")
// fMyProv represents our custom provider's chat format.
fMyProv = sdktr.Format("myprov.chat")
)
// init registers trivial translators for demonstration purposes.
// In a real implementation, you would implement proper request/response
// transformation logic between OpenAI format and your provider's format.
func init() {
sdktr.Register(fOpenAI, fMyProv,
func(model string, raw []byte, stream bool) []byte { return raw },
sdktr.ResponseTransform{
Stream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) []string {
return []string{string(raw)}
},
NonStream: func(ctx context.Context, model string, originalReq, translatedReq, raw []byte, param *any) string {
return string(raw)
},
},
)
}
// MyExecutor is a minimal provider implementation for demonstration purposes.
// It implements the Executor interface to handle requests to a custom AI provider.
type MyExecutor struct{}
// Identifier returns the unique identifier for this executor.
func (MyExecutor) Identifier() string { return providerKey }
// PrepareRequest optionally injects credentials to raw HTTP requests.
// This method is called before each request to allow the executor to modify
// the HTTP request with authentication headers or other necessary modifications.
//
// Parameters:
// - req: The HTTP request to prepare
// - a: The authentication information
//
// Returns:
// - error: An error if request preparation fails
func (MyExecutor) PrepareRequest(req *http.Request, a *coreauth.Auth) error {
if req == nil || a == nil {
return nil
}
if a.Attributes != nil {
if ak := strings.TrimSpace(a.Attributes["api_key"]); ak != "" {
req.Header.Set("Authorization", "Bearer "+ak)
}
}
return nil
}
func buildHTTPClient(a *coreauth.Auth) *http.Client {
if a == nil || strings.TrimSpace(a.ProxyURL) == "" {
return http.DefaultClient
}
u, err := url.Parse(a.ProxyURL)
if err != nil || (u.Scheme != "http" && u.Scheme != "https") {
return http.DefaultClient
}
return &http.Client{Transport: &http.Transport{Proxy: http.ProxyURL(u)}}
}
func upstreamEndpoint(a *coreauth.Auth) string {
if a != nil && a.Attributes != nil {
if ep := strings.TrimSpace(a.Attributes["endpoint"]); ep != "" {
return ep
}
}
// Demo echo endpoint; replace with your upstream.
return "https://httpbin.org/post"
}
func (MyExecutor) Execute(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
client := buildHTTPClient(a)
endpoint := upstreamEndpoint(a)
httpReq, errNew := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(req.Payload))
if errNew != nil {
return clipexec.Response{}, errNew
}
httpReq.Header.Set("Content-Type", "application/json")
// Inject credentials via PrepareRequest hook.
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
return clipexec.Response{}, errPrep
}
resp, errDo := client.Do(httpReq)
if errDo != nil {
return clipexec.Response{}, errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
fmt.Fprintf(os.Stderr, "close response body error: %v\n", errClose)
}
}()
body, _ := io.ReadAll(resp.Body)
return clipexec.Response{Payload: body}, nil
}
func (MyExecutor) HttpRequest(ctx context.Context, a *coreauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("myprov executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if errPrep := (MyExecutor{}).PrepareRequest(httpReq, a); errPrep != nil {
return nil, errPrep
}
client := buildHTTPClient(a)
return client.Do(httpReq)
}
func (MyExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("count tokens not implemented")
}
func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (*clipexec.StreamResult, error) {
ch := make(chan clipexec.StreamChunk, 1)
go func() {
defer close(ch)
ch <- clipexec.StreamChunk{Payload: []byte("data: {\"ok\":true}\n\n")}
}()
return &clipexec.StreamResult{Chunks: ch}, nil
}
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
return a, nil
}
func main() {
cfg, err := config.LoadConfig("config.yaml")
if err != nil {
panic(err)
}
tokenStore := sdkAuth.GetTokenStore()
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
dirSetter.SetBaseDir(cfg.AuthDir)
}
core := coreauth.NewManager(tokenStore, nil, nil)
core.RegisterExecutor(MyExecutor{})
hooks := cliproxy.Hooks{
OnAfterStart: func(s *cliproxy.Service) {
// Register demo models for the custom provider so they appear in /v1/models.
models := []*cliproxy.ModelInfo{{ID: "myprov-pro-1", Object: "model", Type: providerKey, DisplayName: "MyProv Pro 1"}}
for _, a := range core.List() {
if strings.EqualFold(a.Provider, providerKey) {
cliproxy.GlobalModelRegistry().RegisterClient(a.ID, providerKey, models)
}
}
},
}
svc, err := cliproxy.NewBuilder().
WithConfig(cfg).
WithConfigPath("config.yaml").
WithCoreAuthManager(core).
WithServerOptions(
// Optional: add a simple middleware + custom request logger
api.WithMiddleware(func(c *gin.Context) { c.Header("X-Example", "custom-provider"); c.Next() }),
api.WithRequestLoggerFactory(func(cfg *config.Config, cfgPath string) logging.RequestLogger {
return logging.NewFileRequestLoggerWithOptions(true, "logs", filepath.Dir(cfgPath), cfg.ErrorLogsMaxFiles)
}),
).
WithHooks(hooks).
Build()
if err != nil {
panic(err)
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if errRun := svc.Run(ctx); errRun != nil && !errors.Is(errRun, context.Canceled) {
panic(errRun)
}
_ = os.Stderr // keep os import used (demo only)
_ = time.Second
}
================================================
FILE: examples/http-request/main.go
================================================
// Package main demonstrates how to use coreauth.Manager.HttpRequest/NewHttpRequest
// to execute arbitrary HTTP requests with provider credentials injected.
//
// This example registers a minimal custom executor that injects an Authorization
// header from auth.Attributes["api_key"], then performs two requests against
// httpbin.org to show the injected headers.
package main
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"
"time"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
)
const providerKey = "echo"
// EchoExecutor is a minimal provider implementation for demonstration purposes.
type EchoExecutor struct{}
func (EchoExecutor) Identifier() string { return providerKey }
func (EchoExecutor) PrepareRequest(req *http.Request, auth *coreauth.Auth) error {
if req == nil || auth == nil {
return nil
}
if auth.Attributes != nil {
if apiKey := strings.TrimSpace(auth.Attributes["api_key"]); apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
}
return nil
}
func (EchoExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("echo executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if errPrep := (EchoExecutor{}).PrepareRequest(httpReq, auth); errPrep != nil {
return nil, errPrep
}
return http.DefaultClient.Do(httpReq)
}
func (EchoExecutor) Execute(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("echo executor: Execute not implemented")
}
func (EchoExecutor) ExecuteStream(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (*clipexec.StreamResult, error) {
return nil, errors.New("echo executor: ExecuteStream not implemented")
}
func (EchoExecutor) Refresh(context.Context, *coreauth.Auth) (*coreauth.Auth, error) {
return nil, errors.New("echo executor: Refresh not implemented")
}
func (EchoExecutor) CountTokens(context.Context, *coreauth.Auth, clipexec.Request, clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("echo executor: CountTokens not implemented")
}
func main() {
log.SetLevel(log.InfoLevel)
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
core := coreauth.NewManager(nil, nil, nil)
core.RegisterExecutor(EchoExecutor{})
auth := &coreauth.Auth{
ID: "demo-echo",
Provider: providerKey,
Attributes: map[string]string{
"api_key": "demo-api-key",
},
}
// Example 1: Build a prepared request and execute it using your own http.Client.
reqPrepared, errReqPrepared := core.NewHttpRequest(
ctx,
auth,
http.MethodGet,
"https://httpbin.org/anything",
nil,
http.Header{"X-Example": []string{"prepared"}},
)
if errReqPrepared != nil {
panic(errReqPrepared)
}
respPrepared, errDoPrepared := http.DefaultClient.Do(reqPrepared)
if errDoPrepared != nil {
panic(errDoPrepared)
}
defer func() {
if errClose := respPrepared.Body.Close(); errClose != nil {
log.Errorf("close response body error: %v", errClose)
}
}()
bodyPrepared, errReadPrepared := io.ReadAll(respPrepared.Body)
if errReadPrepared != nil {
panic(errReadPrepared)
}
fmt.Printf("Prepared request status: %d\n%s\n\n", respPrepared.StatusCode, bodyPrepared)
// Example 2: Execute a raw request via core.HttpRequest (auto inject + do).
rawBody := []byte(`{"hello":"world"}`)
rawReq, errRawReq := http.NewRequestWithContext(ctx, http.MethodPost, "https://httpbin.org/anything", bytes.NewReader(rawBody))
if errRawReq != nil {
panic(errRawReq)
}
rawReq.Header.Set("Content-Type", "application/json")
rawReq.Header.Set("X-Example", "executed")
respExec, errDoExec := core.HttpRequest(ctx, auth, rawReq)
if errDoExec != nil {
panic(errDoExec)
}
defer func() {
if errClose := respExec.Body.Close(); errClose != nil {
log.Errorf("close response body error: %v", errClose)
}
}()
bodyExec, errReadExec := io.ReadAll(respExec.Body)
if errReadExec != nil {
panic(errReadExec)
}
fmt.Printf("Manager HttpRequest status: %d\n%s\n", respExec.StatusCode, bodyExec)
}
================================================
FILE: examples/translator/main.go
================================================
package main
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
_ "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator/builtin"
)
func main() {
rawRequest := []byte(`{"messages":[{"content":[{"text":"Hello! Gemini","type":"text"}],"role":"user"}],"model":"gemini-2.5-pro","stream":false}`)
fmt.Println("Has gemini->openai response translator:", translator.HasResponseTransformerByFormatName(
translator.FormatGemini,
translator.FormatOpenAI,
))
translatedRequest := translator.TranslateRequestByFormatName(
translator.FormatOpenAI,
translator.FormatGemini,
"gemini-2.5-pro",
rawRequest,
false,
)
fmt.Printf("Translated request to Gemini format:\n%s\n\n", translatedRequest)
claudeResponse := []byte(`{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"Okay, here's what's going through my mind. I need to schedule a meeting"},{"thoughtSignature":"","functionCall":{"name":"schedule_meeting","args":{"topic":"Q3 planning","attendees":["Bob","Alice"],"time":"10:00","date":"2025-03-27"}}}]},"finishReason":"STOP","avgLogprobs":-0.50018133435930523}],"usageMetadata":{"promptTokenCount":117,"candidatesTokenCount":28,"totalTokenCount":474,"trafficType":"PROVISIONED_THROUGHPUT","promptTokensDetails":[{"modality":"TEXT","tokenCount":117}],"candidatesTokensDetails":[{"modality":"TEXT","tokenCount":28}],"thoughtsTokenCount":329},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T04:12:55.249090Z","responseId":"x7OeaIKaD6CU48APvNXDyA4"}`)
convertedResponse := translator.TranslateNonStreamByFormatName(
context.Background(),
translator.FormatGemini,
translator.FormatOpenAI,
"gemini-2.5-pro",
rawRequest,
translatedRequest,
claudeResponse,
nil,
)
fmt.Printf("Converted response for OpenAI clients:\n%s\n", convertedResponse)
}
================================================
FILE: go.mod
================================================
module github.com/router-for-me/CLIProxyAPI/v6
go 1.26.0
require (
github.com/andybalholm/brotli v1.0.6
github.com/atotto/clipboard v0.1.4
github.com/charmbracelet/bubbles v1.0.0
github.com/charmbracelet/bubbletea v1.3.10
github.com/charmbracelet/lipgloss v1.1.0
github.com/fsnotify/fsnotify v1.9.0
github.com/gin-gonic/gin v1.10.1
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
github.com/google/uuid v1.6.0
github.com/gorilla/websocket v1.5.3
github.com/jackc/pgx/v5 v5.7.6
github.com/joho/godotenv v1.5.1
github.com/klauspost/compress v1.17.4
github.com/minio/minio-go/v7 v7.0.66
github.com/refraction-networking/utls v1.8.2
github.com/sirupsen/logrus v1.9.3
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
github.com/tiktoken-go/tokenizer v0.7.0
golang.org/x/crypto v0.45.0
golang.org/x/net v0.47.0
golang.org/x/oauth2 v0.30.0
golang.org/x/sync v0.18.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v3 v3.0.1
)
require (
cloud.google.com/go/compute/metadata v0.3.0 // indirect
github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/ProtonMail/go-crypto v1.3.0 // indirect
github.com/aymanbagabas/go-osc52/v2 v2.0.1 // indirect
github.com/bytedance/sonic v1.11.6 // indirect
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/charmbracelet/colorprofile v0.4.1 // indirect
github.com/charmbracelet/x/ansi v0.11.6 // indirect
github.com/charmbracelet/x/cellbuf v0.0.15 // indirect
github.com/charmbracelet/x/term v0.2.2 // indirect
github.com/clipperhouse/displaywidth v0.9.0 // indirect
github.com/clipperhouse/stringish v0.1.1 // indirect
github.com/clipperhouse/uax29/v2 v2.5.0 // indirect
github.com/cloudflare/circl v1.6.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/cyphar/filepath-securejoin v0.4.1 // indirect
github.com/dlclark/regexp2 v1.11.5 // indirect
github.com/dustin/go-humanize v1.0.1 // indirect
github.com/emirpasic/gods v1.18.1 // indirect
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-git/gcfg/v2 v2.0.2 // indirect
github.com/go-git/go-billy/v6 v6.0.0-20250627091229-31e2a16eef30 // indirect
github.com/go-playground/locales v0.14.1 // indirect
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.20.0 // indirect
github.com/goccy/go-json v0.10.2 // indirect
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/json-iterator/go v1.1.12 // indirect
github.com/kevinburke/ssh_config v1.4.0 // indirect
github.com/klauspost/cpuid/v2 v2.3.0 // indirect
github.com/leodido/go-urn v1.4.0 // indirect
github.com/lucasb-eyer/go-colorful v1.3.0 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-localereader v0.0.1 // indirect
github.com/mattn/go-runewidth v0.0.19 // indirect
github.com/minio/md5-simd v1.1.2 // indirect
github.com/minio/sha256-simd v1.0.1 // indirect
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect
github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 // indirect
github.com/muesli/cancelreader v0.2.2 // indirect
github.com/muesli/termenv v0.16.0 // indirect
github.com/pelletier/go-toml/v2 v2.2.2 // indirect
github.com/pjbgf/sha1cd v0.5.0 // indirect
github.com/rivo/uniseg v0.4.7 // indirect
github.com/rs/xid v1.5.0 // indirect
github.com/sergi/go-diff v1.4.0 // indirect
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.0 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/sys v0.38.0 // indirect
golang.org/x/text v0.31.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect
gopkg.in/ini.v1 v1.67.0 // indirect
)
================================================
FILE: go.sum
================================================
cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc=
cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k=
github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY=
github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU=
github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw=
github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE=
github.com/andybalholm/brotli v1.0.6 h1:Yf9fFpf49Zrxb9NlQaluyE92/+X7UVHlhMNJN2sxfOI=
github.com/andybalholm/brotli v1.0.6/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5 h1:0CwZNZbxp69SHPdPJAN/hZIm0C4OItdklCFmMRWYpio=
github.com/armon/go-socks5 v0.0.0-20160902184237-e75332964ef5/go.mod h1:wHh0iHkYZB8zMSxRWpUBQtwG5a7fFgvEO+odwuTv2gs=
github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4=
github.com/atotto/clipboard v0.1.4/go.mod h1:ZY9tmq7sm5xIbd9bOK4onWV4S6X0u6GY7Vn0Yu86PYI=
github.com/aymanbagabas/go-osc52/v2 v2.0.1 h1:HwpRHbFMcZLEVr42D4p7XBqjyuxQH5SMiErDT4WkJ2k=
github.com/aymanbagabas/go-osc52/v2 v2.0.1/go.mod h1:uYgXzlJ7ZpABp8OJ+exZzJJhRNQ2ASbcXHWsFqH8hp8=
github.com/bytedance/sonic v1.11.6 h1:oUp34TzMlL+OY1OUWxHqsdkgC/Zfc85zGqw9siXjrc0=
github.com/bytedance/sonic v1.11.6/go.mod h1:LysEHSvpvDySVdC2f87zGWf6CIKJcAvqab1ZaiQtds4=
github.com/bytedance/sonic/loader v0.1.1 h1:c+e5Pt1k/cy5wMveRDyk2X4B9hF4g7an8N3zCYjJFNM=
github.com/bytedance/sonic/loader v0.1.1/go.mod h1:ncP89zfokxS5LZrJxl5z0UJcsk4M4yY2JpfqGeCtNLU=
github.com/charmbracelet/bubbles v1.0.0 h1:12J8/ak/uCZEMQ6KU7pcfwceyjLlWsDLAxB5fXonfvc=
github.com/charmbracelet/bubbles v1.0.0/go.mod h1:9d/Zd5GdnauMI5ivUIVisuEm3ave1XwXtD1ckyV6r3E=
github.com/charmbracelet/bubbletea v1.3.10 h1:otUDHWMMzQSB0Pkc87rm691KZ3SWa4KUlvF9nRvCICw=
github.com/charmbracelet/bubbletea v1.3.10/go.mod h1:ORQfo0fk8U+po9VaNvnV95UPWA1BitP1E0N6xJPlHr4=
github.com/charmbracelet/colorprofile v0.4.1 h1:a1lO03qTrSIRaK8c3JRxJDZOvhvIeSco3ej+ngLk1kk=
github.com/charmbracelet/colorprofile v0.4.1/go.mod h1:U1d9Dljmdf9DLegaJ0nGZNJvoXAhayhmidOdcBwAvKk=
github.com/charmbracelet/lipgloss v1.1.0 h1:vYXsiLHVkK7fp74RkV7b2kq9+zDLoEU4MZoFqR/noCY=
github.com/charmbracelet/lipgloss v1.1.0/go.mod h1:/6Q8FR2o+kj8rz4Dq0zQc3vYf7X+B0binUUBwA0aL30=
github.com/charmbracelet/x/ansi v0.11.6 h1:GhV21SiDz/45W9AnV2R61xZMRri5NlLnl6CVF7ihZW8=
github.com/charmbracelet/x/ansi v0.11.6/go.mod h1:2JNYLgQUsyqaiLovhU2Rv/pb8r6ydXKS3NIttu3VGZQ=
github.com/charmbracelet/x/cellbuf v0.0.15 h1:ur3pZy0o6z/R7EylET877CBxaiE1Sp1GMxoFPAIztPI=
github.com/charmbracelet/x/cellbuf v0.0.15/go.mod h1:J1YVbR7MUuEGIFPCaaZ96KDl5NoS0DAWkskup+mOY+Q=
github.com/charmbracelet/x/term v0.2.2 h1:xVRT/S2ZcKdhhOuSP4t5cLi5o+JxklsoEObBSgfgZRk=
github.com/charmbracelet/x/term v0.2.2/go.mod h1:kF8CY5RddLWrsgVwpw4kAa6TESp6EB5y3uxGLeCqzAI=
github.com/clipperhouse/displaywidth v0.9.0 h1:Qb4KOhYwRiN3viMv1v/3cTBlz3AcAZX3+y9OLhMtAtA=
github.com/clipperhouse/displaywidth v0.9.0/go.mod h1:aCAAqTlh4GIVkhQnJpbL0T/WfcrJXHcj8C0yjYcjOZA=
github.com/clipperhouse/stringish v0.1.1 h1:+NSqMOr3GR6k1FdRhhnXrLfztGzuG+VuFDfatpWHKCs=
github.com/clipperhouse/stringish v0.1.1/go.mod h1:v/WhFtE1q0ovMta2+m+UbpZ+2/HEXNWYXQgCt4hdOzA=
github.com/clipperhouse/uax29/v2 v2.5.0 h1:x7T0T4eTHDONxFJsL94uKNKPHrclyFI0lm7+w94cO8U=
github.com/clipperhouse/uax29/v2 v2.5.0/go.mod h1:Wn1g7MK6OoeDT0vL+Q0SQLDz/KpfsVRgg6W7ihQeh4g=
github.com/cloudflare/circl v1.6.1 h1:zqIqSPIndyBh1bjLVVDHMPpVKqp8Su/V+6MeDzzQBQ0=
github.com/cloudflare/circl v1.6.1/go.mod h1:uddAzsPgqdMAYatqJ0lsjX1oECcQLIlRpzZh3pJrofs=
github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/0Y=
github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w=
github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg=
github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY=
github.com/cyphar/filepath-securejoin v0.4.1 h1:JyxxyPEaktOD+GAnqIqTf9A8tHyAG22rowi7HkoSU1s=
github.com/cyphar/filepath-securejoin v0.4.1/go.mod h1:Sdj7gXlvMcPZsbhwhQ33GguGLDGQL7h7bg04C/+u9jI=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dlclark/regexp2 v1.11.5 h1:Q/sSnsKerHeCkc/jSTNq1oCm7KiVgUMZRDUoRu0JQZQ=
github.com/dlclark/regexp2 v1.11.5/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY=
github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto=
github.com/elazarl/goproxy v1.7.2 h1:Y2o6urb7Eule09PjlhQRGNsqRfPmYI3KKQLFpCAV3+o=
github.com/elazarl/goproxy v1.7.2/go.mod h1:82vkLNir0ALaW14Rc399OTTjyNREgmdL2cVoIbS6XaE=
github.com/emirpasic/gods v1.18.1 h1:FXtiHYKDGKCW2KzwZKx0iC0PQmdlorYgdFG9jPXJ1Bc=
github.com/emirpasic/gods v1.18.1/go.mod h1:8tpGGwCnJ5H4r6BWwaV6OrWmMoPhUl5jm/FMNAnJvWQ=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f h1:Y/CXytFA4m6baUTXGLOoWe4PQhGxaX0KpnayAqC48p4=
github.com/erikgeiser/coninput v0.0.0-20211004153227-1c3628e74d0f/go.mod h1:vw97MGsxSvLiUE2X8qFplwetxpGLQrlU1Q9AUEIzCaM=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
github.com/gin-gonic/gin v1.10.1 h1:T0ujvqyCSqRopADpgPgiTT63DUQVSfojyME59Ei63pQ=
github.com/gin-gonic/gin v1.10.1/go.mod h1:4PMNQiOhvDRa013RKVbsiNwoyezlm2rm0uX/T7kzp5Y=
github.com/gliderlabs/ssh v0.3.8 h1:a4YXD1V7xMF9g5nTkdfnja3Sxy1PVDCj1Zg4Wb8vY6c=
github.com/gliderlabs/ssh v0.3.8/go.mod h1:xYoytBv1sV0aL3CavoDuJIQNURXkkfPA/wxQ1pL1fAU=
github.com/go-git/gcfg/v2 v2.0.2 h1:MY5SIIfTGGEMhdA7d7JePuVVxtKL7Hp+ApGDJAJ7dpo=
github.com/go-git/gcfg/v2 v2.0.2/go.mod h1:/lv2NsxvhepuMrldsFilrgct6pxzpGdSRC13ydTLSLs=
github.com/go-git/go-billy/v6 v6.0.0-20250627091229-31e2a16eef30 h1:4KqVJTL5eanN8Sgg3BV6f2/QzfZEFbCd+rTak1fGRRA=
github.com/go-git/go-billy/v6 v6.0.0-20250627091229-31e2a16eef30/go.mod h1:snwvGrbywVFy2d6KJdQ132zapq4aLyzLMgpo79XdEfM=
github.com/go-git/go-git-fixtures/v5 v5.1.1 h1:OH8i1ojV9bWfr0ZfasfpgtUXQHQyVS8HXik/V1C099w=
github.com/go-git/go-git-fixtures/v5 v5.1.1/go.mod h1:Altk43lx3b1ks+dVoAG2300o5WWUnktvfY3VI6bcaXU=
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145 h1:C/oVxHd6KkkuvthQ/StZfHzZK07gl6xjfCfT3derko0=
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145/go.mod h1:gR+xpbL+o1wuJJDwRN4pOkpNwDS0D24Eo4AD5Aau2DY=
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA=
github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY=
github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY=
github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY=
github.com/go-playground/validator/v10 v10.20.0 h1:K9ISHbSaI0lyB2eWMPJo+kOS/FBExVwjEviJTixqxL8=
github.com/go-playground/validator/v10 v10.20.0/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ=
github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk=
github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M=
github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo=
github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0=
github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0=
github.com/knz/go-libedit v1.10.1/go.mod h1:MZTVkCWyz0oBc7JOWP3wNAzd002ZbM/5hgShxwh4x8M=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ=
github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI=
github.com/lucasb-eyer/go-colorful v1.3.0 h1:2/yBRLdWBZKrf7gB40FoiKfAWYQ0lqNcbuQwVHXptag=
github.com/lucasb-eyer/go-colorful v1.3.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0=
github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY=
github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y=
github.com/mattn/go-localereader v0.0.1 h1:ygSAOl7ZXTx4RdPYinUpg6W99U8jWvWi9Ye2JC/oIi4=
github.com/mattn/go-localereader v0.0.1/go.mod h1:8fBrzywKY7BI3czFoHkuzRoWE9C+EiG4R1k4Cjx5p88=
github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw=
github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs=
github.com/minio/md5-simd v1.1.2 h1:Gdi1DZK69+ZVMoNHRXJyNcxrMA4dSxoYHZSQbirFg34=
github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEpN+20RM=
github.com/minio/minio-go/v7 v7.0.66 h1:bnTOXOHjOqv/gcMuiVbN9o2ngRItvqE774dG9nq0Dzw=
github.com/minio/minio-go/v7 v7.0.66/go.mod h1:DHAgmyQEGdW3Cif0UooKOyrT3Vxs82zNdV6tkKhRtbs=
github.com/minio/sha256-simd v1.0.1 h1:6kaan5IFmwTNynnKKpDHe6FWHohJOHhCPchzK49dzMM=
github.com/minio/sha256-simd v1.0.1/go.mod h1:Pz6AKMiUdngCLpeTL/RJY1M9rUuPMYujV5xJjtbRSN8=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6 h1:ZK8zHtRHOkbHy6Mmr5D264iyp3TiX5OmNcI5cIARiQI=
github.com/muesli/ansi v0.0.0-20230316100256-276c6243b2f6/go.mod h1:CJlz5H+gyd6CUWT45Oy4q24RdLyn7Md9Vj2/ldJBSIo=
github.com/muesli/cancelreader v0.2.2 h1:3I4Kt4BQjOR54NavqnDogx/MIoWBFa0StPA8ELUXHmA=
github.com/muesli/cancelreader v0.2.2/go.mod h1:3XuTXfFS2VjM+HTLZY9Ak0l6eUKfijIfMUZ4EgX0QYo=
github.com/muesli/termenv v0.16.0 h1:S5AlUN9dENB57rsbnkPyfdGuWIlkmzJjbFf0Tf5FWUc=
github.com/muesli/termenv v0.16.0/go.mod h1:ZRfOIKPFDYQoDFF4Olj7/QJbW60Ol/kL1pU3VfY/Cnk=
github.com/pelletier/go-toml/v2 v2.2.2 h1:aYUidT7k73Pcl9nb2gScu7NSrKCSHIDE89b3+6Wq+LM=
github.com/pelletier/go-toml/v2 v2.2.2/go.mod h1:1t835xjRzz80PqgE6HHgN2JOsmgYu/h4qDAS4n929Rs=
github.com/pjbgf/sha1cd v0.5.0 h1:a+UkboSi1znleCDUNT3M5YxjOnN1fz2FhN48FlwCxs0=
github.com/pjbgf/sha1cd v0.5.0/go.mod h1:lhpGlyHLpQZoxMv8HcgXvZEhcGs0PG/vsZnEJ7H0iCM=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/refraction-networking/utls v1.8.2 h1:j4Q1gJj0xngdeH+Ox/qND11aEfhpgoEvV+S9iJ2IdQo=
github.com/refraction-networking/utls v1.8.2/go.mod h1:jkSOEkLqn+S/jtpEHPOsVv/4V4EVnelwbMQl4vCWXAM=
github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ=
github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88=
github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ=
github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc=
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/sergi/go-diff v1.4.0 h1:n/SP9D5ad1fORl+llWyN+D6qoUETXNZARKjyY2/KVCw=
github.com/sergi/go-diff v1.4.0/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966 h1:JIAuq3EEf9cgbU6AtGPK4CTG3Zf6CKMNqf0MHTggAUA=
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966/go.mod h1:sUM3LWHvSMaG192sy56D9F7CNvL7jUJVXoqM1QKLnog=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U=
github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY=
github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA=
github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs=
github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/tiktoken-go/tokenizer v0.7.0 h1:VMu6MPT0bXFDHr7UPh9uii7CNItVt3X9K90omxL54vw=
github.com/tiktoken-go/tokenizer v0.7.0/go.mod h1:6UCYI/DtOallbmL7sSy30p6YQv60qNyU/4aVigPOx6w=
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
github.com/ugorji/go/codec v1.2.12 h1:9LC83zGrHhuUA9l16C9AHXAqEV/2wBQ4nkvumAE65EE=
github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no=
github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM=
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d h1:jtJma62tbqLibJ5sFQz8bKtEM8rJBtfilJ2qTU199MI=
golang.org/x/exp v0.0.0-20231006140011-7918f672742d/go.mod h1:ldy0pHrwJyGW56pPQzzkH36rKxoZW1tw7ZJpeKx+hdo=
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA=
gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k=
gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc=
gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
nullprogram.com/x/optparse v1.0.0/go.mod h1:KdyPE+Igbe0jQUrVfMqDMeJQIJZEuyV7pjYmp6pbG50=
rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4=
================================================
FILE: internal/access/config_access/provider.go
================================================
package configaccess
import (
"context"
"net/http"
"strings"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
// Register ensures the config-access provider is available to the access manager.
func Register(cfg *sdkconfig.SDKConfig) {
if cfg == nil {
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
return
}
keys := normalizeKeys(cfg.APIKeys)
if len(keys) == 0 {
sdkaccess.UnregisterProvider(sdkaccess.AccessProviderTypeConfigAPIKey)
return
}
sdkaccess.RegisterProvider(
sdkaccess.AccessProviderTypeConfigAPIKey,
newProvider(sdkaccess.DefaultAccessProviderName, keys),
)
}
type provider struct {
name string
keys map[string]struct{}
}
func newProvider(name string, keys []string) *provider {
providerName := strings.TrimSpace(name)
if providerName == "" {
providerName = sdkaccess.DefaultAccessProviderName
}
keySet := make(map[string]struct{}, len(keys))
for _, key := range keys {
keySet[key] = struct{}{}
}
return &provider{name: providerName, keys: keySet}
}
func (p *provider) Identifier() string {
if p == nil || p.name == "" {
return sdkaccess.DefaultAccessProviderName
}
return p.name
}
func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.Result, *sdkaccess.AuthError) {
if p == nil {
return nil, sdkaccess.NewNotHandledError()
}
if len(p.keys) == 0 {
return nil, sdkaccess.NewNotHandledError()
}
authHeader := r.Header.Get("Authorization")
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
authHeaderAnthropic := r.Header.Get("X-Api-Key")
queryKey := ""
queryAuthToken := ""
if r.URL != nil {
queryKey = r.URL.Query().Get("key")
queryAuthToken = r.URL.Query().Get("auth_token")
}
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
return nil, sdkaccess.NewNoCredentialsError()
}
apiKey := extractBearerToken(authHeader)
candidates := []struct {
value string
source string
}{
{apiKey, "authorization"},
{authHeaderGoogle, "x-goog-api-key"},
{authHeaderAnthropic, "x-api-key"},
{queryKey, "query-key"},
{queryAuthToken, "query-auth-token"},
}
for _, candidate := range candidates {
if candidate.value == "" {
continue
}
if _, ok := p.keys[candidate.value]; ok {
return &sdkaccess.Result{
Provider: p.Identifier(),
Principal: candidate.value,
Metadata: map[string]string{
"source": candidate.source,
},
}, nil
}
}
return nil, sdkaccess.NewInvalidCredentialError()
}
func extractBearerToken(header string) string {
if header == "" {
return ""
}
parts := strings.SplitN(header, " ", 2)
if len(parts) != 2 {
return header
}
if strings.ToLower(parts[0]) != "bearer" {
return header
}
return strings.TrimSpace(parts[1])
}
func normalizeKeys(keys []string) []string {
if len(keys) == 0 {
return nil
}
normalized := make([]string, 0, len(keys))
seen := make(map[string]struct{}, len(keys))
for _, key := range keys {
trimmedKey := strings.TrimSpace(key)
if trimmedKey == "" {
continue
}
if _, exists := seen[trimmedKey]; exists {
continue
}
seen[trimmedKey] = struct{}{}
normalized = append(normalized, trimmedKey)
}
if len(normalized) == 0 {
return nil
}
return normalized
}
================================================
FILE: internal/access/reconcile.go
================================================
package access
import (
"fmt"
"reflect"
"sort"
"strings"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
log "github.com/sirupsen/logrus"
)
// ReconcileProviders builds the desired provider list by reusing existing providers when possible
// and creating or removing providers only when their configuration changed. It returns the final
// ordered provider slice along with the identifiers of providers that were added, updated, or
// removed compared to the previous configuration.
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []sdkaccess.Provider) (result []sdkaccess.Provider, added, updated, removed []string, err error) {
_ = oldCfg
if newCfg == nil {
return nil, nil, nil, nil, nil
}
result = sdkaccess.RegisteredProviders()
existingMap := make(map[string]sdkaccess.Provider, len(existing))
for _, provider := range existing {
providerID := identifierFromProvider(provider)
if providerID == "" {
continue
}
existingMap[providerID] = provider
}
finalIDs := make(map[string]struct{}, len(result))
isInlineProvider := func(id string) bool {
return strings.EqualFold(id, sdkaccess.DefaultAccessProviderName)
}
appendChange := func(list *[]string, id string) {
if isInlineProvider(id) {
return
}
*list = append(*list, id)
}
for _, provider := range result {
providerID := identifierFromProvider(provider)
if providerID == "" {
continue
}
finalIDs[providerID] = struct{}{}
existingProvider, exists := existingMap[providerID]
if !exists {
appendChange(&added, providerID)
continue
}
if !providerInstanceEqual(existingProvider, provider) {
appendChange(&updated, providerID)
}
}
for providerID := range existingMap {
if _, exists := finalIDs[providerID]; exists {
continue
}
appendChange(&removed, providerID)
}
sort.Strings(added)
sort.Strings(updated)
sort.Strings(removed)
return result, added, updated, removed, nil
}
// ApplyAccessProviders reconciles the configured access providers against the
// currently registered providers and updates the manager. It logs a concise
// summary of the detected changes and returns whether any provider changed.
func ApplyAccessProviders(manager *sdkaccess.Manager, oldCfg, newCfg *config.Config) (bool, error) {
if manager == nil || newCfg == nil {
return false, nil
}
existing := manager.Providers()
configaccess.Register(&newCfg.SDKConfig)
providers, added, updated, removed, err := ReconcileProviders(oldCfg, newCfg, existing)
if err != nil {
log.Errorf("failed to reconcile request auth providers: %v", err)
return false, fmt.Errorf("reconciling access providers: %w", err)
}
manager.SetProviders(providers)
if len(added)+len(updated)+len(removed) > 0 {
log.Debugf("auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
log.Debugf("auth providers changes details - added=%v updated=%v removed=%v", added, updated, removed)
return true, nil
}
log.Debug("auth providers unchanged after config update")
return false, nil
}
func identifierFromProvider(provider sdkaccess.Provider) string {
if provider == nil {
return ""
}
return strings.TrimSpace(provider.Identifier())
}
func providerInstanceEqual(a, b sdkaccess.Provider) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
if reflect.TypeOf(a) != reflect.TypeOf(b) {
return false
}
valueA := reflect.ValueOf(a)
valueB := reflect.ValueOf(b)
if valueA.Kind() == reflect.Pointer && valueB.Kind() == reflect.Pointer {
return valueA.Pointer() == valueB.Pointer()
}
return reflect.DeepEqual(a, b)
}
================================================
FILE: internal/api/handlers/management/api_tools.go
================================================
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const defaultAPICallTimeout = 60 * time.Second
const (
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
const (
antigravityOAuthClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityOAuthClientSecret = "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"
)
var antigravityOAuthTokenURL = "https://oauth2.googleapis.com/token"
type apiCallRequest struct {
AuthIndexSnake *string `json:"auth_index"`
AuthIndexCamel *string `json:"authIndex"`
AuthIndexPascal *string `json:"AuthIndex"`
Method string `json:"method"`
URL string `json:"url"`
Header map[string]string `json:"header"`
Data string `json:"data"`
}
type apiCallResponse struct {
StatusCode int `json:"status_code"`
Header map[string][]string `json:"header"`
Body string `json:"body"`
}
// APICall makes a generic HTTP request on behalf of the management API caller.
// It is protected by the management middleware.
//
// Endpoint:
//
// POST /v0/management/api-call
//
// Authentication:
//
// Same as other management APIs (requires a management key and remote-management rules).
// You can provide the key via:
// - Authorization: Bearer
// - X-Management-Key:
//
// Request JSON:
// - auth_index / authIndex / AuthIndex (optional):
// The credential "auth_index" from GET /v0/management/auth-files (or other endpoints returning it).
// If omitted or not found, credential-specific proxy/token substitution is skipped.
// - method (required): HTTP method, e.g. GET, POST, PUT, PATCH, DELETE.
// - url (required): Absolute URL including scheme and host, e.g. "https://api.example.com/v1/ping".
// - header (optional): Request headers map.
// Supports magic variable "$TOKEN$" which is replaced using the selected credential:
// 1) metadata.access_token
// 2) attributes.api_key
// 3) metadata.token / metadata.id_token / metadata.cookie
// Example: {"Authorization":"Bearer $TOKEN$"}.
// Note: if you need to override the HTTP Host header, set header["Host"].
// - data (optional): Raw request body as string (useful for POST/PUT/PATCH).
//
// Proxy selection (highest priority first):
// 1. Selected credential proxy_url
// 2. Global config proxy-url
// 3. Direct connect (environment proxies are not used)
//
// Response JSON (returned with HTTP 200 when the APICall itself succeeds):
// - status_code: Upstream HTTP status code.
// - header: Upstream response headers.
// - body: Upstream response body as string.
//
// Example:
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer " \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"","method":"GET","url":"https://api.example.com/v1/ping","header":{"Authorization":"Bearer $TOKEN$"}}'
//
// curl -sS -X POST "http://127.0.0.1:8317/v0/management/api-call" \
// -H "Authorization: Bearer 831227" \
// -H "Content-Type: application/json" \
// -d '{"auth_index":"","method":"POST","url":"https://api.example.com/v1/fetchAvailableModels","header":{"Authorization":"Bearer $TOKEN$","Content-Type":"application/json","User-Agent":"cliproxyapi"},"data":"{}"}'
func (h *Handler) APICall(c *gin.Context) {
var body apiCallRequest
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
method := strings.ToUpper(strings.TrimSpace(body.Method))
if method == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing method"})
return
}
urlStr := strings.TrimSpace(body.URL)
if urlStr == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing url"})
return
}
parsedURL, errParseURL := url.Parse(urlStr)
if errParseURL != nil || parsedURL.Scheme == "" || parsedURL.Host == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid url"})
return
}
authIndex := firstNonEmptyString(body.AuthIndexSnake, body.AuthIndexCamel, body.AuthIndexPascal)
auth := h.authByIndex(authIndex)
reqHeaders := body.Header
if reqHeaders == nil {
reqHeaders = map[string]string{}
}
var hostOverride string
var token string
var tokenResolved bool
var tokenErr error
for key, value := range reqHeaders {
if !strings.Contains(value, "$TOKEN$") {
continue
}
if !tokenResolved {
token, tokenErr = h.resolveTokenForAuth(c.Request.Context(), auth)
tokenResolved = true
}
if auth != nil && token == "" {
if tokenErr != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token refresh failed"})
return
}
c.JSON(http.StatusBadRequest, gin.H{"error": "auth token not found"})
return
}
if token == "" {
continue
}
reqHeaders[key] = strings.ReplaceAll(value, "$TOKEN$", token)
}
var requestBody io.Reader
if body.Data != "" {
requestBody = strings.NewReader(body.Data)
}
req, errNewRequest := http.NewRequestWithContext(c.Request.Context(), method, urlStr, requestBody)
if errNewRequest != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to build request"})
return
}
for key, value := range reqHeaders {
if strings.EqualFold(key, "host") {
hostOverride = strings.TrimSpace(value)
continue
}
req.Header.Set(key, value)
}
if hostOverride != "" {
req.Host = hostOverride
}
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
}
httpClient.Transport = h.apiCallTransport(auth)
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.WithError(errDo).Debug("management APICall request failed")
c.JSON(http.StatusBadGateway, gin.H{"error": "request failed"})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
respBody, errReadAll := io.ReadAll(resp.Body)
if errReadAll != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "failed to read response"})
return
}
c.JSON(http.StatusOK, apiCallResponse{
StatusCode: resp.StatusCode,
Header: resp.Header,
Body: string(respBody),
})
}
func firstNonEmptyString(values ...*string) string {
for _, v := range values {
if v == nil {
continue
}
if out := strings.TrimSpace(*v); out != "" {
return out
}
}
return ""
}
func tokenValueForAuth(auth *coreauth.Auth) string {
if auth == nil {
return ""
}
if v := tokenValueFromMetadata(auth.Metadata); v != "" {
return v
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return v
}
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
if v := tokenValueFromMetadata(shared.MetadataSnapshot()); v != "" {
return v
}
}
return ""
}
func (h *Handler) resolveTokenForAuth(ctx context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if provider == "gemini-cli" {
token, errToken := h.refreshGeminiOAuthAccessToken(ctx, auth)
return token, errToken
}
if provider == "antigravity" {
token, errToken := h.refreshAntigravityOAuthAccessToken(ctx, auth)
return token, errToken
}
return tokenValueForAuth(auth), nil
}
func (h *Handler) refreshGeminiOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata, updater := geminiOAuthMetadata(auth)
if len(metadata) == 0 {
return "", fmt.Errorf("gemini oauth metadata missing")
}
base := make(map[string]any)
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
base = cloneMap(tokenRaw)
}
var token oauth2.Token
if len(base) > 0 {
if raw, errMarshal := json.Marshal(base); errMarshal == nil {
_ = json.Unmarshal(raw, &token)
}
}
if token.AccessToken == "" {
token.AccessToken = stringValue(metadata, "access_token")
}
if token.RefreshToken == "" {
token.RefreshToken = stringValue(metadata, "refresh_token")
}
if token.TokenType == "" {
token.TokenType = stringValue(metadata, "token_type")
}
if token.Expiry.IsZero() {
if expiry := stringValue(metadata, "expiry"); expiry != "" {
if ts, errParseTime := time.Parse(time.RFC3339, expiry); errParseTime == nil {
token.Expiry = ts
}
}
}
conf := &oauth2.Config{
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}
ctxToken := ctx
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
src := conf.TokenSource(ctxToken, &token)
currentToken, errToken := src.Token()
if errToken != nil {
return "", errToken
}
merged := buildOAuthTokenMap(base, currentToken)
fields := buildOAuthTokenFields(currentToken, merged)
if updater != nil {
updater(fields)
}
return strings.TrimSpace(currentToken.AccessToken), nil
}
func (h *Handler) refreshAntigravityOAuthAccessToken(ctx context.Context, auth *coreauth.Auth) (string, error) {
if ctx == nil {
ctx = context.Background()
}
if auth == nil {
return "", nil
}
metadata := auth.Metadata
if len(metadata) == 0 {
return "", fmt.Errorf("antigravity oauth metadata missing")
}
current := strings.TrimSpace(tokenValueFromMetadata(metadata))
if current != "" && !antigravityTokenNeedsRefresh(metadata) {
return current, nil
}
refreshToken := stringValue(metadata, "refresh_token")
if refreshToken == "" {
return "", fmt.Errorf("antigravity refresh token missing")
}
tokenURL := strings.TrimSpace(antigravityOAuthTokenURL)
if tokenURL == "" {
tokenURL = "https://oauth2.googleapis.com/token"
}
form := url.Values{}
form.Set("client_id", antigravityOAuthClientID)
form.Set("client_secret", antigravityOAuthClientSecret)
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
req, errReq := http.NewRequestWithContext(ctx, http.MethodPost, tokenURL, strings.NewReader(form.Encode()))
if errReq != nil {
return "", errReq
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
httpClient := &http.Client{
Timeout: defaultAPICallTimeout,
Transport: h.apiCallTransport(auth),
}
resp, errDo := httpClient.Do(req)
if errDo != nil {
return "", errDo
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", errRead
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("antigravity oauth token refresh failed: status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if errUnmarshal := json.Unmarshal(bodyBytes, &tokenResp); errUnmarshal != nil {
return "", errUnmarshal
}
if strings.TrimSpace(tokenResp.AccessToken) == "" {
return "", fmt.Errorf("antigravity oauth token refresh returned empty access_token")
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
now := time.Now()
auth.Metadata["access_token"] = strings.TrimSpace(tokenResp.AccessToken)
if strings.TrimSpace(tokenResp.RefreshToken) != "" {
auth.Metadata["refresh_token"] = strings.TrimSpace(tokenResp.RefreshToken)
}
if tokenResp.ExpiresIn > 0 {
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
auth.Metadata["timestamp"] = now.UnixMilli()
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
}
auth.Metadata["type"] = "antigravity"
if h != nil && h.authManager != nil {
auth.LastRefreshedAt = now
auth.UpdatedAt = now
_, _ = h.authManager.Update(ctx, auth)
}
return strings.TrimSpace(tokenResp.AccessToken), nil
}
func antigravityTokenNeedsRefresh(metadata map[string]any) bool {
// Refresh a bit early to avoid requests racing token expiry.
const skew = 30 * time.Second
if metadata == nil {
return true
}
if expStr, ok := metadata["expired"].(string); ok {
if ts, errParse := time.Parse(time.RFC3339, strings.TrimSpace(expStr)); errParse == nil {
return !ts.After(time.Now().Add(skew))
}
}
expiresIn := int64Value(metadata["expires_in"])
timestampMs := int64Value(metadata["timestamp"])
if expiresIn > 0 && timestampMs > 0 {
exp := time.UnixMilli(timestampMs).Add(time.Duration(expiresIn) * time.Second)
return !exp.After(time.Now().Add(skew))
}
return true
}
func int64Value(raw any) int64 {
switch typed := raw.(type) {
case int:
return int64(typed)
case int32:
return int64(typed)
case int64:
return typed
case uint:
return int64(typed)
case uint32:
return int64(typed)
case uint64:
if typed > uint64(^uint64(0)>>1) {
return 0
}
return int64(typed)
case float32:
return int64(typed)
case float64:
return int64(typed)
case json.Number:
if i, errParse := typed.Int64(); errParse == nil {
return i
}
case string:
if s := strings.TrimSpace(typed); s != "" {
if i, errParse := json.Number(s).Int64(); errParse == nil {
return i
}
}
}
return 0
}
func geminiOAuthMetadata(auth *coreauth.Auth) (map[string]any, func(map[string]any)) {
if auth == nil {
return nil, nil
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
snapshot := shared.MetadataSnapshot()
return snapshot, func(fields map[string]any) { shared.MergeMetadata(fields) }
}
return auth.Metadata, func(fields map[string]any) {
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
for k, v := range fields {
auth.Metadata[k] = v
}
}
}
func stringValue(metadata map[string]any, key string) string {
if len(metadata) == 0 || key == "" {
return ""
}
if v, ok := metadata[key].(string); ok {
return strings.TrimSpace(v)
}
return ""
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func buildOAuthTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
merged := cloneMap(base)
if merged == nil {
merged = make(map[string]any)
}
if tok == nil {
return merged
}
if raw, errMarshal := json.Marshal(tok); errMarshal == nil {
var tokenMap map[string]any
if errUnmarshal := json.Unmarshal(raw, &tokenMap); errUnmarshal == nil {
for k, v := range tokenMap {
merged[k] = v
}
}
}
return merged
}
func buildOAuthTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
fields := make(map[string]any, 5)
if tok != nil && tok.AccessToken != "" {
fields["access_token"] = tok.AccessToken
}
if tok != nil && tok.TokenType != "" {
fields["token_type"] = tok.TokenType
}
if tok != nil && tok.RefreshToken != "" {
fields["refresh_token"] = tok.RefreshToken
}
if tok != nil && !tok.Expiry.IsZero() {
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
}
if len(merged) > 0 {
fields["token"] = cloneMap(merged)
}
return fields
}
func tokenValueFromMetadata(metadata map[string]any) string {
if len(metadata) == 0 {
return ""
}
if v, ok := metadata["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if tokenRaw, ok := metadata["token"]; ok && tokenRaw != nil {
switch typed := tokenRaw.(type) {
case string:
if v := strings.TrimSpace(typed); v != "" {
return v
}
case map[string]any:
if v, ok := typed["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := typed["accessToken"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
case map[string]string:
if v := strings.TrimSpace(typed["access_token"]); v != "" {
return v
}
if v := strings.TrimSpace(typed["accessToken"]); v != "" {
return v
}
}
}
if v, ok := metadata["token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["id_token"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
if v, ok := metadata["cookie"].(string); ok && strings.TrimSpace(v) != "" {
return strings.TrimSpace(v)
}
return ""
}
func (h *Handler) authByIndex(authIndex string) *coreauth.Auth {
authIndex = strings.TrimSpace(authIndex)
if authIndex == "" || h == nil || h.authManager == nil {
return nil
}
auths := h.authManager.List()
for _, auth := range auths {
if auth == nil {
continue
}
auth.EnsureIndex()
if auth.Index == authIndex {
return auth
}
}
return nil
}
func (h *Handler) apiCallTransport(auth *coreauth.Auth) http.RoundTripper {
var proxyCandidates []string
if auth != nil {
if proxyStr := strings.TrimSpace(auth.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
if h != nil && h.cfg != nil {
if proxyStr := strings.TrimSpace(h.cfg.ProxyURL); proxyStr != "" {
proxyCandidates = append(proxyCandidates, proxyStr)
}
}
for _, proxyStr := range proxyCandidates {
if transport := buildProxyTransport(proxyStr); transport != nil {
return transport
}
}
transport, ok := http.DefaultTransport.(*http.Transport)
if !ok || transport == nil {
return &http.Transport{Proxy: nil}
}
clone := transport.Clone()
clone.Proxy = nil
return clone
}
func buildProxyTransport(proxyStr string) *http.Transport {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil {
log.WithError(errBuild).Debug("build proxy transport failed")
return nil
}
return transport
}
================================================
FILE: internal/api/handlers/management/api_tools_test.go
================================================
package management
import (
"context"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestAPICallTransportDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
},
}
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "direct"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
if httpTransport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
func TestAPICallTransportInvalidAuthFallsBackToGlobalProxy(t *testing.T) {
t.Parallel()
h := &Handler{
cfg: &config.Config{
SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"},
},
}
transport := h.apiCallTransport(&coreauth.Auth{ProxyURL: "bad-value"})
httpTransport, ok := transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", transport)
}
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
proxyURL, errProxy := httpTransport.Proxy(req)
if errProxy != nil {
t.Fatalf("httpTransport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://global-proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://global-proxy.example.com:8080", proxyURL)
}
}
func TestAuthByIndexDistinguishesSharedAPIKeysAcrossProviders(t *testing.T) {
t.Parallel()
manager := coreauth.NewManager(nil, nil, nil)
geminiAuth := &coreauth.Auth{
ID: "gemini:apikey:123",
Provider: "gemini",
Attributes: map[string]string{
"api_key": "shared-key",
},
}
compatAuth := &coreauth.Auth{
ID: "openai-compatibility:bohe:456",
Provider: "bohe",
Label: "bohe",
Attributes: map[string]string{
"api_key": "shared-key",
"compat_name": "bohe",
"provider_key": "bohe",
},
}
if _, errRegister := manager.Register(context.Background(), geminiAuth); errRegister != nil {
t.Fatalf("register gemini auth: %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), compatAuth); errRegister != nil {
t.Fatalf("register compat auth: %v", errRegister)
}
geminiIndex := geminiAuth.EnsureIndex()
compatIndex := compatAuth.EnsureIndex()
if geminiIndex == compatIndex {
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
}
h := &Handler{authManager: manager}
gotGemini := h.authByIndex(geminiIndex)
if gotGemini == nil {
t.Fatal("expected gemini auth by index")
}
if gotGemini.ID != geminiAuth.ID {
t.Fatalf("authByIndex(gemini) returned %q, want %q", gotGemini.ID, geminiAuth.ID)
}
gotCompat := h.authByIndex(compatIndex)
if gotCompat == nil {
t.Fatal("expected compat auth by index")
}
if gotCompat.ID != compatAuth.ID {
t.Fatalf("authByIndex(compat) returned %q, want %q", gotCompat.ID, compatAuth.ID)
}
}
================================================
FILE: internal/api/handlers/management/auth_files.go
================================================
package management
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net"
"net/http"
"os"
"path/filepath"
"runtime"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
const (
anthropicCallbackPort = 54545
geminiCallbackPort = 8085
codexCallbackPort = 1455
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
)
type callbackForwarder struct {
provider string
server *http.Server
done chan struct{}
}
var (
callbackForwardersMu sync.Mutex
callbackForwarders = make(map[int]*callbackForwarder)
)
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
if len(meta) == 0 {
return time.Time{}, false
}
for _, key := range lastRefreshKeys {
if val, ok := meta[key]; ok {
if ts, ok1 := parseLastRefreshValue(val); ok1 {
return ts, true
}
}
}
return time.Time{}, false
}
func parseLastRefreshValue(v any) (time.Time, bool) {
switch val := v.(type) {
case string:
s := strings.TrimSpace(val)
if s == "" {
return time.Time{}, false
}
layouts := []string{time.RFC3339, time.RFC3339Nano, "2006-01-02 15:04:05", "2006-01-02T15:04:05Z07:00"}
for _, layout := range layouts {
if ts, err := time.Parse(layout, s); err == nil {
return ts.UTC(), true
}
}
if unix, err := strconv.ParseInt(s, 10, 64); err == nil && unix > 0 {
return time.Unix(unix, 0).UTC(), true
}
case float64:
if val <= 0 {
return time.Time{}, false
}
return time.Unix(int64(val), 0).UTC(), true
case int64:
if val <= 0 {
return time.Time{}, false
}
return time.Unix(val, 0).UTC(), true
case int:
if val <= 0 {
return time.Time{}, false
}
return time.Unix(int64(val), 0).UTC(), true
case json.Number:
if i, err := val.Int64(); err == nil && i > 0 {
return time.Unix(i, 0).UTC(), true
}
}
return time.Time{}, false
}
func isWebUIRequest(c *gin.Context) bool {
raw := strings.TrimSpace(c.Query("is_webui"))
if raw == "" {
return false
}
switch strings.ToLower(raw) {
case "1", "true", "yes", "on":
return true
default:
return false
}
}
func startCallbackForwarder(port int, provider, targetBase string) (*callbackForwarder, error) {
callbackForwardersMu.Lock()
prev := callbackForwarders[port]
if prev != nil {
delete(callbackForwarders, port)
}
callbackForwardersMu.Unlock()
if prev != nil {
stopForwarderInstance(port, prev)
}
addr := fmt.Sprintf("127.0.0.1:%d", port)
ln, err := net.Listen("tcp", addr)
if err != nil {
return nil, fmt.Errorf("failed to listen on %s: %w", addr, err)
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
target := targetBase
if raw := r.URL.RawQuery; raw != "" {
if strings.Contains(target, "?") {
target = target + "&" + raw
} else {
target = target + "?" + raw
}
}
w.Header().Set("Cache-Control", "no-store")
http.Redirect(w, r, target, http.StatusFound)
})
srv := &http.Server{
Handler: handler,
ReadHeaderTimeout: 5 * time.Second,
WriteTimeout: 5 * time.Second,
}
done := make(chan struct{})
go func() {
if errServe := srv.Serve(ln); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
log.WithError(errServe).Warnf("callback forwarder for %s stopped unexpectedly", provider)
}
close(done)
}()
forwarder := &callbackForwarder{
provider: provider,
server: srv,
done: done,
}
callbackForwardersMu.Lock()
callbackForwarders[port] = forwarder
callbackForwardersMu.Unlock()
log.Infof("callback forwarder for %s listening on %s", provider, addr)
return forwarder, nil
}
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
if forwarder == nil {
return
}
callbackForwardersMu.Lock()
if current := callbackForwarders[port]; current == forwarder {
delete(callbackForwarders, port)
}
callbackForwardersMu.Unlock()
stopForwarderInstance(port, forwarder)
}
func stopForwarderInstance(port int, forwarder *callbackForwarder) {
if forwarder == nil || forwarder.server == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if err := forwarder.server.Shutdown(ctx); err != nil && !errors.Is(err, http.ErrServerClosed) {
log.WithError(err).Warnf("failed to shut down callback forwarder on port %d", port)
}
select {
case <-forwarder.done:
case <-time.After(2 * time.Second):
}
log.Infof("callback forwarder on port %d stopped", port)
}
func (h *Handler) managementCallbackURL(path string) (string, error) {
if h == nil || h.cfg == nil || h.cfg.Port <= 0 {
return "", fmt.Errorf("server port is not configured")
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
scheme := "http"
if h.cfg.TLS.Enable {
scheme = "https"
}
return fmt.Sprintf("%s://127.0.0.1:%d%s", scheme, h.cfg.Port, path), nil
}
func (h *Handler) ListAuthFiles(c *gin.Context) {
if h == nil {
c.JSON(500, gin.H{"error": "handler not initialized"})
return
}
if h.authManager == nil {
h.listAuthFilesFromDisk(c)
return
}
auths := h.authManager.List()
files := make([]gin.H, 0, len(auths))
for _, auth := range auths {
if entry := h.buildAuthFileEntry(auth); entry != nil {
files = append(files, entry)
}
}
sort.Slice(files, func(i, j int) bool {
nameI, _ := files[i]["name"].(string)
nameJ, _ := files[j]["name"].(string)
return strings.ToLower(nameI) < strings.ToLower(nameJ)
})
c.JSON(200, gin.H{"files": files})
}
// GetAuthFileModels returns the models supported by a specific auth file
func (h *Handler) GetAuthFileModels(c *gin.Context) {
name := c.Query("name")
if name == "" {
c.JSON(400, gin.H{"error": "name is required"})
return
}
// Try to find auth ID via authManager
var authID string
if h.authManager != nil {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name || auth.ID == name {
authID = auth.ID
break
}
}
}
if authID == "" {
authID = name // fallback to filename as ID
}
// Get models from registry
reg := registry.GetGlobalRegistry()
models := reg.GetModelsForClient(authID)
result := make([]gin.H, 0, len(models))
for _, m := range models {
entry := gin.H{
"id": m.ID,
}
if m.DisplayName != "" {
entry["display_name"] = m.DisplayName
}
if m.Type != "" {
entry["type"] = m.Type
}
if m.OwnedBy != "" {
entry["owned_by"] = m.OwnedBy
}
result = append(result, entry)
}
c.JSON(200, gin.H{"models": result})
}
// List auth files from disk when the auth manager is unavailable.
func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
entries, err := os.ReadDir(h.cfg.AuthDir)
if err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)})
return
}
files := make([]gin.H, 0)
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(strings.ToLower(name), ".json") {
continue
}
if info, errInfo := e.Info(); errInfo == nil {
fileData := gin.H{"name": name, "size": info.Size(), "modtime": info.ModTime()}
// Read file to get type field
full := filepath.Join(h.cfg.AuthDir, name)
if data, errRead := os.ReadFile(full); errRead == nil {
typeValue := gjson.GetBytes(data, "type").String()
emailValue := gjson.GetBytes(data, "email").String()
fileData["type"] = typeValue
fileData["email"] = emailValue
if pv := gjson.GetBytes(data, "priority"); pv.Exists() {
switch pv.Type {
case gjson.Number:
fileData["priority"] = int(pv.Int())
case gjson.String:
if parsed, errAtoi := strconv.Atoi(strings.TrimSpace(pv.String())); errAtoi == nil {
fileData["priority"] = parsed
}
}
}
if nv := gjson.GetBytes(data, "note"); nv.Exists() && nv.Type == gjson.String {
if trimmed := strings.TrimSpace(nv.String()); trimmed != "" {
fileData["note"] = trimmed
}
}
}
files = append(files, fileData)
}
}
c.JSON(200, gin.H{"files": files})
}
func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
if auth == nil {
return nil
}
auth.EnsureIndex()
runtimeOnly := isRuntimeOnlyAuth(auth)
if runtimeOnly && (auth.Disabled || auth.Status == coreauth.StatusDisabled) {
return nil
}
path := strings.TrimSpace(authAttribute(auth, "path"))
if path == "" && !runtimeOnly {
return nil
}
name := strings.TrimSpace(auth.FileName)
if name == "" {
name = auth.ID
}
entry := gin.H{
"id": auth.ID,
"auth_index": auth.Index,
"name": name,
"type": strings.TrimSpace(auth.Provider),
"provider": strings.TrimSpace(auth.Provider),
"label": auth.Label,
"status": auth.Status,
"status_message": auth.StatusMessage,
"disabled": auth.Disabled,
"unavailable": auth.Unavailable,
"runtime_only": runtimeOnly,
"source": "memory",
"size": int64(0),
}
if email := authEmail(auth); email != "" {
entry["email"] = email
}
if accountType, account := auth.AccountInfo(); accountType != "" || account != "" {
if accountType != "" {
entry["account_type"] = accountType
}
if account != "" {
entry["account"] = account
}
}
if !auth.CreatedAt.IsZero() {
entry["created_at"] = auth.CreatedAt
}
if !auth.UpdatedAt.IsZero() {
entry["modtime"] = auth.UpdatedAt
entry["updated_at"] = auth.UpdatedAt
}
if !auth.LastRefreshedAt.IsZero() {
entry["last_refresh"] = auth.LastRefreshedAt
}
if !auth.NextRetryAfter.IsZero() {
entry["next_retry_after"] = auth.NextRetryAfter
}
if path != "" {
entry["path"] = path
entry["source"] = "file"
if info, err := os.Stat(path); err == nil {
entry["size"] = info.Size()
entry["modtime"] = info.ModTime()
} else if os.IsNotExist(err) {
// Hide credentials removed from disk but still lingering in memory.
if !runtimeOnly && (auth.Disabled || auth.Status == coreauth.StatusDisabled || strings.EqualFold(strings.TrimSpace(auth.StatusMessage), "removed via management api")) {
return nil
}
entry["source"] = "memory"
} else {
log.WithError(err).Warnf("failed to stat auth file %s", path)
}
}
if claims := extractCodexIDTokenClaims(auth); claims != nil {
entry["id_token"] = claims
}
// Expose priority from Attributes (set by synthesizer from JSON "priority" field).
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
if p := strings.TrimSpace(authAttribute(auth, "priority")); p != "" {
if parsed, err := strconv.Atoi(p); err == nil {
entry["priority"] = parsed
}
} else if auth.Metadata != nil {
if rawPriority, ok := auth.Metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
entry["priority"] = int(v)
case int:
entry["priority"] = v
case string:
if parsed, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
entry["priority"] = parsed
}
}
}
}
// Expose note from Attributes (set by synthesizer from JSON "note" field).
// Fall back to Metadata for auths registered via UploadAuthFile (no synthesizer).
if note := strings.TrimSpace(authAttribute(auth, "note")); note != "" {
entry["note"] = note
} else if auth.Metadata != nil {
if rawNote, ok := auth.Metadata["note"].(string); ok {
if trimmed := strings.TrimSpace(rawNote); trimmed != "" {
entry["note"] = trimmed
}
}
}
return entry
}
func extractCodexIDTokenClaims(auth *coreauth.Auth) gin.H {
if auth == nil || auth.Metadata == nil {
return nil
}
if !strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
return nil
}
idTokenRaw, ok := auth.Metadata["id_token"].(string)
if !ok {
return nil
}
idToken := strings.TrimSpace(idTokenRaw)
if idToken == "" {
return nil
}
claims, err := codex.ParseJWTToken(idToken)
if err != nil || claims == nil {
return nil
}
result := gin.H{}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID); v != "" {
result["chatgpt_account_id"] = v
}
if v := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); v != "" {
result["plan_type"] = v
}
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveStart; v != nil {
result["chatgpt_subscription_active_start"] = v
}
if v := claims.CodexAuthInfo.ChatgptSubscriptionActiveUntil; v != nil {
result["chatgpt_subscription_active_until"] = v
}
if len(result) == 0 {
return nil
}
return result
}
func authEmail(auth *coreauth.Auth) string {
if auth == nil {
return ""
}
if auth.Metadata != nil {
if v, ok := auth.Metadata["email"].(string); ok {
return strings.TrimSpace(v)
}
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["email"]); v != "" {
return v
}
if v := strings.TrimSpace(auth.Attributes["account_email"]); v != "" {
return v
}
}
return ""
}
func authAttribute(auth *coreauth.Auth, key string) string {
if auth == nil || len(auth.Attributes) == 0 {
return ""
}
return auth.Attributes[key]
}
func isRuntimeOnlyAuth(auth *coreauth.Auth) bool {
if auth == nil || len(auth.Attributes) == 0 {
return false
}
return strings.EqualFold(strings.TrimSpace(auth.Attributes["runtime_only"]), "true")
}
// Download single auth file by name
func (h *Handler) DownloadAuthFile(c *gin.Context) {
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
if !strings.HasSuffix(strings.ToLower(name), ".json") {
c.JSON(400, gin.H{"error": "name must end with .json"})
return
}
full := filepath.Join(h.cfg.AuthDir, name)
data, err := os.ReadFile(full)
if err != nil {
if os.IsNotExist(err) {
c.JSON(404, gin.H{"error": "file not found"})
} else {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
}
return
}
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", name))
c.Data(200, "application/json", data)
}
// Upload auth file: multipart or raw JSON with ?name=
func (h *Handler) UploadAuthFile(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
ctx := c.Request.Context()
if file, err := c.FormFile("file"); err == nil && file != nil {
name := filepath.Base(file.Filename)
if !strings.HasSuffix(strings.ToLower(name), ".json") {
c.JSON(400, gin.H{"error": "file must be .json"})
return
}
dst := filepath.Join(h.cfg.AuthDir, name)
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
if errSave := c.SaveUploadedFile(file, dst); errSave != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to save file: %v", errSave)})
return
}
data, errRead := os.ReadFile(dst)
if errRead != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read saved file: %v", errRead)})
return
}
if errReg := h.registerAuthFromFile(ctx, dst, data); errReg != nil {
c.JSON(500, gin.H{"error": errReg.Error()})
return
}
c.JSON(200, gin.H{"status": "ok"})
return
}
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
if !strings.HasSuffix(strings.ToLower(name), ".json") {
c.JSON(400, gin.H{"error": "name must end with .json"})
return
}
data, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)})
return
}
if err = h.registerAuthFromFile(ctx, dst, data); err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
c.JSON(200, gin.H{"status": "ok"})
}
// Delete auth files: single by name or all
func (h *Handler) DeleteAuthFile(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
ctx := c.Request.Context()
if all := c.Query("all"); all == "true" || all == "1" || all == "*" {
entries, err := os.ReadDir(h.cfg.AuthDir)
if err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)})
return
}
deleted := 0
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(strings.ToLower(name), ".json") {
continue
}
full := filepath.Join(h.cfg.AuthDir, name)
if !filepath.IsAbs(full) {
if abs, errAbs := filepath.Abs(full); errAbs == nil {
full = abs
}
}
if err = os.Remove(full); err == nil {
if errDel := h.deleteTokenRecord(ctx, full); errDel != nil {
c.JSON(500, gin.H{"error": errDel.Error()})
return
}
deleted++
h.disableAuth(ctx, full)
}
}
c.JSON(200, gin.H{"status": "ok", "deleted": deleted})
return
}
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
targetPath := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
targetID := ""
if targetAuth := h.findAuthForDelete(name); targetAuth != nil {
targetID = strings.TrimSpace(targetAuth.ID)
if path := strings.TrimSpace(authAttribute(targetAuth, "path")); path != "" {
targetPath = path
}
}
if !filepath.IsAbs(targetPath) {
if abs, errAbs := filepath.Abs(targetPath); errAbs == nil {
targetPath = abs
}
}
if errRemove := os.Remove(targetPath); errRemove != nil {
if os.IsNotExist(errRemove) {
c.JSON(404, gin.H{"error": "file not found"})
} else {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", errRemove)})
}
return
}
if errDeleteRecord := h.deleteTokenRecord(ctx, targetPath); errDeleteRecord != nil {
c.JSON(500, gin.H{"error": errDeleteRecord.Error()})
return
}
if targetID != "" {
h.disableAuth(ctx, targetID)
} else {
h.disableAuth(ctx, targetPath)
}
c.JSON(200, gin.H{"status": "ok"})
}
func (h *Handler) findAuthForDelete(name string) *coreauth.Auth {
if h == nil || h.authManager == nil {
return nil
}
name = strings.TrimSpace(name)
if name == "" {
return nil
}
if auth, ok := h.authManager.GetByID(name); ok {
return auth
}
auths := h.authManager.List()
for _, auth := range auths {
if auth == nil {
continue
}
if strings.TrimSpace(auth.FileName) == name {
return auth
}
if filepath.Base(strings.TrimSpace(authAttribute(auth, "path"))) == name {
return auth
}
}
return nil
}
func (h *Handler) authIDForPath(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
id := path
if h != nil && h.cfg != nil {
authDir := strings.TrimSpace(h.cfg.AuthDir)
if authDir != "" {
if rel, errRel := filepath.Rel(authDir, path); errRel == nil && rel != "" {
id = rel
}
}
}
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
if runtime.GOOS == "windows" {
id = strings.ToLower(id)
}
return id
}
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
if h.authManager == nil {
return nil
}
if path == "" {
return fmt.Errorf("auth path is empty")
}
if data == nil {
var err error
data, err = os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read auth file: %w", err)
}
}
metadata := make(map[string]any)
if err := json.Unmarshal(data, &metadata); err != nil {
return fmt.Errorf("invalid auth file: %w", err)
}
provider, _ := metadata["type"].(string)
if provider == "" {
provider = "unknown"
}
label := provider
if email, ok := metadata["email"].(string); ok && email != "" {
label = email
}
lastRefresh, hasLastRefresh := extractLastRefreshTimestamp(metadata)
authID := h.authIDForPath(path)
if authID == "" {
authID = path
}
attr := map[string]string{
"path": path,
"source": path,
}
auth := &coreauth.Auth{
ID: authID,
Provider: provider,
FileName: filepath.Base(path),
Label: label,
Status: coreauth.StatusActive,
Attributes: attr,
Metadata: metadata,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if hasLastRefresh {
auth.LastRefreshedAt = lastRefresh
}
if existing, ok := h.authManager.GetByID(authID); ok {
auth.CreatedAt = existing.CreatedAt
if !hasLastRefresh {
auth.LastRefreshedAt = existing.LastRefreshedAt
}
auth.NextRefreshAfter = existing.NextRefreshAfter
auth.Runtime = existing.Runtime
_, err := h.authManager.Update(ctx, auth)
return err
}
_, err := h.authManager.Register(ctx, auth)
return err
}
// PatchAuthFileStatus toggles the disabled state of an auth file
func (h *Handler) PatchAuthFileStatus(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Disabled *bool `json:"disabled"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
if req.Disabled == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "disabled is required"})
return
}
ctx := c.Request.Context()
// Find auth by name or ID
var targetAuth *coreauth.Auth
if auth, ok := h.authManager.GetByID(name); ok {
targetAuth = auth
} else {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name {
targetAuth = auth
break
}
}
}
if targetAuth == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
return
}
// Update disabled state
targetAuth.Disabled = *req.Disabled
if *req.Disabled {
targetAuth.Status = coreauth.StatusDisabled
targetAuth.StatusMessage = "disabled via management API"
} else {
targetAuth.Status = coreauth.StatusActive
targetAuth.StatusMessage = ""
}
targetAuth.UpdatedAt = time.Now()
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok", "disabled": *req.Disabled})
}
// PatchAuthFileFields updates editable fields (prefix, proxy_url, priority, note) of an auth file.
func (h *Handler) PatchAuthFileFields(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
var req struct {
Name string `json:"name"`
Prefix *string `json:"prefix"`
ProxyURL *string `json:"proxy_url"`
Priority *int `json:"priority"`
Note *string `json:"note"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"})
return
}
name := strings.TrimSpace(req.Name)
if name == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "name is required"})
return
}
ctx := c.Request.Context()
// Find auth by name or ID
var targetAuth *coreauth.Auth
if auth, ok := h.authManager.GetByID(name); ok {
targetAuth = auth
} else {
auths := h.authManager.List()
for _, auth := range auths {
if auth.FileName == name {
targetAuth = auth
break
}
}
}
if targetAuth == nil {
c.JSON(http.StatusNotFound, gin.H{"error": "auth file not found"})
return
}
changed := false
if req.Prefix != nil {
targetAuth.Prefix = *req.Prefix
changed = true
}
if req.ProxyURL != nil {
targetAuth.ProxyURL = *req.ProxyURL
changed = true
}
if req.Priority != nil || req.Note != nil {
if targetAuth.Metadata == nil {
targetAuth.Metadata = make(map[string]any)
}
if targetAuth.Attributes == nil {
targetAuth.Attributes = make(map[string]string)
}
if req.Priority != nil {
if *req.Priority == 0 {
delete(targetAuth.Metadata, "priority")
delete(targetAuth.Attributes, "priority")
} else {
targetAuth.Metadata["priority"] = *req.Priority
targetAuth.Attributes["priority"] = strconv.Itoa(*req.Priority)
}
}
if req.Note != nil {
trimmedNote := strings.TrimSpace(*req.Note)
if trimmedNote == "" {
delete(targetAuth.Metadata, "note")
delete(targetAuth.Attributes, "note")
} else {
targetAuth.Metadata["note"] = trimmedNote
targetAuth.Attributes["note"] = trimmedNote
}
}
changed = true
}
if !changed {
c.JSON(http.StatusBadRequest, gin.H{"error": "no fields to update"})
return
}
targetAuth.UpdatedAt = time.Now()
if _, err := h.authManager.Update(ctx, targetAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to update auth: %v", err)})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h == nil || h.authManager == nil {
return
}
id = strings.TrimSpace(id)
if id == "" {
return
}
if auth, ok := h.authManager.GetByID(id); ok {
auth.Disabled = true
auth.Status = coreauth.StatusDisabled
auth.StatusMessage = "removed via management API"
auth.UpdatedAt = time.Now()
_, _ = h.authManager.Update(ctx, auth)
return
}
authID := h.authIDForPath(id)
if authID == "" {
return
}
if auth, ok := h.authManager.GetByID(authID); ok {
auth.Disabled = true
auth.Status = coreauth.StatusDisabled
auth.StatusMessage = "removed via management API"
auth.UpdatedAt = time.Now()
_, _ = h.authManager.Update(ctx, auth)
}
}
func (h *Handler) deleteTokenRecord(ctx context.Context, path string) error {
if strings.TrimSpace(path) == "" {
return fmt.Errorf("auth path is empty")
}
store := h.tokenStoreWithBaseDir()
if store == nil {
return fmt.Errorf("token store unavailable")
}
return store.Delete(ctx, path)
}
func (h *Handler) tokenStoreWithBaseDir() coreauth.Store {
if h == nil {
return nil
}
store := h.tokenStore
if store == nil {
store = sdkAuth.GetTokenStore()
h.tokenStore = store
}
if h.cfg != nil {
if dirSetter, ok := store.(interface{ SetBaseDir(string) }); ok {
dirSetter.SetBaseDir(h.cfg.AuthDir)
}
}
return store
}
func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (string, error) {
if record == nil {
return "", fmt.Errorf("token record is nil")
}
store := h.tokenStoreWithBaseDir()
if store == nil {
return "", fmt.Errorf("token store unavailable")
}
if h.postAuthHook != nil {
if err := h.postAuthHook(ctx, record); err != nil {
return "", fmt.Errorf("post-auth hook failed: %w", err)
}
}
return store.Save(ctx, record)
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Claude authentication...")
// Generate PKCE codes
pkceCodes, err := claude.GeneratePKCECodes()
if err != nil {
log.Errorf("Failed to generate PKCE codes: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
// Generate random state parameter
state, err := misc.GenerateRandomState()
if err != nil {
log.Errorf("Failed to generate state parameter: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
// Initialize Claude auth service
anthropicAuth := claude.NewClaudeAuth(h.cfg)
// Generate authorization URL (then override redirect_uri to reuse server port)
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
if err != nil {
log.Errorf("Failed to generate authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
RegisterOAuthSession(state, "anthropic")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute anthropic callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start anthropic callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder)
}
// Helper: wait for callback file
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
deadline := time.Now().Add(timeout)
for {
if !IsOAuthSessionPending(state, "anthropic") {
return nil, errOAuthSessionNotPending
}
if time.Now().After(deadline) {
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
data, errRead := os.ReadFile(path)
if errRead == nil {
var m map[string]string
_ = json.Unmarshal(data, &m)
_ = os.Remove(path)
return m, nil
}
time.Sleep(500 * time.Millisecond)
}
}
fmt.Println("Waiting for authentication callback...")
// Wait up to 5 minutes
resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
if errWait != nil {
if errors.Is(errWait, errOAuthSessionNotPending) {
return
}
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
log.Error(claude.GetUserFriendlyMessage(authErr))
return
}
if errStr := resultMap["error"]; errStr != "" {
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Error(claude.GetUserFriendlyMessage(oauthErr))
SetOAuthSessionError(state, "Bad request")
return
}
if resultMap["state"] != state {
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
log.Error(claude.GetUserFriendlyMessage(authErr))
SetOAuthSessionError(state, "State code error")
return
}
// Parse code (Claude may append state after '#')
rawCode := resultMap["code"]
code := strings.Split(rawCode, "#")[0]
// Exchange code for tokens using internal auth service
bundle, errExchange := anthropicAuth.ExchangeCodeForTokens(ctx, code, state, pkceCodes)
if errExchange != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errExchange)
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
return
}
// Create token storage
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
record := &coreauth.Auth{
ID: fmt.Sprintf("claude-%s.json", tokenStorage.Email),
Provider: "claude",
FileName: fmt.Sprintf("claude-%s.json", tokenStorage.Email),
Storage: tokenStorage,
Metadata: map[string]any{"email": tokenStorage.Email},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if bundle.APIKey != "" {
fmt.Println("API key obtained and saved")
}
fmt.Println("You can now use Claude services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("anthropic")
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
// Optional project ID from query
projectID := c.Query("project_id")
fmt.Println("Initializing Google authentication...")
// OAuth2 configuration using exported constants from internal/auth/gemini
conf := &oauth2.Config{
ClientID: geminiAuth.ClientID,
ClientSecret: geminiAuth.ClientSecret,
RedirectURL: fmt.Sprintf("http://localhost:%d/oauth2callback", geminiAuth.DefaultCallbackPort),
Scopes: geminiAuth.Scopes,
Endpoint: google.Endpoint,
}
// Build authorization URL and return it immediately
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
RegisterOAuthSession(state, "gemini")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/google/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute gemini callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start gemini callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder)
}
// Wait for callback file written by server route
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
fmt.Println("Waiting for authentication callback...")
deadline := time.Now().Add(5 * time.Minute)
var authCode string
for {
if !IsOAuthSessionPending(state, "gemini") {
return
}
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out")
return
}
if data, errR := os.ReadFile(waitFile); errR == nil {
var m map[string]string
_ = json.Unmarshal(data, &m)
_ = os.Remove(waitFile)
if errStr := m["error"]; errStr != "" {
log.Errorf("Authentication failed: %s", errStr)
SetOAuthSessionError(state, "Authentication failed")
return
}
authCode = m["code"]
if authCode == "" {
log.Errorf("Authentication failed: code not found")
SetOAuthSessionError(state, "Authentication failed: code not found")
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
// Exchange authorization code for token
token, err := conf.Exchange(ctx, authCode)
if err != nil {
log.Errorf("Failed to exchange token: %v", err)
SetOAuthSessionError(state, "Failed to exchange token")
return
}
requestedProjectID := strings.TrimSpace(projectID)
// Create token storage (mirrors internal/auth/gemini createTokenStorage)
authHTTPClient := conf.Client(ctx, token)
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errNewRequest != nil {
log.Errorf("Could not get user info: %v", errNewRequest)
SetOAuthSessionError(state, "Could not get user info")
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
resp, errDo := authHTTPClient.Do(req)
if errDo != nil {
log.Errorf("Failed to execute request: %v", errDo)
SetOAuthSessionError(state, "Failed to execute request")
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Printf("warn: failed to close response body: %v", errClose)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
return
}
email := gjson.GetBytes(bodyBytes, "email").String()
if email != "" {
fmt.Printf("Authenticated user email: %s\n", email)
} else {
fmt.Println("Failed to get user email from token")
}
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
var ifToken map[string]any
jsonData, _ := json.Marshal(token)
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
SetOAuthSessionError(state, "Failed to unmarshal token")
return
}
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
ifToken["client_id"] = geminiAuth.ClientID
ifToken["client_secret"] = geminiAuth.ClientSecret
ifToken["scopes"] = geminiAuth.Scopes
ifToken["universe_domain"] = "googleapis.com"
ts := geminiAuth.GeminiTokenStorage{
Token: ifToken,
ProjectID: requestedProjectID,
Email: email,
Auto: requestedProjectID == "",
}
// Initialize authenticated HTTP client via GeminiAuth to honor proxy settings
gemAuth := geminiAuth.NewGeminiAuth()
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{
NoBrowser: true,
})
if errGetClient != nil {
log.Errorf("failed to get authenticated client: %v", errGetClient)
SetOAuthSessionError(state, "Failed to get authenticated client")
return
}
fmt.Println("Authentication successful.")
if strings.EqualFold(requestedProjectID, "ALL") {
ts.Auto = false
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
if errAll != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errAll))
return
}
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errVerify))
return
}
ts.ProjectID = strings.Join(projects, ",")
ts.Checked = true
} else if strings.EqualFold(requestedProjectID, "GOOGLE_ONE") {
ts.Auto = false
if errSetup := performGeminiCLISetup(ctx, gemClient, &ts, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup)
SetOAuthSessionError(state, fmt.Sprintf("Google One auto-discovery failed: %v", errSetup))
return
}
if strings.TrimSpace(ts.ProjectID) == "" {
log.Error("Google One auto-discovery returned empty project ID")
SetOAuthSessionError(state, "Google One auto-discovery returned empty project ID")
return
}
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the auto-discovered project")
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return
}
} else {
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
SetOAuthSessionError(state, fmt.Sprintf("Failed to complete Gemini CLI onboarding: %v", errEnsure))
return
}
if strings.TrimSpace(ts.ProjectID) == "" {
log.Error("Onboarding did not return a project ID")
SetOAuthSessionError(state, "Failed to resolve project ID")
return
}
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
SetOAuthSessionError(state, fmt.Sprintf("Failed to verify Cloud AI API status: %v", errCheck))
return
}
ts.Checked = isChecked
if !isChecked {
log.Error("Cloud AI API is not enabled for the selected project")
SetOAuthSessionError(state, fmt.Sprintf("Cloud AI API not enabled for project %s", ts.ProjectID))
return
}
}
recordMetadata := map[string]any{
"email": ts.Email,
"project_id": ts.ProjectID,
"auto": ts.Auto,
"checked": ts.Checked,
}
fileName := geminiAuth.CredentialFileName(ts.Email, ts.ProjectID, true)
record := &coreauth.Auth{
ID: fileName,
Provider: "gemini",
FileName: fileName,
Storage: &ts,
Metadata: recordMetadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save token to file: %v", errSave)
SetOAuthSessionError(state, "Failed to save token to file")
return
}
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("gemini")
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Codex authentication...")
// Generate PKCE codes
pkceCodes, err := codex.GeneratePKCECodes()
if err != nil {
log.Errorf("Failed to generate PKCE codes: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
// Generate random state parameter
state, err := misc.GenerateRandomState()
if err != nil {
log.Errorf("Failed to generate state parameter: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
// Initialize Codex auth service
openaiAuth := codex.NewCodexAuth(h.cfg)
// Generate authorization URL
authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes)
if err != nil {
log.Errorf("Failed to generate authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
RegisterOAuthSession(state, "codex")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute codex callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start codex callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(codexCallbackPort, forwarder)
}
// Wait for callback file
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var code string
for {
if !IsOAuthSessionPending(state, "codex") {
return
}
if time.Now().After(deadline) {
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
log.Error(codex.GetUserFriendlyMessage(authErr))
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
return
}
if data, errR := os.ReadFile(waitFile); errR == nil {
var m map[string]string
_ = json.Unmarshal(data, &m)
_ = os.Remove(waitFile)
if errStr := m["error"]; errStr != "" {
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Error(codex.GetUserFriendlyMessage(oauthErr))
SetOAuthSessionError(state, "Bad Request")
return
}
if m["state"] != state {
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
SetOAuthSessionError(state, "State code error")
log.Error(codex.GetUserFriendlyMessage(authErr))
return
}
code = m["code"]
break
}
time.Sleep(500 * time.Millisecond)
}
log.Debug("Authorization code received, exchanging for tokens...")
// Exchange code for tokens using internal auth service
bundle, errExchange := openaiAuth.ExchangeCodeForTokens(ctx, code, pkceCodes)
if errExchange != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errExchange)
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
return
}
// Extract additional info for filename generation
claims, _ := codex.ParseJWTToken(bundle.TokenData.IDToken)
planType := ""
hashAccountID := ""
if claims != nil {
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
if accountID := claims.GetAccountID(); accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
// Create token storage and persist
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
record := &coreauth.Auth{
ID: fileName,
Provider: "codex",
FileName: fileName,
Storage: tokenStorage,
Metadata: map[string]any{
"email": tokenStorage.Email,
"account_id": tokenStorage.AccountID,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
SetOAuthSessionError(state, "Failed to save authentication tokens")
log.Errorf("Failed to save authentication tokens: %v", errSave)
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if bundle.APIKey != "" {
fmt.Println("API key obtained and saved")
}
fmt.Println("You can now use Codex services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("codex")
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Antigravity authentication...")
authSvc := antigravity.NewAntigravityAuth(h.cfg, nil)
state, errState := misc.GenerateRandomState()
if errState != nil {
log.Errorf("Failed to generate state parameter: %v", errState)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", antigravity.CallbackPort)
authURL := authSvc.BuildAuthURL(state, redirectURI)
RegisterOAuthSession(state, "antigravity")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute antigravity callback target")
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(antigravity.CallbackPort, "antigravity", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(antigravity.CallbackPort, forwarder)
}
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var authCode string
for {
if !IsOAuthSessionPending(state, "antigravity") {
return
}
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
SetOAuthSessionError(state, "OAuth flow timed out")
return
}
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
var payload map[string]string
_ = json.Unmarshal(data, &payload)
_ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
log.Errorf("Authentication failed: %s", errStr)
SetOAuthSessionError(state, "Authentication failed")
return
}
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
log.Errorf("Authentication failed: state mismatch")
SetOAuthSessionError(state, "Authentication failed: state mismatch")
return
}
authCode = strings.TrimSpace(payload["code"])
if authCode == "" {
log.Error("Authentication failed: code not found")
SetOAuthSessionError(state, "Authentication failed: code not found")
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, authCode, redirectURI)
if errToken != nil {
log.Errorf("Failed to exchange token: %v", errToken)
SetOAuthSessionError(state, "Failed to exchange token")
return
}
accessToken := strings.TrimSpace(tokenResp.AccessToken)
if accessToken == "" {
log.Error("antigravity: token exchange returned empty access token")
SetOAuthSessionError(state, "Failed to exchange token")
return
}
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
if errInfo != nil {
log.Errorf("Failed to fetch user info: %v", errInfo)
SetOAuthSessionError(state, "Failed to fetch user info")
return
}
email = strings.TrimSpace(email)
if email == "" {
log.Error("antigravity: user info returned empty email")
SetOAuthSessionError(state, "Failed to fetch user info")
return
}
projectID := ""
if accessToken != "" {
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else {
projectID = fetchedProjectID
log.Infof("antigravity: obtained project ID %s", projectID)
}
}
now := time.Now()
metadata := map[string]any{
"type": "antigravity",
"access_token": tokenResp.AccessToken,
"refresh_token": tokenResp.RefreshToken,
"expires_in": tokenResp.ExpiresIn,
"timestamp": now.UnixMilli(),
"expired": now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
}
if email != "" {
metadata["email"] = email
}
if projectID != "" {
metadata["project_id"] = projectID
}
fileName := antigravity.CredentialFileName(email)
label := strings.TrimSpace(email)
if label == "" {
label = "antigravity"
}
record := &coreauth.Auth{
ID: fileName,
Provider: "antigravity",
FileName: fileName,
Label: label,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save token to file: %v", errSave)
SetOAuthSessionError(state, "Failed to save token to file")
return
}
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("antigravity")
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID)
}
fmt.Println("You can now use Antigravity services through this CLI")
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Qwen authentication...")
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
// Initialize Qwen auth service
qwenAuth := qwen.NewQwenAuth(h.cfg)
// Generate authorization URL
deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
if err != nil {
log.Errorf("Failed to generate authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
authURL := deviceFlow.VerificationURIComplete
RegisterOAuthSession(state, "qwen")
go func() {
fmt.Println("Waiting for authentication...")
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
if errPollForToken != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errPollForToken)
return
}
// Create token storage
tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
tokenStorage.Email = fmt.Sprintf("%d", time.Now().UnixMilli())
record := &coreauth.Auth{
ID: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
Provider: "qwen",
FileName: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
Storage: tokenStorage,
Metadata: map[string]any{"email": tokenStorage.Email},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use Qwen services through this CLI")
CompleteOAuthSession(state)
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestKimiToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing Kimi authentication...")
state := fmt.Sprintf("kmi-%d", time.Now().UnixNano())
// Initialize Kimi auth service
kimiAuth := kimi.NewKimiAuth(h.cfg)
// Generate authorization URL
deviceFlow, errStartDeviceFlow := kimiAuth.StartDeviceFlow(ctx)
if errStartDeviceFlow != nil {
log.Errorf("Failed to generate authorization URL: %v", errStartDeviceFlow)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
authURL := deviceFlow.VerificationURIComplete
if authURL == "" {
authURL = deviceFlow.VerificationURI
}
RegisterOAuthSession(state, "kimi")
go func() {
fmt.Println("Waiting for authentication...")
authBundle, errWaitForAuthorization := kimiAuth.WaitForAuthorization(ctx, deviceFlow)
if errWaitForAuthorization != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errWaitForAuthorization)
return
}
// Create token storage
tokenStorage := kimiAuth.CreateTokenStorage(authBundle)
metadata := map[string]any{
"type": "kimi",
"access_token": authBundle.TokenData.AccessToken,
"refresh_token": authBundle.TokenData.RefreshToken,
"token_type": authBundle.TokenData.TokenType,
"scope": authBundle.TokenData.Scope,
"timestamp": time.Now().UnixMilli(),
}
if authBundle.TokenData.ExpiresAt > 0 {
expired := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
metadata["expired"] = expired
}
if strings.TrimSpace(authBundle.DeviceID) != "" {
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
}
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
record := &coreauth.Auth{
ID: fileName,
Provider: "kimi",
FileName: fileName,
Label: "Kimi User",
Storage: tokenStorage,
Metadata: metadata,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave)
SetOAuthSessionError(state, "Failed to save authentication tokens")
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use Kimi services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("kimi")
}()
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestIFlowToken(c *gin.Context) {
ctx := context.Background()
ctx = PopulateAuthContext(ctx, c)
fmt.Println("Initializing iFlow authentication...")
state := fmt.Sprintf("ifl-%d", time.Now().UnixNano())
authSvc := iflowauth.NewIFlowAuth(h.cfg)
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
RegisterOAuthSession(state, "iflow")
isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
if errTarget != nil {
log.WithError(errTarget).Error("failed to compute iflow callback target")
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
return
}
var errStart error
if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start iflow callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
return
}
}
go func() {
if isWebUI {
defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder)
}
fmt.Println("Waiting for authentication...")
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-iflow-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var resultMap map[string]string
for {
if !IsOAuthSessionPending(state, "iflow") {
return
}
if time.Now().After(deadline) {
SetOAuthSessionError(state, "Authentication failed")
fmt.Println("Authentication failed: timeout waiting for callback")
return
}
if data, errR := os.ReadFile(waitFile); errR == nil {
_ = os.Remove(waitFile)
_ = json.Unmarshal(data, &resultMap)
break
}
time.Sleep(500 * time.Millisecond)
}
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %s\n", errStr)
return
}
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
SetOAuthSessionError(state, "Authentication failed")
fmt.Println("Authentication failed: state mismatch")
return
}
code := strings.TrimSpace(resultMap["code"])
if code == "" {
SetOAuthSessionError(state, "Authentication failed")
fmt.Println("Authentication failed: code missing")
return
}
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
if errExchange != nil {
SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errExchange)
return
}
tokenStorage := authSvc.CreateTokenStorage(tokenData)
identifier := strings.TrimSpace(tokenStorage.Email)
if identifier == "" {
identifier = fmt.Sprintf("%d", time.Now().UnixMilli())
tokenStorage.Email = identifier
}
record := &coreauth.Auth{
ID: fmt.Sprintf("iflow-%s.json", identifier),
Provider: "iflow",
FileName: fmt.Sprintf("iflow-%s.json", identifier),
Storage: tokenStorage,
Metadata: map[string]any{"email": identifier, "api_key": tokenStorage.APIKey},
Attributes: map[string]string{"api_key": tokenStorage.APIKey},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
SetOAuthSessionError(state, "Failed to save authentication tokens")
log.Errorf("Failed to save authentication tokens: %v", errSave)
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if tokenStorage.APIKey != "" {
fmt.Println("API key obtained and saved")
}
fmt.Println("You can now use iFlow services through this CLI")
CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("iflow")
}()
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestIFlowCookieToken(c *gin.Context) {
ctx := context.Background()
var payload struct {
Cookie string `json:"cookie"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "cookie is required"})
return
}
cookieValue := strings.TrimSpace(payload.Cookie)
if cookieValue == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "cookie is required"})
return
}
cookieValue, errNormalize := iflowauth.NormalizeCookie(cookieValue)
if errNormalize != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": errNormalize.Error()})
return
}
// Check for duplicate BXAuth before authentication
bxAuth := iflowauth.ExtractBXAuth(cookieValue)
if existingFile, err := iflowauth.CheckDuplicateBXAuth(h.cfg.AuthDir, bxAuth); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to check duplicate"})
return
} else if existingFile != "" {
existingFileName := filepath.Base(existingFile)
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "duplicate BXAuth found", "existing_file": existingFileName})
return
}
authSvc := iflowauth.NewIFlowAuth(h.cfg)
tokenData, errAuth := authSvc.AuthenticateWithCookie(ctx, cookieValue)
if errAuth != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": errAuth.Error()})
return
}
tokenData.Cookie = cookieValue
tokenStorage := authSvc.CreateCookieTokenStorage(tokenData)
email := strings.TrimSpace(tokenStorage.Email)
if email == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "failed to extract email from token"})
return
}
fileName := iflowauth.SanitizeIFlowFileName(email)
if fileName == "" {
fileName = fmt.Sprintf("iflow-%d", time.Now().UnixMilli())
} else {
fileName = fmt.Sprintf("iflow-%s", fileName)
}
tokenStorage.Email = email
timestamp := time.Now().Unix()
record := &coreauth.Auth{
ID: fmt.Sprintf("%s-%d.json", fileName, timestamp),
Provider: "iflow",
FileName: fmt.Sprintf("%s-%d.json", fileName, timestamp),
Storage: tokenStorage,
Metadata: map[string]any{
"email": email,
"api_key": tokenStorage.APIKey,
"expired": tokenStorage.Expire,
"cookie": tokenStorage.Cookie,
"type": tokenStorage.Type,
"last_refresh": tokenStorage.LastRefresh,
},
Attributes: map[string]string{
"api_key": tokenStorage.APIKey,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to save authentication tokens"})
return
}
fmt.Printf("iFlow cookie authentication successful. Token saved to %s\n", savedPath)
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"saved_path": savedPath,
"email": email,
"expired": tokenStorage.Expire,
"type": tokenStorage.Type,
})
}
type projectSelectionRequiredError struct{}
func (e *projectSelectionRequiredError) Error() string {
return "gemini cli: project selection required"
}
func ensureGeminiProjectAndOnboard(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage, requestedProject string) error {
if storage == nil {
return fmt.Errorf("gemini storage is nil")
}
trimmedRequest := strings.TrimSpace(requestedProject)
if trimmedRequest == "" {
projects, errProjects := fetchGCPProjects(ctx, httpClient)
if errProjects != nil {
return fmt.Errorf("fetch project list: %w", errProjects)
}
if len(projects) == 0 {
return fmt.Errorf("no Google Cloud projects available for this account")
}
trimmedRequest = strings.TrimSpace(projects[0].ProjectID)
if trimmedRequest == "" {
return fmt.Errorf("resolved project id is empty")
}
storage.Auto = true
} else {
storage.Auto = false
}
if err := performGeminiCLISetup(ctx, httpClient, storage, trimmedRequest); err != nil {
return err
}
if strings.TrimSpace(storage.ProjectID) == "" {
storage.ProjectID = trimmedRequest
}
return nil
}
func onboardAllGeminiProjects(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage) ([]string, error) {
projects, errProjects := fetchGCPProjects(ctx, httpClient)
if errProjects != nil {
return nil, fmt.Errorf("fetch project list: %w", errProjects)
}
if len(projects) == 0 {
return nil, fmt.Errorf("no Google Cloud projects available for this account")
}
activated := make([]string, 0, len(projects))
seen := make(map[string]struct{}, len(projects))
for _, project := range projects {
candidate := strings.TrimSpace(project.ProjectID)
if candidate == "" {
continue
}
if _, dup := seen[candidate]; dup {
continue
}
if err := performGeminiCLISetup(ctx, httpClient, storage, candidate); err != nil {
return nil, fmt.Errorf("onboard project %s: %w", candidate, err)
}
finalID := strings.TrimSpace(storage.ProjectID)
if finalID == "" {
finalID = candidate
}
activated = append(activated, finalID)
seen[candidate] = struct{}{}
}
if len(activated) == 0 {
return nil, fmt.Errorf("no Google Cloud projects available for this account")
}
return activated, nil
}
func ensureGeminiProjectsEnabled(ctx context.Context, httpClient *http.Client, projectIDs []string) error {
for _, pid := range projectIDs {
trimmed := strings.TrimSpace(pid)
if trimmed == "" {
continue
}
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, trimmed)
if errCheck != nil {
return fmt.Errorf("project %s: %w", trimmed, errCheck)
}
if !isChecked {
return fmt.Errorf("project %s: Cloud AI API not enabled", trimmed)
}
}
return nil
}
func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage, requestedProject string) error {
metadata := map[string]string{
"ideType": "IDE_UNSPECIFIED",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
}
trimmedRequest := strings.TrimSpace(requestedProject)
explicitProject := trimmedRequest != ""
loadReqBody := map[string]any{
"metadata": metadata,
}
if explicitProject {
loadReqBody["cloudaicompanionProject"] = trimmedRequest
}
var loadResp map[string]any
if errLoad := callGeminiCLI(ctx, httpClient, "loadCodeAssist", loadReqBody, &loadResp); errLoad != nil {
return fmt.Errorf("load code assist: %w", errLoad)
}
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID := trimmedRequest
if projectID == "" {
if id, okProject := loadResp["cloudaicompanionProject"].(string); okProject {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, okProject := loadResp["cloudaicompanionProject"].(map[string]any); okProject {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
}
if projectID == "" {
// Auto-discovery: try onboardUser without specifying a project
// to let Google auto-provision one (matches Gemini CLI headless behavior
// and Antigravity's FetchProjectID pattern).
autoOnboardReq := map[string]any{
"tierId": tierID,
"metadata": metadata,
}
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
defer autoCancel()
for attempt := 1; ; attempt++ {
var onboardResp map[string]any
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
}
if done, okDone := onboardResp["done"].(bool); okDone && done {
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
switch v := resp["cloudaicompanionProject"].(type) {
case string:
projectID = strings.TrimSpace(v)
case map[string]any:
if id, okID := v["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
break
}
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
select {
case <-autoCtx.Done():
return &projectSelectionRequiredError{}
case <-time.After(2 * time.Second):
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
}
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
}
onboardReqBody := map[string]any{
"tierId": tierID,
"metadata": metadata,
"cloudaicompanionProject": projectID,
}
storage.ProjectID = projectID
for {
var onboardResp map[string]any
if errOnboard := callGeminiCLI(ctx, httpClient, "onboardUser", onboardReqBody, &onboardResp); errOnboard != nil {
return fmt.Errorf("onboard user: %w", errOnboard)
}
if done, okDone := onboardResp["done"].(bool); okDone && done {
responseProjectID := ""
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
switch projectValue := resp["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
responseProjectID = strings.TrimSpace(id)
}
case string:
responseProjectID = strings.TrimSpace(projectValue)
}
}
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// For free users, use backend project ID for preview model access
log.Infof("Gemini onboarding: frontend project %s maps to backend project %s", projectID, responseProjectID)
log.Infof("Using backend project ID: %s (recommended for preview model access)", responseProjectID)
finalProjectID = responseProjectID
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
}
} else {
finalProjectID = responseProjectID
}
}
storage.ProjectID = strings.TrimSpace(finalProjectID)
if storage.ProjectID == "" {
storage.ProjectID = strings.TrimSpace(projectID)
}
if storage.ProjectID == "" {
return fmt.Errorf("onboard user completed without project id")
}
log.Infof("Onboarding complete. Using Project ID: %s", storage.ProjectID)
return nil
}
log.Println("Onboarding in progress, waiting 5 seconds...")
time.Sleep(5 * time.Second)
}
}
func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string, body any, result any) error {
endPointURL := fmt.Sprintf("%s/%s:%s", geminiCLIEndpoint, geminiCLIVersion, endpoint)
if strings.HasPrefix(endpoint, "operations/") {
endPointURL = fmt.Sprintf("%s/%s", geminiCLIEndpoint, endpoint)
}
var reader io.Reader
if body != nil {
rawBody, errMarshal := json.Marshal(body)
if errMarshal != nil {
return fmt.Errorf("marshal request body: %w", errMarshal)
}
reader = bytes.NewReader(rawBody)
}
req, errRequest := http.NewRequestWithContext(ctx, http.MethodPost, endPointURL, reader)
if errRequest != nil {
return fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo := httpClient.Do(req)
if errDo != nil {
return fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(resp.Body)
return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
if result == nil {
_, _ = io.Copy(io.Discard, resp.Body)
return nil
}
if errDecode := json.NewDecoder(resp.Body).Decode(result); errDecode != nil {
return fmt.Errorf("decode response body: %w", errDecode)
}
return nil
}
func fetchGCPProjects(ctx context.Context, httpClient *http.Client) ([]interfaces.GCPProjectProjects, error) {
req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
if errRequest != nil {
return nil, fmt.Errorf("could not create project list request: %w", errRequest)
}
resp, errDo := httpClient.Do(req)
if errDo != nil {
return nil, fmt.Errorf("failed to execute project list request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var projects interfaces.GCPProject
if errDecode := json.NewDecoder(resp.Body).Decode(&projects); errDecode != nil {
return nil, fmt.Errorf("failed to unmarshal project list: %w", errDecode)
}
return projects.Projects, nil
}
func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projectID string) (bool, error) {
serviceUsageURL := "https://serviceusage.googleapis.com"
requiredServices := []string{
"cloudaicompanion.googleapis.com",
}
for _, service := range requiredServices {
checkURL := fmt.Sprintf("%s/v1/projects/%s/services/%s", serviceUsageURL, projectID, service)
req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, checkURL, nil)
if errRequest != nil {
return false, fmt.Errorf("failed to create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo := httpClient.Do(req)
if errDo != nil {
return false, fmt.Errorf("failed to execute request: %w", errDo)
}
if resp.StatusCode == http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
if gjson.GetBytes(bodyBytes, "state").String() == "ENABLED" {
_ = resp.Body.Close()
continue
}
}
_ = resp.Body.Close()
enableURL := fmt.Sprintf("%s/v1/projects/%s/services/%s:enable", serviceUsageURL, projectID, service)
req, errRequest = http.NewRequestWithContext(ctx, http.MethodPost, enableURL, strings.NewReader("{}"))
if errRequest != nil {
return false, fmt.Errorf("failed to create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo = httpClient.Do(req)
if errDo != nil {
return false, fmt.Errorf("failed to execute request: %w", errDo)
}
bodyBytes, _ := io.ReadAll(resp.Body)
errMessage := string(bodyBytes)
errMessageResult := gjson.GetBytes(bodyBytes, "error.message")
if errMessageResult.Exists() {
errMessage = errMessageResult.String()
}
if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated {
_ = resp.Body.Close()
continue
} else if resp.StatusCode == http.StatusBadRequest {
_ = resp.Body.Close()
if strings.Contains(strings.ToLower(errMessage), "already enabled") {
continue
}
}
_ = resp.Body.Close()
return false, fmt.Errorf("project activation required: %s", errMessage)
}
return true, nil
}
func (h *Handler) GetAuthStatus(c *gin.Context) {
state := strings.TrimSpace(c.Query("state"))
if state == "" {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
return
}
if err := ValidateOAuthState(state); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
return
}
_, status, ok := GetOAuthSession(state)
if !ok {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
return
}
if status != "" {
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
return
}
c.JSON(http.StatusOK, gin.H{"status": "wait"})
}
// PopulateAuthContext extracts request info and adds it to the context
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
info := &coreauth.RequestInfo{
Query: c.Request.URL.Query(),
Headers: c.Request.Header,
}
return coreauth.WithRequestInfo(ctx, info)
}
================================================
FILE: internal/api/handlers/management/auth_files_delete_test.go
================================================
package management
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestDeleteAuthFile_UsesAuthPathFromManager(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
tempDir := t.TempDir()
authDir := filepath.Join(tempDir, "auth")
externalDir := filepath.Join(tempDir, "external")
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
}
if errMkdirExternal := os.MkdirAll(externalDir, 0o700); errMkdirExternal != nil {
t.Fatalf("failed to create external dir: %v", errMkdirExternal)
}
fileName := "codex-user@example.com-plus.json"
shadowPath := filepath.Join(authDir, fileName)
realPath := filepath.Join(externalDir, fileName)
if errWriteShadow := os.WriteFile(shadowPath, []byte(`{"type":"codex","email":"shadow@example.com"}`), 0o600); errWriteShadow != nil {
t.Fatalf("failed to write shadow file: %v", errWriteShadow)
}
if errWriteReal := os.WriteFile(realPath, []byte(`{"type":"codex","email":"real@example.com"}`), 0o600); errWriteReal != nil {
t.Fatalf("failed to write real file: %v", errWriteReal)
}
manager := coreauth.NewManager(nil, nil, nil)
record := &coreauth.Auth{
ID: "legacy/" + fileName,
FileName: fileName,
Provider: "codex",
Status: coreauth.StatusError,
Unavailable: true,
Attributes: map[string]string{
"path": realPath,
},
Metadata: map[string]any{
"type": "codex",
"email": "real@example.com",
},
}
if _, errRegister := manager.Register(context.Background(), record); errRegister != nil {
t.Fatalf("failed to register auth record: %v", errRegister)
}
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
h.tokenStore = &memoryAuthStore{}
deleteRec := httptest.NewRecorder()
deleteCtx, _ := gin.CreateTestContext(deleteRec)
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
deleteCtx.Request = deleteReq
h.DeleteAuthFile(deleteCtx)
if deleteRec.Code != http.StatusOK {
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
}
if _, errStatReal := os.Stat(realPath); !os.IsNotExist(errStatReal) {
t.Fatalf("expected managed auth file to be removed, stat err: %v", errStatReal)
}
if _, errStatShadow := os.Stat(shadowPath); errStatShadow != nil {
t.Fatalf("expected shadow auth file to remain, stat err: %v", errStatShadow)
}
listRec := httptest.NewRecorder()
listCtx, _ := gin.CreateTestContext(listRec)
listReq := httptest.NewRequest(http.MethodGet, "/v0/management/auth-files", nil)
listCtx.Request = listReq
h.ListAuthFiles(listCtx)
if listRec.Code != http.StatusOK {
t.Fatalf("expected list status %d, got %d with body %s", http.StatusOK, listRec.Code, listRec.Body.String())
}
var listPayload map[string]any
if errUnmarshal := json.Unmarshal(listRec.Body.Bytes(), &listPayload); errUnmarshal != nil {
t.Fatalf("failed to decode list payload: %v", errUnmarshal)
}
filesRaw, ok := listPayload["files"].([]any)
if !ok {
t.Fatalf("expected files array, payload: %#v", listPayload)
}
if len(filesRaw) != 0 {
t.Fatalf("expected removed auth to be hidden from list, got %d entries", len(filesRaw))
}
}
func TestDeleteAuthFile_FallbackToAuthDirPath(t *testing.T) {
t.Setenv("MANAGEMENT_PASSWORD", "")
gin.SetMode(gin.TestMode)
authDir := t.TempDir()
fileName := "fallback-user.json"
filePath := filepath.Join(authDir, fileName)
if errWrite := os.WriteFile(filePath, []byte(`{"type":"codex"}`), 0o600); errWrite != nil {
t.Fatalf("failed to write auth file: %v", errWrite)
}
manager := coreauth.NewManager(nil, nil, nil)
h := NewHandlerWithoutConfigFilePath(&config.Config{AuthDir: authDir}, manager)
h.tokenStore = &memoryAuthStore{}
deleteRec := httptest.NewRecorder()
deleteCtx, _ := gin.CreateTestContext(deleteRec)
deleteReq := httptest.NewRequest(http.MethodDelete, "/v0/management/auth-files?name="+url.QueryEscape(fileName), nil)
deleteCtx.Request = deleteReq
h.DeleteAuthFile(deleteCtx)
if deleteRec.Code != http.StatusOK {
t.Fatalf("expected delete status %d, got %d with body %s", http.StatusOK, deleteRec.Code, deleteRec.Body.String())
}
if _, errStat := os.Stat(filePath); !os.IsNotExist(errStat) {
t.Fatalf("expected auth file to be removed from auth dir, stat err: %v", errStat)
}
}
================================================
FILE: internal/api/handlers/management/config_basic.go
================================================
package management
import (
"encoding/json"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
)
const (
latestReleaseURL = "https://api.github.com/repos/router-for-me/CLIProxyAPI/releases/latest"
latestReleaseUserAgent = "CLIProxyAPI"
)
func (h *Handler) GetConfig(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{})
return
}
c.JSON(200, new(*h.cfg))
}
type releaseInfo struct {
TagName string `json:"tag_name"`
Name string `json:"name"`
}
// GetLatestVersion returns the latest release version from GitHub without downloading assets.
func (h *Handler) GetLatestVersion(c *gin.Context) {
client := &http.Client{Timeout: 10 * time.Second}
proxyURL := ""
if h != nil && h.cfg != nil {
proxyURL = strings.TrimSpace(h.cfg.ProxyURL)
}
if proxyURL != "" {
sdkCfg := &sdkconfig.SDKConfig{ProxyURL: proxyURL}
util.SetProxy(sdkCfg, client)
}
req, err := http.NewRequestWithContext(c.Request.Context(), http.MethodGet, latestReleaseURL, nil)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "request_create_failed", "message": err.Error()})
return
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("User-Agent", latestReleaseUserAgent)
resp, err := client.Do(req)
if err != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "request_failed", "message": err.Error()})
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.WithError(errClose).Debug("failed to close latest version response body")
}
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
c.JSON(http.StatusBadGateway, gin.H{"error": "unexpected_status", "message": fmt.Sprintf("status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))})
return
}
var info releaseInfo
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
c.JSON(http.StatusBadGateway, gin.H{"error": "decode_failed", "message": errDecode.Error()})
return
}
version := strings.TrimSpace(info.TagName)
if version == "" {
version = strings.TrimSpace(info.Name)
}
if version == "" {
c.JSON(http.StatusBadGateway, gin.H{"error": "invalid_response", "message": "missing release version"})
return
}
c.JSON(http.StatusOK, gin.H{"latest-version": version})
}
func WriteConfig(path string, data []byte) error {
data = config.NormalizeCommentIndentation(data)
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return err
}
if _, errWrite := f.Write(data); errWrite != nil {
_ = f.Close()
return errWrite
}
if errSync := f.Sync(); errSync != nil {
_ = f.Close()
return errSync
}
return f.Close()
}
func (h *Handler) PutConfigYAML(c *gin.Context) {
body, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": "cannot read request body"})
return
}
var cfg config.Config
if err = yaml.Unmarshal(body, &cfg); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": err.Error()})
return
}
// Validate config using LoadConfigOptional with optional=false to enforce parsing
tmpDir := filepath.Dir(h.configFilePath)
tmpFile, err := os.CreateTemp(tmpDir, "config-validate-*.yaml")
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
return
}
tempFile := tmpFile.Name()
if _, errWrite := tmpFile.Write(body); errWrite != nil {
_ = tmpFile.Close()
_ = os.Remove(tempFile)
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errWrite.Error()})
return
}
if errClose := tmpFile.Close(); errClose != nil {
_ = os.Remove(tempFile)
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errClose.Error()})
return
}
defer func() {
_ = os.Remove(tempFile)
}()
_, err = config.LoadConfigOptional(tempFile, false)
if err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "invalid_config", "message": err.Error()})
return
}
h.mu.Lock()
defer h.mu.Unlock()
if WriteConfig(h.configFilePath, body) != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": "failed to write config"})
return
}
// Reload into handler to keep memory in sync
newCfg, err := config.LoadConfig(h.configFilePath)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "reload_failed", "message": err.Error()})
return
}
h.cfg = newCfg
c.JSON(http.StatusOK, gin.H{"ok": true, "changed": []string{"config"}})
}
// GetConfigYAML returns the raw config.yaml file bytes without re-encoding.
// It preserves comments and original formatting/styles.
func (h *Handler) GetConfigYAML(c *gin.Context) {
data, err := os.ReadFile(h.configFilePath)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "not_found", "message": "config file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": "read_failed", "message": err.Error()})
return
}
c.Header("Content-Type", "application/yaml; charset=utf-8")
c.Header("Cache-Control", "no-store")
c.Header("X-Content-Type-Options", "nosniff")
// Write raw bytes as-is
_, _ = c.Writer.Write(data)
}
// Debug
func (h *Handler) GetDebug(c *gin.Context) { c.JSON(200, gin.H{"debug": h.cfg.Debug}) }
func (h *Handler) PutDebug(c *gin.Context) { h.updateBoolField(c, func(v bool) { h.cfg.Debug = v }) }
// UsageStatisticsEnabled
func (h *Handler) GetUsageStatisticsEnabled(c *gin.Context) {
c.JSON(200, gin.H{"usage-statistics-enabled": h.cfg.UsageStatisticsEnabled})
}
func (h *Handler) PutUsageStatisticsEnabled(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.UsageStatisticsEnabled = v })
}
// UsageStatisticsEnabled
func (h *Handler) GetLoggingToFile(c *gin.Context) {
c.JSON(200, gin.H{"logging-to-file": h.cfg.LoggingToFile})
}
func (h *Handler) PutLoggingToFile(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.LoggingToFile = v })
}
// LogsMaxTotalSizeMB
func (h *Handler) GetLogsMaxTotalSizeMB(c *gin.Context) {
c.JSON(200, gin.H{"logs-max-total-size-mb": h.cfg.LogsMaxTotalSizeMB})
}
func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
var body struct {
Value *int `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
value := *body.Value
if value < 0 {
value = 0
}
h.cfg.LogsMaxTotalSizeMB = value
h.persist(c)
}
// ErrorLogsMaxFiles
func (h *Handler) GetErrorLogsMaxFiles(c *gin.Context) {
c.JSON(200, gin.H{"error-logs-max-files": h.cfg.ErrorLogsMaxFiles})
}
func (h *Handler) PutErrorLogsMaxFiles(c *gin.Context) {
var body struct {
Value *int `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
value := *body.Value
if value < 0 {
value = 10
}
h.cfg.ErrorLogsMaxFiles = value
h.persist(c)
}
// Request log
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
func (h *Handler) PutRequestLog(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v })
}
// Websocket auth
func (h *Handler) GetWebsocketAuth(c *gin.Context) {
c.JSON(200, gin.H{"ws-auth": h.cfg.WebsocketAuth})
}
func (h *Handler) PutWebsocketAuth(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.WebsocketAuth = v })
}
// Request retry
func (h *Handler) GetRequestRetry(c *gin.Context) {
c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry})
}
func (h *Handler) PutRequestRetry(c *gin.Context) {
h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v })
}
// Max retry interval
func (h *Handler) GetMaxRetryInterval(c *gin.Context) {
c.JSON(200, gin.H{"max-retry-interval": h.cfg.MaxRetryInterval})
}
func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
}
// ForceModelPrefix
func (h *Handler) GetForceModelPrefix(c *gin.Context) {
c.JSON(200, gin.H{"force-model-prefix": h.cfg.ForceModelPrefix})
}
func (h *Handler) PutForceModelPrefix(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.ForceModelPrefix = v })
}
func normalizeRoutingStrategy(strategy string) (string, bool) {
normalized := strings.ToLower(strings.TrimSpace(strategy))
switch normalized {
case "", "round-robin", "roundrobin", "rr":
return "round-robin", true
case "fill-first", "fillfirst", "ff":
return "fill-first", true
default:
return "", false
}
}
// RoutingStrategy
func (h *Handler) GetRoutingStrategy(c *gin.Context) {
strategy, ok := normalizeRoutingStrategy(h.cfg.Routing.Strategy)
if !ok {
c.JSON(200, gin.H{"strategy": strings.TrimSpace(h.cfg.Routing.Strategy)})
return
}
c.JSON(200, gin.H{"strategy": strategy})
}
func (h *Handler) PutRoutingStrategy(c *gin.Context) {
var body struct {
Value *string `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
normalized, ok := normalizeRoutingStrategy(*body.Value)
if !ok {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid strategy"})
return
}
h.cfg.Routing.Strategy = normalized
h.persist(c)
}
// Proxy URL
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
func (h *Handler) PutProxyURL(c *gin.Context) {
h.updateStringField(c, func(v string) { h.cfg.ProxyURL = v })
}
func (h *Handler) DeleteProxyURL(c *gin.Context) {
h.cfg.ProxyURL = ""
h.persist(c)
}
================================================
FILE: internal/api/handlers/management/config_lists.go
================================================
package management
import (
"encoding/json"
"fmt"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// Generic helpers for list[string]
func (h *Handler) putStringList(c *gin.Context, set func([]string), after func()) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []string
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []string `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
set(arr)
if after != nil {
after()
}
h.persist(c)
}
func (h *Handler) patchStringList(c *gin.Context, target *[]string, after func()) {
var body struct {
Old *string `json:"old"`
New *string `json:"new"`
Index *int `json:"index"`
Value *string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Index != nil && body.Value != nil && *body.Index >= 0 && *body.Index < len(*target) {
(*target)[*body.Index] = *body.Value
if after != nil {
after()
}
h.persist(c)
return
}
if body.Old != nil && body.New != nil {
for i := range *target {
if (*target)[i] == *body.Old {
(*target)[i] = *body.New
if after != nil {
after()
}
h.persist(c)
return
}
}
*target = append(*target, *body.New)
if after != nil {
after()
}
h.persist(c)
return
}
c.JSON(400, gin.H{"error": "missing fields"})
}
func (h *Handler) deleteFromStringList(c *gin.Context, target *[]string, after func()) {
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, err := fmt.Sscanf(idxStr, "%d", &idx)
if err == nil && idx >= 0 && idx < len(*target) {
*target = append((*target)[:idx], (*target)[idx+1:]...)
if after != nil {
after()
}
h.persist(c)
return
}
}
if val := strings.TrimSpace(c.Query("value")); val != "" {
out := make([]string, 0, len(*target))
for _, v := range *target {
if strings.TrimSpace(v) != val {
out = append(out, v)
}
}
*target = out
if after != nil {
after()
}
h.persist(c)
return
}
c.JSON(400, gin.H{"error": "missing index or value"})
}
// api-keys
func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.cfg.APIKeys}) }
func (h *Handler) PutAPIKeys(c *gin.Context) {
h.putStringList(c, func(v []string) {
h.cfg.APIKeys = append([]string(nil), v...)
}, nil)
}
func (h *Handler) PatchAPIKeys(c *gin.Context) {
h.patchStringList(c, &h.cfg.APIKeys, func() {})
}
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
h.deleteFromStringList(c, &h.cfg.APIKeys, func() {})
}
// gemini-api-key: []GeminiKey
func (h *Handler) GetGeminiKeys(c *gin.Context) {
c.JSON(200, gin.H{"gemini-api-key": h.cfg.GeminiKey})
}
func (h *Handler) PutGeminiKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.GeminiKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.GeminiKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
h.cfg.GeminiKey = append([]config.GeminiKey(nil), arr...)
h.cfg.SanitizeGeminiKeys()
h.persist(c)
}
func (h *Handler) PatchGeminiKey(c *gin.Context) {
type geminiKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *geminiKeyPatch `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
targetIndex := -1
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
targetIndex = *body.Index
}
if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
if match != "" {
for i := range h.cfg.GeminiKey {
if h.cfg.GeminiKey[i].APIKey == match {
targetIndex = i
break
}
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.GeminiKey[targetIndex]
if body.Value.APIKey != nil {
trimmed := strings.TrimSpace(*body.Value.APIKey)
if trimmed == "" {
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...)
h.cfg.SanitizeGeminiKeys()
h.persist(c)
return
}
entry.APIKey = trimmed
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
h.cfg.GeminiKey[targetIndex] = entry
h.cfg.SanitizeGeminiKeys()
h.persist(c)
}
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
for _, v := range h.cfg.GeminiKey {
if v.APIKey != val {
out = append(out, v)
}
}
if len(out) != len(h.cfg.GeminiKey) {
h.cfg.GeminiKey = out
h.cfg.SanitizeGeminiKeys()
h.persist(c)
} else {
c.JSON(404, gin.H{"error": "item not found"})
}
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
if _, err := fmt.Sscanf(idxStr, "%d", &idx); err == nil && idx >= 0 && idx < len(h.cfg.GeminiKey) {
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:idx], h.cfg.GeminiKey[idx+1:]...)
h.cfg.SanitizeGeminiKeys()
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}
// claude-api-key: []ClaudeKey
func (h *Handler) GetClaudeKeys(c *gin.Context) {
c.JSON(200, gin.H{"claude-api-key": h.cfg.ClaudeKey})
}
func (h *Handler) PutClaudeKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.ClaudeKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.ClaudeKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
for i := range arr {
normalizeClaudeKey(&arr[i])
}
h.cfg.ClaudeKey = arr
h.cfg.SanitizeClaudeKeys()
h.persist(c)
}
func (h *Handler) PatchClaudeKey(c *gin.Context) {
type claudeKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Models *[]config.ClaudeModel `json:"models"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *claudeKeyPatch `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
targetIndex := -1
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
targetIndex = *body.Index
}
if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
for i := range h.cfg.ClaudeKey {
if h.cfg.ClaudeKey[i].APIKey == match {
targetIndex = i
break
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.ClaudeKey[targetIndex]
if body.Value.APIKey != nil {
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Models != nil {
entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
normalizeClaudeKey(&entry)
h.cfg.ClaudeKey[targetIndex] = entry
h.cfg.SanitizeClaudeKeys()
h.persist(c)
}
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" {
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
for _, v := range h.cfg.ClaudeKey {
if v.APIKey != val {
out = append(out, v)
}
}
h.cfg.ClaudeKey = out
h.cfg.SanitizeClaudeKeys()
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, err := fmt.Sscanf(idxStr, "%d", &idx)
if err == nil && idx >= 0 && idx < len(h.cfg.ClaudeKey) {
h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:idx], h.cfg.ClaudeKey[idx+1:]...)
h.cfg.SanitizeClaudeKeys()
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}
// openai-compatibility: []OpenAICompatibility
func (h *Handler) GetOpenAICompat(c *gin.Context) {
c.JSON(200, gin.H{"openai-compatibility": normalizedOpenAICompatibilityEntries(h.cfg.OpenAICompatibility)})
}
func (h *Handler) PutOpenAICompat(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.OpenAICompatibility
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.OpenAICompatibility `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
filtered := make([]config.OpenAICompatibility, 0, len(arr))
for i := range arr {
normalizeOpenAICompatibilityEntry(&arr[i])
if strings.TrimSpace(arr[i].BaseURL) != "" {
filtered = append(filtered, arr[i])
}
}
h.cfg.OpenAICompatibility = filtered
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
}
func (h *Handler) PatchOpenAICompat(c *gin.Context) {
type openAICompatPatch struct {
Name *string `json:"name"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"`
Models *[]config.OpenAICompatibilityModel `json:"models"`
Headers *map[string]string `json:"headers"`
}
var body struct {
Name *string `json:"name"`
Index *int `json:"index"`
Value *openAICompatPatch `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
targetIndex := -1
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
targetIndex = *body.Index
}
if targetIndex == -1 && body.Name != nil {
match := strings.TrimSpace(*body.Name)
for i := range h.cfg.OpenAICompatibility {
if h.cfg.OpenAICompatibility[i].Name == match {
targetIndex = i
break
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.OpenAICompatibility[targetIndex]
if body.Value.Name != nil {
entry.Name = strings.TrimSpace(*body.Value.Name)
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...)
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
return
}
entry.BaseURL = trimmed
}
if body.Value.APIKeyEntries != nil {
entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...)
}
if body.Value.Models != nil {
entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
normalizeOpenAICompatibilityEntry(&entry)
h.cfg.OpenAICompatibility[targetIndex] = entry
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
}
func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
if name := c.Query("name"); name != "" {
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
for _, v := range h.cfg.OpenAICompatibility {
if v.Name != name {
out = append(out, v)
}
}
h.cfg.OpenAICompatibility = out
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, err := fmt.Sscanf(idxStr, "%d", &idx)
if err == nil && idx >= 0 && idx < len(h.cfg.OpenAICompatibility) {
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:idx], h.cfg.OpenAICompatibility[idx+1:]...)
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing name or index"})
}
// vertex-api-key: []VertexCompatKey
func (h *Handler) GetVertexCompatKeys(c *gin.Context) {
c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey})
}
func (h *Handler) PutVertexCompatKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.VertexCompatKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.VertexCompatKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
for i := range arr {
normalizeVertexCompatKey(&arr[i])
if arr[i].APIKey == "" {
c.JSON(400, gin.H{"error": fmt.Sprintf("vertex-api-key[%d].api-key is required", i)})
return
}
}
h.cfg.VertexCompatAPIKey = append([]config.VertexCompatKey(nil), arr...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) PatchVertexCompatKey(c *gin.Context) {
type vertexCompatPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
Models *[]config.VertexCompatModel `json:"models"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *vertexCompatPatch `json:"value"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
targetIndex := -1
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) {
targetIndex = *body.Index
}
if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
if match != "" {
for i := range h.cfg.VertexCompatAPIKey {
if h.cfg.VertexCompatAPIKey[i].APIKey == match {
targetIndex = i
break
}
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.VertexCompatAPIKey[targetIndex]
if body.Value.APIKey != nil {
trimmed := strings.TrimSpace(*body.Value.APIKey)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.APIKey = trimmed
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
entry.BaseURL = trimmed
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.Models != nil {
entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
normalizeVertexCompatKey(&entry)
h.cfg.VertexCompatAPIKey[targetIndex] = entry
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
}
func (h *Handler) DeleteVertexCompatKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey))
for _, v := range h.cfg.VertexCompatAPIKey {
if v.APIKey != val {
out = append(out, v)
}
}
h.cfg.VertexCompatAPIKey = out
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, errScan := fmt.Sscanf(idxStr, "%d", &idx)
if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) {
h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...)
h.cfg.SanitizeVertexCompatKeys()
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}
// oauth-excluded-models: map[string][]string
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
}
func (h *Handler) PutOAuthExcludedModels(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var entries map[string][]string
if err = json.Unmarshal(data, &entries); err != nil {
var wrapper struct {
Items map[string][]string `json:"items"`
}
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
entries = wrapper.Items
}
h.cfg.OAuthExcludedModels = config.NormalizeOAuthExcludedModels(entries)
h.persist(c)
}
func (h *Handler) PatchOAuthExcludedModels(c *gin.Context) {
var body struct {
Provider *string `json:"provider"`
Models []string `json:"models"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Provider == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
provider := strings.ToLower(strings.TrimSpace(*body.Provider))
if provider == "" {
c.JSON(400, gin.H{"error": "invalid provider"})
return
}
normalized := config.NormalizeExcludedModels(body.Models)
if len(normalized) == 0 {
if h.cfg.OAuthExcludedModels == nil {
c.JSON(404, gin.H{"error": "provider not found"})
return
}
if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
c.JSON(404, gin.H{"error": "provider not found"})
return
}
delete(h.cfg.OAuthExcludedModels, provider)
if len(h.cfg.OAuthExcludedModels) == 0 {
h.cfg.OAuthExcludedModels = nil
}
h.persist(c)
return
}
if h.cfg.OAuthExcludedModels == nil {
h.cfg.OAuthExcludedModels = make(map[string][]string)
}
h.cfg.OAuthExcludedModels[provider] = normalized
h.persist(c)
}
func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
provider := strings.ToLower(strings.TrimSpace(c.Query("provider")))
if provider == "" {
c.JSON(400, gin.H{"error": "missing provider"})
return
}
if h.cfg.OAuthExcludedModels == nil {
c.JSON(404, gin.H{"error": "provider not found"})
return
}
if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
c.JSON(404, gin.H{"error": "provider not found"})
return
}
delete(h.cfg.OAuthExcludedModels, provider)
if len(h.cfg.OAuthExcludedModels) == 0 {
h.cfg.OAuthExcludedModels = nil
}
h.persist(c)
}
// oauth-model-alias: map[string][]OAuthModelAlias
func (h *Handler) GetOAuthModelAlias(c *gin.Context) {
c.JSON(200, gin.H{"oauth-model-alias": sanitizedOAuthModelAlias(h.cfg.OAuthModelAlias)})
}
func (h *Handler) PutOAuthModelAlias(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var entries map[string][]config.OAuthModelAlias
if err = json.Unmarshal(data, &entries); err != nil {
var wrapper struct {
Items map[string][]config.OAuthModelAlias `json:"items"`
}
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
entries = wrapper.Items
}
h.cfg.OAuthModelAlias = sanitizedOAuthModelAlias(entries)
h.persist(c)
}
func (h *Handler) PatchOAuthModelAlias(c *gin.Context) {
var body struct {
Provider *string `json:"provider"`
Channel *string `json:"channel"`
Aliases []config.OAuthModelAlias `json:"aliases"`
}
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
channelRaw := ""
if body.Channel != nil {
channelRaw = *body.Channel
} else if body.Provider != nil {
channelRaw = *body.Provider
}
channel := strings.ToLower(strings.TrimSpace(channelRaw))
if channel == "" {
c.JSON(400, gin.H{"error": "invalid channel"})
return
}
normalizedMap := sanitizedOAuthModelAlias(map[string][]config.OAuthModelAlias{channel: body.Aliases})
normalized := normalizedMap[channel]
if len(normalized) == 0 {
if h.cfg.OAuthModelAlias == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
h.persist(c)
return
}
if h.cfg.OAuthModelAlias == nil {
h.cfg.OAuthModelAlias = make(map[string][]config.OAuthModelAlias)
}
h.cfg.OAuthModelAlias[channel] = normalized
h.persist(c)
}
func (h *Handler) DeleteOAuthModelAlias(c *gin.Context) {
channel := strings.ToLower(strings.TrimSpace(c.Query("channel")))
if channel == "" {
channel = strings.ToLower(strings.TrimSpace(c.Query("provider")))
}
if channel == "" {
c.JSON(400, gin.H{"error": "missing channel"})
return
}
if h.cfg.OAuthModelAlias == nil {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
if _, ok := h.cfg.OAuthModelAlias[channel]; !ok {
c.JSON(404, gin.H{"error": "channel not found"})
return
}
delete(h.cfg.OAuthModelAlias, channel)
if len(h.cfg.OAuthModelAlias) == 0 {
h.cfg.OAuthModelAlias = nil
}
h.persist(c)
}
// codex-api-key: []CodexKey
func (h *Handler) GetCodexKeys(c *gin.Context) {
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
}
func (h *Handler) PutCodexKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.CodexKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.CodexKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
// Filter out codex entries with empty base-url (treat as removed)
filtered := make([]config.CodexKey, 0, len(arr))
for i := range arr {
entry := arr[i]
normalizeCodexKey(&entry)
if entry.BaseURL == "" {
continue
}
filtered = append(filtered, entry)
}
h.cfg.CodexKey = filtered
h.cfg.SanitizeCodexKeys()
h.persist(c)
}
func (h *Handler) PatchCodexKey(c *gin.Context) {
type codexKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Models *[]config.CodexModel `json:"models"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *codexKeyPatch `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
targetIndex := -1
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
targetIndex = *body.Index
}
if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
for i := range h.cfg.CodexKey {
if h.cfg.CodexKey[i].APIKey == match {
targetIndex = i
break
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.CodexKey[targetIndex]
if body.Value.APIKey != nil {
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...)
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
entry.BaseURL = trimmed
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Models != nil {
entry.Models = append([]config.CodexModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
normalizeCodexKey(&entry)
h.cfg.CodexKey[targetIndex] = entry
h.cfg.SanitizeCodexKeys()
h.persist(c)
}
func (h *Handler) DeleteCodexKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" {
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
for _, v := range h.cfg.CodexKey {
if v.APIKey != val {
out = append(out, v)
}
}
h.cfg.CodexKey = out
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, err := fmt.Sscanf(idxStr, "%d", &idx)
if err == nil && idx >= 0 && idx < len(h.cfg.CodexKey) {
h.cfg.CodexKey = append(h.cfg.CodexKey[:idx], h.cfg.CodexKey[idx+1:]...)
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}
func normalizeOpenAICompatibilityEntry(entry *config.OpenAICompatibility) {
if entry == nil {
return
}
// Trim base-url; empty base-url indicates provider should be removed by sanitization
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
existing := make(map[string]struct{}, len(entry.APIKeyEntries))
for i := range entry.APIKeyEntries {
trimmed := strings.TrimSpace(entry.APIKeyEntries[i].APIKey)
entry.APIKeyEntries[i].APIKey = trimmed
if trimmed != "" {
existing[trimmed] = struct{}{}
}
}
}
func normalizedOpenAICompatibilityEntries(entries []config.OpenAICompatibility) []config.OpenAICompatibility {
if len(entries) == 0 {
return nil
}
out := make([]config.OpenAICompatibility, len(entries))
for i := range entries {
copyEntry := entries[i]
if len(copyEntry.APIKeyEntries) > 0 {
copyEntry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), copyEntry.APIKeyEntries...)
}
normalizeOpenAICompatibilityEntry(©Entry)
out[i] = copyEntry
}
return out
}
func normalizeClaudeKey(entry *config.ClaudeKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.ClaudeModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" && model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func normalizeCodexKey(entry *config.CodexKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.CodexModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" && model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
if entry == nil {
return
}
entry.APIKey = strings.TrimSpace(entry.APIKey)
entry.Prefix = strings.TrimSpace(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = config.NormalizeHeaders(entry.Headers)
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
if len(entry.Models) == 0 {
return
}
normalized := make([]config.VertexCompatModel, 0, len(entry.Models))
for i := range entry.Models {
model := entry.Models[i]
model.Name = strings.TrimSpace(model.Name)
model.Alias = strings.TrimSpace(model.Alias)
if model.Name == "" || model.Alias == "" {
continue
}
normalized = append(normalized, model)
}
entry.Models = normalized
}
func sanitizedOAuthModelAlias(entries map[string][]config.OAuthModelAlias) map[string][]config.OAuthModelAlias {
if len(entries) == 0 {
return nil
}
copied := make(map[string][]config.OAuthModelAlias, len(entries))
for channel, aliases := range entries {
if len(aliases) == 0 {
continue
}
copied[channel] = append([]config.OAuthModelAlias(nil), aliases...)
}
if len(copied) == 0 {
return nil
}
cfg := config.Config{OAuthModelAlias: copied}
cfg.SanitizeOAuthModelAlias()
if len(cfg.OAuthModelAlias) == 0 {
return nil
}
return cfg.OAuthModelAlias
}
// GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"ampcode": config.AmpCode{}})
return
}
c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode})
}
// GetAmpUpstreamURL returns the ampcode upstream URL.
func (h *Handler) GetAmpUpstreamURL(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-url": ""})
return
}
c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL})
}
// PutAmpUpstreamURL updates the ampcode upstream URL.
func (h *Handler) PutAmpUpstreamURL(c *gin.Context) {
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) })
}
// DeleteAmpUpstreamURL clears the ampcode upstream URL.
func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) {
h.cfg.AmpCode.UpstreamURL = ""
h.persist(c)
}
// GetAmpUpstreamAPIKey returns the ampcode upstream API key.
func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-key": ""})
return
}
c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey})
}
// PutAmpUpstreamAPIKey updates the ampcode upstream API key.
func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) {
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) })
}
// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key.
func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) {
h.cfg.AmpCode.UpstreamAPIKey = ""
h.persist(c)
}
// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting.
func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"restrict-management-to-localhost": true})
return
}
c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost})
}
// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting.
func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v })
}
// GetAmpModelMappings returns the ampcode model mappings.
func (h *Handler) GetAmpModelMappings(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}})
return
}
c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings})
}
// PutAmpModelMappings replaces all ampcode model mappings.
func (h *Handler) PutAmpModelMappings(c *gin.Context) {
var body struct {
Value []config.AmpModelMapping `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
h.cfg.AmpCode.ModelMappings = body.Value
h.persist(c)
}
// PatchAmpModelMappings adds or updates model mappings.
func (h *Handler) PatchAmpModelMappings(c *gin.Context) {
var body struct {
Value []config.AmpModelMapping `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, m := range h.cfg.AmpCode.ModelMappings {
existing[strings.TrimSpace(m.From)] = i
}
for _, newMapping := range body.Value {
from := strings.TrimSpace(newMapping.From)
if idx, ok := existing[from]; ok {
h.cfg.AmpCode.ModelMappings[idx] = newMapping
} else {
h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping)
existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1
}
}
h.persist(c)
}
// DeleteAmpModelMappings removes specified model mappings by "from" field.
func (h *Handler) DeleteAmpModelMappings(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 {
h.cfg.AmpCode.ModelMappings = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, from := range body.Value {
toRemove[strings.TrimSpace(from)] = true
}
newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings))
for _, m := range h.cfg.AmpCode.ModelMappings {
if !toRemove[strings.TrimSpace(m.From)] {
newMappings = append(newMappings, m)
}
}
h.cfg.AmpCode.ModelMappings = newMappings
h.persist(c)
}
// GetAmpForceModelMappings returns whether model mappings are forced.
func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"force-model-mappings": false})
return
}
c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings})
}
// PutAmpForceModelMappings updates the force model mappings setting.
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
}
// GetAmpUpstreamAPIKeys returns the ampcode upstream API keys mapping.
func (h *Handler) GetAmpUpstreamAPIKeys(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-keys": []config.AmpUpstreamAPIKeyEntry{}})
return
}
c.JSON(200, gin.H{"upstream-api-keys": h.cfg.AmpCode.UpstreamAPIKeys})
}
// PutAmpUpstreamAPIKeys replaces all ampcode upstream API keys mappings.
func (h *Handler) PutAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
// Normalize entries: trim whitespace, filter empty
normalized := normalizeAmpUpstreamAPIKeyEntries(body.Value)
h.cfg.AmpCode.UpstreamAPIKeys = normalized
h.persist(c)
}
// PatchAmpUpstreamAPIKeys adds or updates upstream API keys entries.
// Matching is done by upstream-api-key value.
func (h *Handler) PatchAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []config.AmpUpstreamAPIKeyEntry `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
existing[strings.TrimSpace(entry.UpstreamAPIKey)] = i
}
for _, newEntry := range body.Value {
upstreamKey := strings.TrimSpace(newEntry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
normalizedEntry := config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: normalizeAPIKeysList(newEntry.APIKeys),
}
if idx, ok := existing[upstreamKey]; ok {
h.cfg.AmpCode.UpstreamAPIKeys[idx] = normalizedEntry
} else {
h.cfg.AmpCode.UpstreamAPIKeys = append(h.cfg.AmpCode.UpstreamAPIKeys, normalizedEntry)
existing[upstreamKey] = len(h.cfg.AmpCode.UpstreamAPIKeys) - 1
}
}
h.persist(c)
}
// DeleteAmpUpstreamAPIKeys removes specified upstream API keys entries.
// Body must be JSON: {"value": ["", ...]}.
// If "value" is an empty array, clears all entries.
// If JSON is invalid or "value" is missing/null, returns 400 and does not persist any change.
func (h *Handler) DeleteAmpUpstreamAPIKeys(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Value == nil {
c.JSON(400, gin.H{"error": "missing value"})
return
}
// Empty array means clear all
if len(body.Value) == 0 {
h.cfg.AmpCode.UpstreamAPIKeys = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, key := range body.Value {
trimmed := strings.TrimSpace(key)
if trimmed == "" {
continue
}
toRemove[trimmed] = true
}
if len(toRemove) == 0 {
c.JSON(400, gin.H{"error": "empty value"})
return
}
newEntries := make([]config.AmpUpstreamAPIKeyEntry, 0, len(h.cfg.AmpCode.UpstreamAPIKeys))
for _, entry := range h.cfg.AmpCode.UpstreamAPIKeys {
if !toRemove[strings.TrimSpace(entry.UpstreamAPIKey)] {
newEntries = append(newEntries, entry)
}
}
h.cfg.AmpCode.UpstreamAPIKeys = newEntries
h.persist(c)
}
// normalizeAmpUpstreamAPIKeyEntries normalizes a list of upstream API key entries.
func normalizeAmpUpstreamAPIKeyEntries(entries []config.AmpUpstreamAPIKeyEntry) []config.AmpUpstreamAPIKeyEntry {
if len(entries) == 0 {
return nil
}
out := make([]config.AmpUpstreamAPIKeyEntry, 0, len(entries))
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
apiKeys := normalizeAPIKeysList(entry.APIKeys)
out = append(out, config.AmpUpstreamAPIKeyEntry{
UpstreamAPIKey: upstreamKey,
APIKeys: apiKeys,
})
}
if len(out) == 0 {
return nil
}
return out
}
// normalizeAPIKeysList trims and filters empty strings from a list of API keys.
func normalizeAPIKeysList(keys []string) []string {
if len(keys) == 0 {
return nil
}
out := make([]string, 0, len(keys))
for _, k := range keys {
trimmed := strings.TrimSpace(k)
if trimmed != "" {
out = append(out, trimmed)
}
}
if len(out) == 0 {
return nil
}
return out
}
================================================
FILE: internal/api/handlers/management/handler.go
================================================
// Package management provides the management API handlers and middleware
// for configuring the server and managing auth files.
package management
import (
"crypto/subtle"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"golang.org/x/crypto/bcrypt"
)
type attemptInfo struct {
count int
blockedUntil time.Time
lastActivity time.Time // track last activity for cleanup
}
// attemptCleanupInterval controls how often stale IP entries are purged
const attemptCleanupInterval = 1 * time.Hour
// attemptMaxIdleTime controls how long an IP can be idle before cleanup
const attemptMaxIdleTime = 2 * time.Hour
// Handler aggregates config reference, persistence path and helpers.
type Handler struct {
cfg *config.Config
configFilePath string
mu sync.Mutex
attemptsMu sync.Mutex
failedAttempts map[string]*attemptInfo // keyed by client IP
authManager *coreauth.Manager
usageStats *usage.RequestStatistics
tokenStore coreauth.Store
localPassword string
allowRemoteOverride bool
envSecret string
logDir string
postAuthHook coreauth.PostAuthHook
}
// NewHandler creates a new management handler instance.
func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Manager) *Handler {
envSecret, _ := os.LookupEnv("MANAGEMENT_PASSWORD")
envSecret = strings.TrimSpace(envSecret)
h := &Handler{
cfg: cfg,
configFilePath: configFilePath,
failedAttempts: make(map[string]*attemptInfo),
authManager: manager,
usageStats: usage.GetRequestStatistics(),
tokenStore: sdkAuth.GetTokenStore(),
allowRemoteOverride: envSecret != "",
envSecret: envSecret,
}
h.startAttemptCleanup()
return h
}
// startAttemptCleanup launches a background goroutine that periodically
// removes stale IP entries from failedAttempts to prevent memory leaks.
func (h *Handler) startAttemptCleanup() {
go func() {
ticker := time.NewTicker(attemptCleanupInterval)
defer ticker.Stop()
for range ticker.C {
h.purgeStaleAttempts()
}
}()
}
// purgeStaleAttempts removes IP entries that have been idle beyond attemptMaxIdleTime
// and whose ban (if any) has expired.
func (h *Handler) purgeStaleAttempts() {
now := time.Now()
h.attemptsMu.Lock()
defer h.attemptsMu.Unlock()
for ip, ai := range h.failedAttempts {
// Skip if still banned
if !ai.blockedUntil.IsZero() && now.Before(ai.blockedUntil) {
continue
}
// Remove if idle too long
if now.Sub(ai.lastActivity) > attemptMaxIdleTime {
delete(h.failedAttempts, ip)
}
}
}
// NewHandler creates a new management handler instance.
func NewHandlerWithoutConfigFilePath(cfg *config.Config, manager *coreauth.Manager) *Handler {
return NewHandler(cfg, "", manager)
}
// SetConfig updates the in-memory config reference when the server hot-reloads.
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }
// SetAuthManager updates the auth manager reference used by management endpoints.
func (h *Handler) SetAuthManager(manager *coreauth.Manager) { h.authManager = manager }
// SetUsageStatistics allows replacing the usage statistics reference.
func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats }
// SetLocalPassword configures the runtime-local password accepted for localhost requests.
func (h *Handler) SetLocalPassword(password string) { h.localPassword = password }
// SetLogDirectory updates the directory where main.log should be looked up.
func (h *Handler) SetLogDirectory(dir string) {
if dir == "" {
return
}
if !filepath.IsAbs(dir) {
if abs, err := filepath.Abs(dir); err == nil {
dir = abs
}
}
h.logDir = dir
}
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
h.postAuthHook = hook
}
// Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true.
func (h *Handler) Middleware() gin.HandlerFunc {
const maxFailures = 5
const banDuration = 30 * time.Minute
return func(c *gin.Context) {
c.Header("X-CPA-VERSION", buildinfo.Version)
c.Header("X-CPA-COMMIT", buildinfo.Commit)
c.Header("X-CPA-BUILD-DATE", buildinfo.BuildDate)
clientIP := c.ClientIP()
localClient := clientIP == "127.0.0.1" || clientIP == "::1"
cfg := h.cfg
var (
allowRemote bool
secretHash string
)
if cfg != nil {
allowRemote = cfg.RemoteManagement.AllowRemote
secretHash = cfg.RemoteManagement.SecretKey
}
if h.allowRemoteOverride {
allowRemote = true
}
envSecret := h.envSecret
fail := func() {}
if !localClient {
h.attemptsMu.Lock()
ai := h.failedAttempts[clientIP]
if ai != nil {
if !ai.blockedUntil.IsZero() {
if time.Now().Before(ai.blockedUntil) {
remaining := time.Until(ai.blockedUntil).Round(time.Second)
h.attemptsMu.Unlock()
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)})
return
}
// Ban expired, reset state
ai.blockedUntil = time.Time{}
ai.count = 0
}
}
h.attemptsMu.Unlock()
if !allowRemote {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management disabled"})
return
}
fail = func() {
h.attemptsMu.Lock()
aip := h.failedAttempts[clientIP]
if aip == nil {
aip = &attemptInfo{}
h.failedAttempts[clientIP] = aip
}
aip.count++
aip.lastActivity = time.Now()
if aip.count >= maxFailures {
aip.blockedUntil = time.Now().Add(banDuration)
aip.count = 0
}
h.attemptsMu.Unlock()
}
}
if secretHash == "" && envSecret == "" {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management key not set"})
return
}
// Accept either Authorization: Bearer or X-Management-Key
var provided string
if ah := c.GetHeader("Authorization"); ah != "" {
parts := strings.SplitN(ah, " ", 2)
if len(parts) == 2 && strings.ToLower(parts[0]) == "bearer" {
provided = parts[1]
} else {
provided = ah
}
}
if provided == "" {
provided = c.GetHeader("X-Management-Key")
}
if provided == "" {
if !localClient {
fail()
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing management key"})
return
}
if localClient {
if lp := h.localPassword; lp != "" {
if subtle.ConstantTimeCompare([]byte(provided), []byte(lp)) == 1 {
c.Next()
return
}
}
}
if envSecret != "" && subtle.ConstantTimeCompare([]byte(provided), []byte(envSecret)) == 1 {
if !localClient {
h.attemptsMu.Lock()
if ai := h.failedAttempts[clientIP]; ai != nil {
ai.count = 0
ai.blockedUntil = time.Time{}
}
h.attemptsMu.Unlock()
}
c.Next()
return
}
if secretHash == "" || bcrypt.CompareHashAndPassword([]byte(secretHash), []byte(provided)) != nil {
if !localClient {
fail()
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid management key"})
return
}
if !localClient {
h.attemptsMu.Lock()
if ai := h.failedAttempts[clientIP]; ai != nil {
ai.count = 0
ai.blockedUntil = time.Time{}
}
h.attemptsMu.Unlock()
}
c.Next()
}
}
// persist saves the current in-memory config to disk.
func (h *Handler) persist(c *gin.Context) bool {
h.mu.Lock()
defer h.mu.Unlock()
// Preserve comments when writing
if err := config.SaveConfigPreserveComments(h.configFilePath, h.cfg); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save config: %v", err)})
return false
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
return true
}
// Helper methods for simple types
func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
var body struct {
Value *bool `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
set(*body.Value)
h.persist(c)
}
func (h *Handler) updateIntField(c *gin.Context, set func(int)) {
var body struct {
Value *int `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
set(*body.Value)
h.persist(c)
}
func (h *Handler) updateStringField(c *gin.Context, set func(string)) {
var body struct {
Value *string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
set(*body.Value)
h.persist(c)
}
================================================
FILE: internal/api/handlers/management/logs.go
================================================
package management
import (
"bufio"
"fmt"
"math"
"net/http"
"os"
"path/filepath"
"sort"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
const (
defaultLogFileName = "main.log"
logScannerInitialBuffer = 64 * 1024
logScannerMaxBuffer = 8 * 1024 * 1024
)
// GetLogs returns log lines with optional incremental loading.
func (h *Handler) GetLogs(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
if !h.cfg.LoggingToFile {
c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"})
return
}
logDir := h.logDirectory()
if strings.TrimSpace(logDir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
files, err := h.collectLogFiles(logDir)
if err != nil {
if os.IsNotExist(err) {
cutoff := parseCutoff(c.Query("after"))
c.JSON(http.StatusOK, gin.H{
"lines": []string{},
"line-count": 0,
"latest-timestamp": cutoff,
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log files: %v", err)})
return
}
limit, errLimit := parseLimit(c.Query("limit"))
if errLimit != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid limit: %v", errLimit)})
return
}
cutoff := parseCutoff(c.Query("after"))
acc := newLogAccumulator(cutoff, limit)
for i := range files {
if errProcess := acc.consumeFile(files[i]); errProcess != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file %s: %v", files[i], errProcess)})
return
}
}
lines, total, latest := acc.result()
if latest == 0 || latest < cutoff {
latest = cutoff
}
c.JSON(http.StatusOK, gin.H{
"lines": lines,
"line-count": total,
"latest-timestamp": latest,
})
}
// DeleteLogs removes all rotated log files and truncates the active log.
func (h *Handler) DeleteLogs(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
if !h.cfg.LoggingToFile {
c.JSON(http.StatusBadRequest, gin.H{"error": "logging to file disabled"})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
return
}
removed := 0
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
fullPath := filepath.Join(dir, name)
if name == defaultLogFileName {
if errTrunc := os.Truncate(fullPath, 0); errTrunc != nil && !os.IsNotExist(errTrunc) {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to truncate log file: %v", errTrunc)})
return
}
continue
}
if isRotatedLogFile(name) {
if errRemove := os.Remove(fullPath); errRemove != nil && !os.IsNotExist(errRemove) {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to remove %s: %v", name, errRemove)})
return
}
removed++
}
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "Logs cleared successfully",
"removed": removed,
})
}
// GetRequestErrorLogs lists error request log files when RequestLog is disabled.
// It returns an empty list when RequestLog is enabled.
func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
if h.cfg.RequestLog {
c.JSON(http.StatusOK, gin.H{"files": []any{}})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusOK, gin.H{"files": []any{}})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list request error logs: %v", err)})
return
}
type errorLog struct {
Name string `json:"name"`
Size int64 `json:"size"`
Modified int64 `json:"modified"`
}
files := make([]errorLog, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
continue
}
info, errInfo := entry.Info()
if errInfo != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log info for %s: %v", name, errInfo)})
return
}
files = append(files, errorLog{
Name: name,
Size: info.Size(),
Modified: info.ModTime().Unix(),
})
}
sort.Slice(files, func(i, j int) bool { return files[i].Modified > files[j].Modified })
c.JSON(http.StatusOK, gin.H{"files": files})
}
// GetRequestLogByID finds and downloads a request log file by its request ID.
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
func (h *Handler) GetRequestLogByID(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
requestID := strings.TrimSpace(c.Param("id"))
if requestID == "" {
requestID = strings.TrimSpace(c.Query("id"))
}
if requestID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
return
}
if strings.ContainsAny(requestID, "/\\") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
return
}
suffix := "-" + requestID + ".log"
var matchedFile string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, suffix) {
matchedFile = name
break
}
}
if matchedFile == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
return
}
dirAbs, errAbs := filepath.Abs(dir)
if errAbs != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
return
}
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
prefix := dirAbs + string(os.PathSeparator)
if !strings.HasPrefix(fullPath, prefix) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
return
}
info, errStat := os.Stat(fullPath)
if errStat != nil {
if os.IsNotExist(errStat) {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
return
}
if info.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
return
}
c.FileAttachment(fullPath, matchedFile)
}
// DownloadRequestErrorLog downloads a specific error request log file by name.
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
name := strings.TrimSpace(c.Param("name"))
if name == "" || strings.Contains(name, "/") || strings.Contains(name, "\\") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file name"})
return
}
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
dirAbs, errAbs := filepath.Abs(dir)
if errAbs != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
return
}
fullPath := filepath.Clean(filepath.Join(dirAbs, name))
prefix := dirAbs + string(os.PathSeparator)
if !strings.HasPrefix(fullPath, prefix) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
return
}
info, errStat := os.Stat(fullPath)
if errStat != nil {
if os.IsNotExist(errStat) {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
return
}
if info.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
return
}
c.FileAttachment(fullPath, name)
}
func (h *Handler) logDirectory() string {
if h == nil {
return ""
}
if h.logDir != "" {
return h.logDir
}
return logging.ResolveLogDirectory(h.cfg)
}
func (h *Handler) collectLogFiles(dir string) ([]string, error) {
entries, err := os.ReadDir(dir)
if err != nil {
return nil, err
}
type candidate struct {
path string
order int64
}
cands := make([]candidate, 0, len(entries))
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if name == defaultLogFileName {
cands = append(cands, candidate{path: filepath.Join(dir, name), order: 0})
continue
}
if order, ok := rotationOrder(name); ok {
cands = append(cands, candidate{path: filepath.Join(dir, name), order: order})
}
}
if len(cands) == 0 {
return []string{}, nil
}
sort.Slice(cands, func(i, j int) bool { return cands[i].order < cands[j].order })
paths := make([]string, 0, len(cands))
for i := len(cands) - 1; i >= 0; i-- {
paths = append(paths, cands[i].path)
}
return paths, nil
}
type logAccumulator struct {
cutoff int64
limit int
lines []string
total int
latest int64
include bool
}
func newLogAccumulator(cutoff int64, limit int) *logAccumulator {
capacity := 256
if limit > 0 && limit < capacity {
capacity = limit
}
return &logAccumulator{
cutoff: cutoff,
limit: limit,
lines: make([]string, 0, capacity),
}
}
func (acc *logAccumulator) consumeFile(path string) error {
file, err := os.Open(path)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
defer func() {
_ = file.Close()
}()
scanner := bufio.NewScanner(file)
buf := make([]byte, 0, logScannerInitialBuffer)
scanner.Buffer(buf, logScannerMaxBuffer)
for scanner.Scan() {
acc.addLine(scanner.Text())
}
if errScan := scanner.Err(); errScan != nil {
return errScan
}
return nil
}
func (acc *logAccumulator) addLine(raw string) {
line := strings.TrimRight(raw, "\r")
acc.total++
ts := parseTimestamp(line)
if ts > acc.latest {
acc.latest = ts
}
if ts > 0 {
acc.include = acc.cutoff == 0 || ts > acc.cutoff
if acc.cutoff == 0 || acc.include {
acc.append(line)
}
return
}
if acc.cutoff == 0 || acc.include {
acc.append(line)
}
}
func (acc *logAccumulator) append(line string) {
acc.lines = append(acc.lines, line)
if acc.limit > 0 && len(acc.lines) > acc.limit {
acc.lines = acc.lines[len(acc.lines)-acc.limit:]
}
}
func (acc *logAccumulator) result() ([]string, int, int64) {
if acc.lines == nil {
acc.lines = []string{}
}
return acc.lines, acc.total, acc.latest
}
func parseCutoff(raw string) int64 {
value := strings.TrimSpace(raw)
if value == "" {
return 0
}
ts, err := strconv.ParseInt(value, 10, 64)
if err != nil || ts <= 0 {
return 0
}
return ts
}
func parseLimit(raw string) (int, error) {
value := strings.TrimSpace(raw)
if value == "" {
return 0, nil
}
limit, err := strconv.Atoi(value)
if err != nil {
return 0, fmt.Errorf("must be a positive integer")
}
if limit <= 0 {
return 0, fmt.Errorf("must be greater than zero")
}
return limit, nil
}
func parseTimestamp(line string) int64 {
if strings.HasPrefix(line, "[") {
line = line[1:]
}
if len(line) < 19 {
return 0
}
candidate := line[:19]
t, err := time.ParseInLocation("2006-01-02 15:04:05", candidate, time.Local)
if err != nil {
return 0
}
return t.Unix()
}
func isRotatedLogFile(name string) bool {
if _, ok := rotationOrder(name); ok {
return true
}
return false
}
func rotationOrder(name string) (int64, bool) {
if order, ok := numericRotationOrder(name); ok {
return order, true
}
if order, ok := timestampRotationOrder(name); ok {
return order, true
}
return 0, false
}
func numericRotationOrder(name string) (int64, bool) {
if !strings.HasPrefix(name, defaultLogFileName+".") {
return 0, false
}
suffix := strings.TrimPrefix(name, defaultLogFileName+".")
if suffix == "" {
return 0, false
}
n, err := strconv.Atoi(suffix)
if err != nil {
return 0, false
}
return int64(n), true
}
func timestampRotationOrder(name string) (int64, bool) {
ext := filepath.Ext(defaultLogFileName)
base := strings.TrimSuffix(defaultLogFileName, ext)
if base == "" {
return 0, false
}
prefix := base + "-"
if !strings.HasPrefix(name, prefix) {
return 0, false
}
clean := strings.TrimPrefix(name, prefix)
if strings.HasSuffix(clean, ".gz") {
clean = strings.TrimSuffix(clean, ".gz")
}
if ext != "" {
if !strings.HasSuffix(clean, ext) {
return 0, false
}
clean = strings.TrimSuffix(clean, ext)
}
if clean == "" {
return 0, false
}
if idx := strings.IndexByte(clean, '.'); idx != -1 {
clean = clean[:idx]
}
parsed, err := time.ParseInLocation("2006-01-02T15-04-05", clean, time.Local)
if err != nil {
return 0, false
}
return math.MaxInt64 - parsed.Unix(), true
}
================================================
FILE: internal/api/handlers/management/model_definitions.go
================================================
package management
import (
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
// GetStaticModelDefinitions returns static model metadata for a given channel.
// Channel is provided via path param (:channel) or query param (?channel=...).
func (h *Handler) GetStaticModelDefinitions(c *gin.Context) {
channel := strings.TrimSpace(c.Param("channel"))
if channel == "" {
channel = strings.TrimSpace(c.Query("channel"))
}
if channel == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "channel is required"})
return
}
models := registry.GetStaticModelDefinitionsByChannel(channel)
if models == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "unknown channel", "channel": channel})
return
}
c.JSON(http.StatusOK, gin.H{
"channel": strings.ToLower(strings.TrimSpace(channel)),
"models": models,
})
}
================================================
FILE: internal/api/handlers/management/oauth_callback.go
================================================
package management
import (
"errors"
"net/http"
"net/url"
"strings"
"github.com/gin-gonic/gin"
)
type oauthCallbackRequest struct {
Provider string `json:"provider"`
RedirectURL string `json:"redirect_url"`
Code string `json:"code"`
State string `json:"state"`
Error string `json:"error"`
}
func (h *Handler) PostOAuthCallback(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
return
}
var req oauthCallbackRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
return
}
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
return
}
state := strings.TrimSpace(req.State)
code := strings.TrimSpace(req.Code)
errMsg := strings.TrimSpace(req.Error)
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
u, errParse := url.Parse(rawRedirect)
if errParse != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
return
}
q := u.Query()
if state == "" {
state = strings.TrimSpace(q.Get("state"))
}
if code == "" {
code = strings.TrimSpace(q.Get("code"))
}
if errMsg == "" {
errMsg = strings.TrimSpace(q.Get("error"))
if errMsg == "" {
errMsg = strings.TrimSpace(q.Get("error_description"))
}
}
}
if state == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
return
}
if err := ValidateOAuthState(state); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
return
}
if code == "" && errMsg == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
return
}
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
if !ok {
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
return
}
if sessionStatus != "" {
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
return
}
if !strings.EqualFold(sessionProvider, canonicalProvider) {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
return
}
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
if errors.Is(errWrite, errOAuthSessionNotPending) {
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
================================================
FILE: internal/api/handlers/management/oauth_sessions.go
================================================
package management
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
const (
oauthSessionTTL = 10 * time.Minute
maxOAuthStateLength = 128
)
var (
errInvalidOAuthState = errors.New("invalid oauth state")
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
errOAuthSessionNotPending = errors.New("oauth session is not pending")
)
type oauthSession struct {
Provider string
Status string
CreatedAt time.Time
ExpiresAt time.Time
}
type oauthSessionStore struct {
mu sync.RWMutex
ttl time.Duration
sessions map[string]oauthSession
}
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
if ttl <= 0 {
ttl = oauthSessionTTL
}
return &oauthSessionStore{
ttl: ttl,
sessions: make(map[string]oauthSession),
}
}
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
for state, session := range s.sessions {
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
delete(s.sessions, state)
}
}
}
func (s *oauthSessionStore) Register(state, provider string) {
state = strings.TrimSpace(state)
provider = strings.ToLower(strings.TrimSpace(provider))
if state == "" || provider == "" {
return
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
s.sessions[state] = oauthSession{
Provider: provider,
Status: "",
CreatedAt: now,
ExpiresAt: now.Add(s.ttl),
}
}
func (s *oauthSessionStore) SetError(state, message string) {
state = strings.TrimSpace(state)
message = strings.TrimSpace(message)
if state == "" {
return
}
if message == "" {
message = "Authentication failed"
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
session, ok := s.sessions[state]
if !ok {
return
}
session.Status = message
session.ExpiresAt = now.Add(s.ttl)
s.sessions[state] = session
}
func (s *oauthSessionStore) Complete(state string) {
state = strings.TrimSpace(state)
if state == "" {
return
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
delete(s.sessions, state)
}
func (s *oauthSessionStore) CompleteProvider(provider string) int {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return 0
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
removed := 0
for state, session := range s.sessions {
if strings.EqualFold(session.Provider, provider) {
delete(s.sessions, state)
removed++
}
}
return removed
}
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
state = strings.TrimSpace(state)
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
session, ok := s.sessions[state]
return session, ok
}
func (s *oauthSessionStore) IsPending(state, provider string) bool {
state = strings.TrimSpace(state)
provider = strings.ToLower(strings.TrimSpace(provider))
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
session, ok := s.sessions[state]
if !ok {
return false
}
if session.Status != "" {
return false
}
if provider == "" {
return true
}
return strings.EqualFold(session.Provider, provider)
}
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
func CompleteOAuthSessionsByProvider(provider string) int {
return oauthSessions.CompleteProvider(provider)
}
func GetOAuthSession(state string) (provider string, status string, ok bool) {
session, ok := oauthSessions.Get(state)
if !ok {
return "", "", false
}
return session.Provider, session.Status, true
}
func IsOAuthSessionPending(state, provider string) bool {
return oauthSessions.IsPending(state, provider)
}
func ValidateOAuthState(state string) error {
trimmed := strings.TrimSpace(state)
if trimmed == "" {
return fmt.Errorf("%w: empty", errInvalidOAuthState)
}
if len(trimmed) > maxOAuthStateLength {
return fmt.Errorf("%w: too long", errInvalidOAuthState)
}
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
}
if strings.Contains(trimmed, "..") {
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
}
for _, r := range trimmed {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case r == '-' || r == '_' || r == '.':
default:
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
}
}
return nil
}
func NormalizeOAuthProvider(provider string) (string, error) {
switch strings.ToLower(strings.TrimSpace(provider)) {
case "anthropic", "claude":
return "anthropic", nil
case "codex", "openai":
return "codex", nil
case "gemini", "google":
return "gemini", nil
case "iflow", "i-flow":
return "iflow", nil
case "antigravity", "anti-gravity":
return "antigravity", nil
case "qwen":
return "qwen", nil
default:
return "", errUnsupportedOAuthFlow
}
}
type oauthCallbackFilePayload struct {
Code string `json:"code"`
State string `json:"state"`
Error string `json:"error"`
}
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
if strings.TrimSpace(authDir) == "" {
return "", fmt.Errorf("auth dir is empty")
}
canonicalProvider, err := NormalizeOAuthProvider(provider)
if err != nil {
return "", err
}
if err := ValidateOAuthState(state); err != nil {
return "", err
}
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
filePath := filepath.Join(authDir, fileName)
payload := oauthCallbackFilePayload{
Code: strings.TrimSpace(code),
State: strings.TrimSpace(state),
Error: strings.TrimSpace(errorMessage),
}
data, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
}
if err := os.WriteFile(filePath, data, 0o600); err != nil {
return "", fmt.Errorf("write oauth callback file: %w", err)
}
return filePath, nil
}
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
canonicalProvider, err := NormalizeOAuthProvider(provider)
if err != nil {
return "", err
}
if !IsOAuthSessionPending(state, canonicalProvider) {
return "", errOAuthSessionNotPending
}
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
}
================================================
FILE: internal/api/handlers/management/quota.go
================================================
package management
import "github.com/gin-gonic/gin"
// Quota exceeded toggles
func (h *Handler) GetSwitchProject(c *gin.Context) {
c.JSON(200, gin.H{"switch-project": h.cfg.QuotaExceeded.SwitchProject})
}
func (h *Handler) PutSwitchProject(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchProject = v })
}
func (h *Handler) GetSwitchPreviewModel(c *gin.Context) {
c.JSON(200, gin.H{"switch-preview-model": h.cfg.QuotaExceeded.SwitchPreviewModel})
}
func (h *Handler) PutSwitchPreviewModel(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchPreviewModel = v })
}
================================================
FILE: internal/api/handlers/management/test_store_test.go
================================================
package management
import (
"context"
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
type memoryAuthStore struct {
mu sync.Mutex
items map[string]*coreauth.Auth
}
func (s *memoryAuthStore) List(_ context.Context) ([]*coreauth.Auth, error) {
s.mu.Lock()
defer s.mu.Unlock()
out := make([]*coreauth.Auth, 0, len(s.items))
for _, item := range s.items {
out = append(out, item)
}
return out, nil
}
func (s *memoryAuthStore) Save(_ context.Context, auth *coreauth.Auth) (string, error) {
if auth == nil {
return "", nil
}
s.mu.Lock()
defer s.mu.Unlock()
if s.items == nil {
s.items = make(map[string]*coreauth.Auth)
}
s.items[auth.ID] = auth
return auth.ID, nil
}
func (s *memoryAuthStore) Delete(_ context.Context, id string) error {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.items, id)
return nil
}
func (s *memoryAuthStore) SetBaseDir(string) {}
================================================
FILE: internal/api/handlers/management/usage.go
================================================
package management
import (
"encoding/json"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
)
type usageExportPayload struct {
Version int `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
type usageImportPayload struct {
Version int `json:"version"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
// GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, gin.H{
"usage": snapshot,
"failed_requests": snapshot.FailureCount,
})
}
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, usageExportPayload{
Version: 1,
ExportedAt: time.Now().UTC(),
Usage: snapshot,
})
}
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
if h == nil || h.usageStats == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
return
}
data, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
var payload usageImportPayload
if err := json.Unmarshal(data, &payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
return
}
if payload.Version != 0 && payload.Version != 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
return
}
result := h.usageStats.MergeSnapshot(payload.Usage)
snapshot := h.usageStats.Snapshot()
c.JSON(http.StatusOK, gin.H{
"added": result.Added,
"skipped": result.Skipped,
"total_requests": snapshot.TotalRequests,
"failed_requests": snapshot.FailureCount,
})
}
================================================
FILE: internal/api/handlers/management/vertex_import.go
================================================
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// ImportVertexCredential handles uploading a Vertex service account JSON and saving it as an auth record.
func (h *Handler) ImportVertexCredential(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "config unavailable"})
return
}
if h.cfg.AuthDir == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth directory not configured"})
return
}
fileHeader, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "file required"})
return
}
file, err := fileHeader.Open()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
return
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
return
}
var serviceAccount map[string]any
if err := json.Unmarshal(data, &serviceAccount); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "message": err.Error()})
return
}
normalizedSA, err := vertex.NormalizeServiceAccountMap(serviceAccount)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid service account", "message": err.Error()})
return
}
serviceAccount = normalizedSA
projectID := strings.TrimSpace(valueAsString(serviceAccount["project_id"]))
if projectID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "project_id missing"})
return
}
email := strings.TrimSpace(valueAsString(serviceAccount["client_email"]))
location := strings.TrimSpace(c.PostForm("location"))
if location == "" {
location = strings.TrimSpace(c.Query("location"))
}
if location == "" {
location = "us-central1"
}
fileName := fmt.Sprintf("vertex-%s.json", sanitizeVertexFilePart(projectID))
label := labelForVertex(projectID, email)
storage := &vertex.VertexCredentialStorage{
ServiceAccount: serviceAccount,
ProjectID: projectID,
Email: email,
Location: location,
Type: "vertex",
}
metadata := map[string]any{
"service_account": serviceAccount,
"project_id": projectID,
"email": email,
"location": location,
"type": "vertex",
"label": label,
}
record := &coreauth.Auth{
ID: fileName,
Provider: "vertex",
FileName: fileName,
Storage: storage,
Label: label,
Metadata: metadata,
}
ctx := context.Background()
if reqCtx := c.Request.Context(); reqCtx != nil {
ctx = reqCtx
}
savedPath, err := h.saveTokenRecord(ctx, record)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "save_failed", "message": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"auth-file": savedPath,
"project_id": projectID,
"email": email,
"location": location,
})
}
func valueAsString(v any) string {
if v == nil {
return ""
}
switch t := v.(type) {
case string:
return t
default:
return fmt.Sprint(t)
}
}
func sanitizeVertexFilePart(s string) string {
out := strings.TrimSpace(s)
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
for i := 0; i < len(replacers); i += 2 {
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
}
if out == "" {
return "vertex"
}
return out
}
func labelForVertex(projectID, email string) string {
p := strings.TrimSpace(projectID)
e := strings.TrimSpace(email)
if p != "" && e != "" {
return fmt.Sprintf("%s (%s)", p, e)
}
if p != "" {
return p
}
if e != "" {
return e
}
return "vertex"
}
================================================
FILE: internal/api/middleware/request_logging.go
================================================
// Package middleware provides HTTP middleware components for the CLI Proxy API server.
// This file contains the request logging middleware that captures comprehensive
// request and response data when enabled through configuration.
package middleware
import (
"bytes"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
// It captures detailed information about the request and response, including headers and body,
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return func(c *gin.Context) {
if logger == nil {
c.Next()
return
}
if shouldSkipMethodForRequestLogging(c.Request) {
c.Next()
return
}
path := c.Request.URL.Path
if !shouldLogRequest(path) {
c.Next()
return
}
loggerEnabled := logger.IsEnabled()
// Capture request information
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
if err != nil {
// Log error but continue processing
// In a real implementation, you might want to use a proper logger here
c.Next()
return
}
// Create response writer wrapper
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
if !loggerEnabled {
wrapper.logOnErrorOnly = true
}
c.Writer = wrapper
// Process the request
c.Next()
// Finalize logging after request processing
if err = wrapper.Finalize(c); err != nil {
// Log error but don't interrupt the response
// In a real implementation, you might want to use a proper logger here
}
}
}
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
if req == nil {
return true
}
if req.Method != http.MethodGet {
return false
}
return !isResponsesWebsocketUpgrade(req)
}
func isResponsesWebsocketUpgrade(req *http.Request) bool {
if req == nil || req.URL == nil {
return false
}
if req.URL.Path != "/v1/responses" {
return false
}
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
}
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
if loggerEnabled {
return true
}
if req == nil || req.Body == nil {
return false
}
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
if strings.HasPrefix(contentType, "multipart/form-data") {
return false
}
if req.ContentLength <= 0 {
return false
}
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
}
// captureRequestInfo extracts relevant information from the incoming HTTP request.
// It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
// Capture URL with sensitive query parameters masked
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
url := c.Request.URL.Path
if maskedQuery != "" {
url += "?" + maskedQuery
}
// Capture method
method := c.Request.Method
// Capture headers
headers := make(map[string][]string)
for key, values := range c.Request.Header {
headers[key] = values
}
// Capture request body
var body []byte
if captureBody && c.Request.Body != nil {
// Read the body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, err
}
// Restore the body for the actual request processing
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
body = bodyBytes
}
return &RequestInfo{
URL: url,
Method: method,
Headers: headers,
Body: body,
RequestID: logging.GetGinRequestID(c),
Timestamp: time.Now(),
}, nil
}
// shouldLogRequest determines whether the request should be logged.
// It skips management endpoints to avoid leaking secrets but allows
// all other routes, including module-provided ones, to honor request-log.
func shouldLogRequest(path string) bool {
if strings.HasPrefix(path, "/v0/management") || strings.HasPrefix(path, "/management") {
return false
}
if strings.HasPrefix(path, "/api") {
return strings.HasPrefix(path, "/api/provider")
}
return true
}
================================================
FILE: internal/api/middleware/request_logging_test.go
================================================
package middleware
import (
"io"
"net/http"
"net/url"
"strings"
"testing"
)
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
tests := []struct {
name string
req *http.Request
skip bool
}{
{
name: "nil request",
req: nil,
skip: true,
},
{
name: "post request should not skip",
req: &http.Request{
Method: http.MethodPost,
URL: &url.URL{Path: "/v1/responses"},
},
skip: false,
},
{
name: "plain get should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/models"},
Header: http.Header{},
},
skip: true,
},
{
name: "responses websocket upgrade should not skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{"Upgrade": []string{"websocket"}},
},
skip: false,
},
{
name: "responses get without upgrade should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{},
},
skip: true,
},
}
for i := range tests {
got := shouldSkipMethodForRequestLogging(tests[i].req)
if got != tests[i].skip {
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
}
}
}
func TestShouldCaptureRequestBody(t *testing.T) {
tests := []struct {
name string
loggerEnabled bool
req *http.Request
want bool
}{
{
name: "logger enabled always captures",
loggerEnabled: true,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "nil request",
loggerEnabled: false,
req: nil,
want: false,
},
{
name: "small known size json in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: 2,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "large known size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "unknown size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "multipart skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: 1,
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
},
want: false,
},
}
for i := range tests {
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
if got != tests[i].want {
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
}
}
}
================================================
FILE: internal/api/middleware/response_writer.go
================================================
// Package middleware provides Gin HTTP middleware for the CLI Proxy API server.
// It includes a sophisticated response writer wrapper designed to capture and log request and response data,
// including support for streaming responses, without impacting latency.
package middleware
import (
"bytes"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
RequestID string // RequestID is the unique identifier for the request.
Timestamp time.Time // Timestamp is when the request was received.
}
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
type ResponseWriterWrapper struct {
gin.ResponseWriter
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
logger logging.RequestLogger // logger is the instance of the request logger service.
requestInfo *RequestInfo // requestInfo holds the details of the original request.
statusCode int // statusCode stores the HTTP status code of the response.
headers map[string][]string // headers stores the response headers.
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
firstChunkTimestamp time.Time // firstChunkTimestamp captures TTFB for streaming responses.
}
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
// It takes the original gin.ResponseWriter, a logger instance, and request information.
//
// Parameters:
// - w: The original gin.ResponseWriter to wrap.
// - logger: The logging service to use for recording requests.
// - requestInfo: The pre-captured information about the incoming request.
//
// Returns:
// - A pointer to a new ResponseWriterWrapper.
func NewResponseWriterWrapper(w gin.ResponseWriter, logger logging.RequestLogger, requestInfo *RequestInfo) *ResponseWriterWrapper {
return &ResponseWriterWrapper{
ResponseWriter: w,
body: &bytes.Buffer{},
logger: logger,
requestInfo: requestInfo,
headers: make(map[string][]string),
}
}
// Write wraps the underlying ResponseWriter's Write method to capture response data.
// For non-streaming responses, it writes to an internal buffer. For streaming responses,
// it sends data chunks to a non-blocking channel for asynchronous logging.
// CRITICAL: This method prioritizes writing to the client to ensure zero latency,
// handling logging operations subsequently.
func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
// Ensure headers are captured before first write
// This is critical because Write() may trigger WriteHeader() internally
w.ensureHeadersCaptured()
// CRITICAL: Write to client first (zero latency)
n, err := w.ResponseWriter.Write(data)
// THEN: Handle logging based on response type
if w.isStreaming && w.chunkChannel != nil {
// Capture TTFB on first chunk (synchronous, before async channel send)
if w.firstChunkTimestamp.IsZero() {
w.firstChunkTimestamp = time.Now()
}
// For streaming responses: Send to async logging channel (non-blocking)
select {
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
default: // Channel full, skip logging to avoid blocking
}
return n, err
}
if w.shouldBufferResponseBody() {
w.body.Write(data)
}
return n, err
}
func (w *ResponseWriterWrapper) shouldBufferResponseBody() bool {
if w.logger != nil && w.logger.IsEnabled() {
return true
}
if !w.logOnErrorOnly {
return false
}
status := w.statusCode
if status == 0 {
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok && statusWriter != nil {
status = statusWriter.Status()
} else {
status = http.StatusOK
}
}
return status >= http.StatusBadRequest
}
// WriteString wraps the underlying ResponseWriter's WriteString method to capture response data.
// Some handlers (and fmt/io helpers) write via io.StringWriter; without this override, those writes
// bypass Write() and would be missing from request logs.
func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
w.ensureHeadersCaptured()
// CRITICAL: Write to client first (zero latency)
n, err := w.ResponseWriter.WriteString(data)
// THEN: Capture for logging
if w.isStreaming && w.chunkChannel != nil {
// Capture TTFB on first chunk (synchronous, before async channel send)
if w.firstChunkTimestamp.IsZero() {
w.firstChunkTimestamp = time.Now()
}
select {
case w.chunkChannel <- []byte(data):
default:
}
return n, err
}
if w.shouldBufferResponseBody() {
w.body.WriteString(data)
}
return n, err
}
// WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
// It captures the status code, detects if the response is streaming based on the Content-Type header,
// and initializes the appropriate logging mechanism (standard or streaming).
func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
w.statusCode = statusCode
// Capture response headers using the new method
w.captureCurrentHeaders()
// Detect streaming based on Content-Type
contentType := w.ResponseWriter.Header().Get("Content-Type")
w.isStreaming = w.detectStreaming(contentType)
// If streaming, initialize streaming log writer
if w.isStreaming && w.logger.IsEnabled() {
streamWriter, err := w.logger.LogStreamingRequest(
w.requestInfo.URL,
w.requestInfo.Method,
w.requestInfo.Headers,
w.requestInfo.Body,
w.requestInfo.RequestID,
)
if err == nil {
w.streamWriter = streamWriter
w.chunkChannel = make(chan []byte, 100) // Buffered channel for async writes
doneChan := make(chan struct{})
w.streamDone = doneChan
// Start async chunk processor
go w.processStreamingChunks(doneChan)
// Write status immediately
_ = streamWriter.WriteStatus(statusCode, w.headers)
}
}
// Call original WriteHeader
w.ResponseWriter.WriteHeader(statusCode)
}
// ensureHeadersCaptured is a helper function to make sure response headers are captured.
// It is safe to call this method multiple times; it will always refresh the headers
// with the latest state from the underlying ResponseWriter.
func (w *ResponseWriterWrapper) ensureHeadersCaptured() {
// Always capture the current headers to ensure we have the latest state
w.captureCurrentHeaders()
}
// captureCurrentHeaders reads all headers from the underlying ResponseWriter and stores them
// in the wrapper's headers map. It creates copies of the header values to prevent race conditions.
func (w *ResponseWriterWrapper) captureCurrentHeaders() {
// Initialize headers map if needed
if w.headers == nil {
w.headers = make(map[string][]string)
}
// Capture all current headers from the underlying ResponseWriter
for key, values := range w.ResponseWriter.Header() {
// Make a copy of the values slice to avoid reference issues
headerValues := make([]string, len(values))
copy(headerValues, values)
w.headers[key] = headerValues
}
}
// detectStreaming determines if a response should be treated as a streaming response.
// It checks for a "text/event-stream" Content-Type or a '"stream": true'
// field in the original request body.
func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
// Check Content-Type for Server-Sent Events
if strings.Contains(contentType, "text/event-stream") {
return true
}
// If a concrete Content-Type is already set (e.g., application/json for error responses),
// treat it as non-streaming instead of inferring from the request payload.
if strings.TrimSpace(contentType) != "" {
return false
}
// Only fall back to request payload hints when Content-Type is not set yet.
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
}
return false
}
// processStreamingChunks runs in a separate goroutine to process response chunks from the chunkChannel.
// It asynchronously writes each chunk to the streaming log writer.
func (w *ResponseWriterWrapper) processStreamingChunks(done chan struct{}) {
if done == nil {
return
}
defer close(done)
if w.streamWriter == nil || w.chunkChannel == nil {
return
}
for chunk := range w.chunkChannel {
w.streamWriter.WriteChunkAsync(chunk)
}
}
// Finalize completes the logging process for the request and response.
// For streaming responses, it closes the chunk channel and the stream writer.
// For non-streaming responses, it logs the complete request and response details,
// including any API-specific request/response data stored in the Gin context.
func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
if w.logger == nil {
return nil
}
finalStatusCode := w.statusCode
if finalStatusCode == 0 {
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
finalStatusCode = statusWriter.Status()
} else {
finalStatusCode = 200
}
}
var slicesAPIResponseError []*interfaces.ErrorMessage
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
if isExist {
if apiErrors, ok := apiResponseError.([]*interfaces.ErrorMessage); ok {
slicesAPIResponseError = apiErrors
}
}
hasAPIError := len(slicesAPIResponseError) > 0 || finalStatusCode >= http.StatusBadRequest
forceLog := w.logOnErrorOnly && hasAPIError && !w.logger.IsEnabled()
if !w.logger.IsEnabled() && !forceLog {
return nil
}
if w.isStreaming && w.streamWriter != nil {
if w.chunkChannel != nil {
close(w.chunkChannel)
w.chunkChannel = nil
}
if w.streamDone != nil {
<-w.streamDone
w.streamDone = nil
}
w.streamWriter.SetFirstChunkTimestamp(w.firstChunkTimestamp)
// Write API Request and Response to the streaming log before closing
apiRequest := w.extractAPIRequest(c)
if len(apiRequest) > 0 {
_ = w.streamWriter.WriteAPIRequest(apiRequest)
}
apiResponse := w.extractAPIResponse(c)
if len(apiResponse) > 0 {
_ = w.streamWriter.WriteAPIResponse(apiResponse)
}
if err := w.streamWriter.Close(); err != nil {
w.streamWriter = nil
return err
}
w.streamWriter = nil
return nil
}
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
w.ensureHeadersCaptured()
finalHeaders := make(map[string][]string, len(w.headers))
for key, values := range w.headers {
headerValues := make([]string, len(values))
copy(headerValues, values)
finalHeaders[key] = headerValues
}
return finalHeaders
}
func (w *ResponseWriterWrapper) extractAPIRequest(c *gin.Context) []byte {
apiRequest, isExist := c.Get("API_REQUEST")
if !isExist {
return nil
}
data, ok := apiRequest.([]byte)
if !ok || len(data) == 0 {
return nil
}
return data
}
func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
apiResponse, isExist := c.Get("API_RESPONSE")
if !isExist {
return nil
}
data, ok := apiResponse.([]byte)
if !ok || len(data) == 0 {
return nil
}
return data
}
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
if !isExist {
return time.Time{}
}
if t, ok := ts.(time.Time); ok {
return t
}
return time.Time{}
}
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
}
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {
return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL,
w.requestInfo.Method,
w.requestInfo.Headers,
requestBody,
statusCode,
headers,
body,
apiRequestBody,
apiResponseBody,
apiResponseErrors,
forceLog,
w.requestInfo.RequestID,
w.requestInfo.Timestamp,
apiResponseTimestamp,
)
}
return w.logger.LogRequest(
w.requestInfo.URL,
w.requestInfo.Method,
w.requestInfo.Headers,
requestBody,
statusCode,
headers,
body,
apiRequestBody,
apiResponseBody,
apiResponseErrors,
w.requestInfo.RequestID,
w.requestInfo.Timestamp,
apiResponseTimestamp,
)
}
================================================
FILE: internal/api/middleware/response_writer_test.go
================================================
package middleware
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{
requestInfo: &RequestInfo{Body: []byte("original-body")},
}
body := wrapper.extractRequestBody(c)
if string(body) != "original-body" {
t.Fatalf("request body = %q, want %q", string(body), "original-body")
}
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
body = wrapper.extractRequestBody(c)
if string(body) != "override-body" {
t.Fatalf("request body = %q, want %q", string(body), "override-body")
}
}
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c)
if string(body) != "override-as-string" {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
}
}
================================================
FILE: internal/api/modules/amp/amp.go
================================================
// Package amp implements the Amp CLI routing module, providing OAuth-based
// integration with Amp CLI for ChatGPT and Anthropic subscriptions.
package amp
import (
"fmt"
"net/http/httputil"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
log "github.com/sirupsen/logrus"
)
// Option configures the AmpModule.
type Option func(*AmpModule)
// AmpModule implements the RouteModuleV2 interface for Amp CLI integration.
// It provides:
// - Reverse proxy to Amp control plane for OAuth/management
// - Provider-specific route aliases (/api/provider/{provider}/...)
// - Automatic gzip decompression for misconfigured upstreams
// - Model mapping for routing unavailable models to alternatives
type AmpModule struct {
secretSource SecretSource
proxy *httputil.ReverseProxy
proxyMu sync.RWMutex // protects proxy for hot-reload
accessManager *sdkaccess.Manager
authMiddleware_ gin.HandlerFunc
modelMapper *DefaultModelMapper
enabled bool
registerOnce sync.Once
// restrictToLocalhost controls localhost-only access for management routes (hot-reloadable)
restrictToLocalhost bool
restrictMu sync.RWMutex
// configMu protects lastConfig for partial reload comparison
configMu sync.RWMutex
lastConfig *config.AmpCode
}
// New creates a new Amp routing module with the given options.
// This is the preferred constructor using the Option pattern.
//
// Example:
//
// ampModule := amp.New(
// amp.WithAccessManager(accessManager),
// amp.WithAuthMiddleware(authMiddleware),
// amp.WithSecretSource(customSecret),
// )
func New(opts ...Option) *AmpModule {
m := &AmpModule{
secretSource: nil, // Will be created on demand if not provided
}
for _, opt := range opts {
opt(m)
}
return m
}
// NewLegacy creates a new Amp routing module using the legacy constructor signature.
// This is provided for backwards compatibility.
//
// DEPRECATED: Use New with options instead.
func NewLegacy(accessManager *sdkaccess.Manager, authMiddleware gin.HandlerFunc) *AmpModule {
return New(
WithAccessManager(accessManager),
WithAuthMiddleware(authMiddleware),
)
}
// WithSecretSource sets a custom secret source for the module.
func WithSecretSource(source SecretSource) Option {
return func(m *AmpModule) {
m.secretSource = source
}
}
// WithAccessManager sets the access manager for the module.
func WithAccessManager(am *sdkaccess.Manager) Option {
return func(m *AmpModule) {
m.accessManager = am
}
}
// WithAuthMiddleware sets the authentication middleware for provider routes.
func WithAuthMiddleware(middleware gin.HandlerFunc) Option {
return func(m *AmpModule) {
m.authMiddleware_ = middleware
}
}
// Name returns the module identifier
func (m *AmpModule) Name() string {
return "amp-routing"
}
// forceModelMappings returns whether model mappings should take precedence over local API keys
func (m *AmpModule) forceModelMappings() bool {
m.configMu.RLock()
defer m.configMu.RUnlock()
if m.lastConfig == nil {
return false
}
return m.lastConfig.ForceModelMappings
}
// Register sets up Amp routes if configured.
// This implements the RouteModuleV2 interface with Context.
// Routes are registered only once via sync.Once for idempotent behavior.
func (m *AmpModule) Register(ctx modules.Context) error {
settings := ctx.Config.AmpCode
upstreamURL := strings.TrimSpace(settings.UpstreamURL)
// Determine auth middleware (from module or context)
auth := m.getAuthMiddleware(ctx)
// Use registerOnce to ensure routes are only registered once
var regErr error
m.registerOnce.Do(func() {
// Initialize model mapper from config (for routing unavailable models to alternatives)
m.modelMapper = NewModelMapper(settings.ModelMappings)
// Store initial config for partial reload comparison
m.lastConfig = new(settings)
// Initialize localhost restriction setting (hot-reloadable)
m.setRestrictToLocalhost(settings.RestrictManagementToLocalhost)
// Always register provider aliases - these work without an upstream
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
// Pass auth middleware to require valid API key for all management routes.
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)
// If no upstream URL, skip proxy routes but provider aliases are still available
if upstreamURL == "" {
log.Debug("amp upstream proxy disabled (no upstream URL configured)")
log.Debug("amp provider alias routes registered")
m.enabled = false
return
}
if err := m.enableUpstreamProxy(upstreamURL, &settings); err != nil {
regErr = fmt.Errorf("failed to create amp proxy: %w", err)
return
}
log.Debug("amp provider alias routes registered")
})
return regErr
}
// getAuthMiddleware returns the authentication middleware, preferring the
// module's configured middleware, then the context middleware, then a fallback.
func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
if m.authMiddleware_ != nil {
return m.authMiddleware_
}
if ctx.AuthMiddleware != nil {
return ctx.AuthMiddleware
}
// Fallback: no authentication (should not happen in production)
log.Warn("amp module: no auth middleware provided, allowing all requests")
return func(c *gin.Context) {
c.Next()
}
}
// OnConfigUpdated handles configuration updates with partial reload support.
// Only updates components that have actually changed to avoid unnecessary work.
// Supports hot-reload for: model-mappings, upstream-api-key, upstream-url, restrict-management-to-localhost.
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
newSettings := cfg.AmpCode
// Get previous config for comparison
m.configMu.RLock()
oldSettings := m.lastConfig
m.configMu.RUnlock()
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
}
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
oldUpstreamURL := ""
if oldSettings != nil {
oldUpstreamURL = strings.TrimSpace(oldSettings.UpstreamURL)
}
if !m.enabled && newUpstreamURL != "" {
if err := m.enableUpstreamProxy(newUpstreamURL, &newSettings); err != nil {
log.Errorf("amp config: failed to enable upstream proxy for %s: %v", newUpstreamURL, err)
}
}
// Check model mappings change
modelMappingsChanged := m.hasModelMappingsChanged(oldSettings, &newSettings)
if modelMappingsChanged {
if m.modelMapper != nil {
m.modelMapper.UpdateMappings(newSettings.ModelMappings)
} else if m.enabled {
log.Warnf("amp model mapper not initialized, skipping model mapping update")
}
}
if m.enabled {
// Check upstream URL change - now supports hot-reload
if newUpstreamURL == "" && oldUpstreamURL != "" {
m.setProxy(nil)
m.enabled = false
} else if oldUpstreamURL != "" && newUpstreamURL != oldUpstreamURL && newUpstreamURL != "" {
// Recreate proxy with new URL
proxy, err := createReverseProxy(newUpstreamURL, m.secretSource)
if err != nil {
log.Errorf("amp config: failed to create proxy for new upstream URL %s: %v", newUpstreamURL, err)
} else {
m.setProxy(proxy)
}
}
// Check API key change (both default and per-client mappings)
apiKeyChanged := m.hasAPIKeyChanged(oldSettings, &newSettings)
upstreamAPIKeysChanged := m.hasUpstreamAPIKeysChanged(oldSettings, &newSettings)
if apiKeyChanged || upstreamAPIKeysChanged {
if m.secretSource != nil {
if ms, ok := m.secretSource.(*MappedSecretSource); ok {
if apiKeyChanged {
ms.UpdateDefaultExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
if upstreamAPIKeysChanged {
ms.UpdateMappings(newSettings.UpstreamAPIKeys)
}
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(newSettings.UpstreamAPIKey)
ms.InvalidateCache()
}
}
}
}
// Store current config for next comparison
m.configMu.Lock()
settingsCopy := newSettings // copy struct
m.lastConfig = &settingsCopy
m.configMu.Unlock()
return nil
}
func (m *AmpModule) enableUpstreamProxy(upstreamURL string, settings *config.AmpCode) error {
if m.secretSource == nil {
// Create MultiSourceSecret as the default source, then wrap with MappedSecretSource
defaultSource := NewMultiSourceSecret(settings.UpstreamAPIKey, 0 /* default 5min */)
mappedSource := NewMappedSecretSource(defaultSource)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
} else if ms, ok := m.secretSource.(*MappedSecretSource); ok {
ms.UpdateDefaultExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
ms.UpdateMappings(settings.UpstreamAPIKeys)
} else if ms, ok := m.secretSource.(*MultiSourceSecret); ok {
// Legacy path: wrap existing MultiSourceSecret with MappedSecretSource
ms.UpdateExplicitKey(settings.UpstreamAPIKey)
ms.InvalidateCache()
mappedSource := NewMappedSecretSource(ms)
mappedSource.UpdateMappings(settings.UpstreamAPIKeys)
m.secretSource = mappedSource
}
proxy, err := createReverseProxy(upstreamURL, m.secretSource)
if err != nil {
return err
}
m.setProxy(proxy)
m.enabled = true
log.Infof("amp upstream proxy enabled for: %s", upstreamURL)
return nil
}
// hasModelMappingsChanged compares old and new model mappings.
func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.ModelMappings) > 0
}
if len(old.ModelMappings) != len(new.ModelMappings) {
return true
}
// Build map for efficient and robust comparison
type mappingInfo struct {
to string
regex bool
}
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
for _, mapping := range old.ModelMappings {
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
to: strings.TrimSpace(mapping.To),
regex: mapping.Regex,
}
}
for _, mapping := range new.ModelMappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
return true
}
}
return false
}
// hasAPIKeyChanged compares old and new API keys.
func (m *AmpModule) hasAPIKeyChanged(old *config.AmpCode, new *config.AmpCode) bool {
oldKey := ""
if old != nil {
oldKey = strings.TrimSpace(old.UpstreamAPIKey)
}
newKey := strings.TrimSpace(new.UpstreamAPIKey)
return oldKey != newKey
}
// hasUpstreamAPIKeysChanged compares old and new per-client upstream API key mappings.
func (m *AmpModule) hasUpstreamAPIKeysChanged(old *config.AmpCode, new *config.AmpCode) bool {
if old == nil {
return len(new.UpstreamAPIKeys) > 0
}
if len(old.UpstreamAPIKeys) != len(new.UpstreamAPIKeys) {
return true
}
// Build map for comparison: upstreamKey -> set of clientKeys
type entryInfo struct {
upstreamKey string
clientKeys map[string]struct{}
}
oldEntries := make([]entryInfo, len(old.UpstreamAPIKeys))
for i, entry := range old.UpstreamAPIKeys {
clientKeys := make(map[string]struct{}, len(entry.APIKeys))
for _, k := range entry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
clientKeys[trimmed] = struct{}{}
}
oldEntries[i] = entryInfo{
upstreamKey: strings.TrimSpace(entry.UpstreamAPIKey),
clientKeys: clientKeys,
}
}
for i, newEntry := range new.UpstreamAPIKeys {
if i >= len(oldEntries) {
return true
}
oldE := oldEntries[i]
if strings.TrimSpace(newEntry.UpstreamAPIKey) != oldE.upstreamKey {
return true
}
newKeys := make(map[string]struct{}, len(newEntry.APIKeys))
for _, k := range newEntry.APIKeys {
trimmed := strings.TrimSpace(k)
if trimmed == "" {
continue
}
newKeys[trimmed] = struct{}{}
}
if len(newKeys) != len(oldE.clientKeys) {
return true
}
for k := range newKeys {
if _, ok := oldE.clientKeys[k]; !ok {
return true
}
}
}
return false
}
// GetModelMapper returns the model mapper instance (for testing/debugging).
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
return m.modelMapper
}
// getProxy returns the current proxy instance (thread-safe for hot-reload).
func (m *AmpModule) getProxy() *httputil.ReverseProxy {
m.proxyMu.RLock()
defer m.proxyMu.RUnlock()
return m.proxy
}
// setProxy updates the proxy instance (thread-safe for hot-reload).
func (m *AmpModule) setProxy(proxy *httputil.ReverseProxy) {
m.proxyMu.Lock()
defer m.proxyMu.Unlock()
m.proxy = proxy
}
// IsRestrictedToLocalhost returns whether management routes are restricted to localhost.
func (m *AmpModule) IsRestrictedToLocalhost() bool {
m.restrictMu.RLock()
defer m.restrictMu.RUnlock()
return m.restrictToLocalhost
}
// setRestrictToLocalhost updates the localhost restriction setting.
func (m *AmpModule) setRestrictToLocalhost(restrict bool) {
m.restrictMu.Lock()
defer m.restrictMu.Unlock()
m.restrictToLocalhost = restrict
}
================================================
FILE: internal/api/modules/amp/amp_test.go
================================================
package amp
import (
"context"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
)
func TestAmpModule_Name(t *testing.T) {
m := New()
if m.Name() != "amp-routing" {
t.Fatalf("want amp-routing, got %s", m.Name())
}
}
func TestAmpModule_New(t *testing.T) {
accessManager := sdkaccess.NewManager()
authMiddleware := func(c *gin.Context) { c.Next() }
m := NewLegacy(accessManager, authMiddleware)
if m.accessManager != accessManager {
t.Fatal("accessManager not set")
}
if m.authMiddleware_ == nil {
t.Fatal("authMiddleware not set")
}
if m.enabled {
t.Fatal("enabled should be false initially")
}
if m.proxy != nil {
t.Fatal("proxy should be nil initially")
}
}
func TestAmpModule_Register_WithUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Fake upstream to ensure URL is valid
upstream := httptest.NewServer(nil)
defer upstream.Close()
accessManager := sdkaccess.NewManager()
base := &handlers.BaseAPIHandler{}
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
cfg := &config.Config{
AmpCode: config.AmpCode{
UpstreamURL: upstream.URL,
UpstreamAPIKey: "test-key",
},
}
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
if err := m.Register(ctx); err != nil {
t.Fatalf("register error: %v", err)
}
if !m.enabled {
t.Fatal("module should be enabled with upstream URL")
}
if m.proxy == nil {
t.Fatal("proxy should be initialized")
}
if m.secretSource == nil {
t.Fatal("secretSource should be initialized")
}
}
func TestAmpModule_Register_WithoutUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
accessManager := sdkaccess.NewManager()
base := &handlers.BaseAPIHandler{}
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
cfg := &config.Config{
AmpCode: config.AmpCode{
UpstreamURL: "", // No upstream
},
}
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
if err := m.Register(ctx); err != nil {
t.Fatalf("register should not error without upstream: %v", err)
}
if m.enabled {
t.Fatal("module should be disabled without upstream URL")
}
if m.proxy != nil {
t.Fatal("proxy should not be initialized without upstream")
}
// But provider aliases should still be registered
req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == 404 {
t.Fatal("provider aliases should be registered even without upstream")
}
}
func TestAmpModule_Register_InvalidUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
accessManager := sdkaccess.NewManager()
base := &handlers.BaseAPIHandler{}
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
cfg := &config.Config{
AmpCode: config.AmpCode{
UpstreamURL: "://invalid-url",
},
}
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
if err := m.Register(ctx); err == nil {
t.Fatal("expected error for invalid upstream URL")
}
}
func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
tmpDir := t.TempDir()
p := filepath.Join(tmpDir, "secrets.json")
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
t.Fatal(err)
}
m := &AmpModule{enabled: true}
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
m.secretSource = ms
m.lastConfig = &config.AmpCode{
UpstreamAPIKey: "old-key",
}
// Warm the cache
if _, err := ms.Get(context.Background()); err != nil {
t.Fatal(err)
}
if ms.cache == nil {
t.Fatal("expected cache to be set")
}
// Update config - should invalidate cache
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
t.Fatal(err)
}
if ms.cache != nil {
t.Fatal("expected cache to be invalidated")
}
}
func TestAmpModule_OnConfigUpdated_NotEnabled(t *testing.T) {
m := &AmpModule{enabled: false}
// Should not error or panic when disabled
if err := m.OnConfigUpdated(&config.Config{}); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestAmpModule_OnConfigUpdated_URLRemoved(t *testing.T) {
m := &AmpModule{enabled: true}
ms := NewMultiSourceSecret("", 0)
m.secretSource = ms
// Config update with empty URL - should log warning but not error
cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: ""}}
if err := m.OnConfigUpdated(cfg); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestAmpModule_OnConfigUpdated_NonMultiSourceSecret(t *testing.T) {
// Test that OnConfigUpdated doesn't panic with StaticSecretSource
m := &AmpModule{enabled: true}
m.secretSource = NewStaticSecretSource("static-key")
cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://example.com"}}
// Should not error or panic
if err := m.OnConfigUpdated(cfg); err != nil {
t.Fatalf("unexpected error: %v", err)
}
}
func TestAmpModule_AuthMiddleware_Fallback(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Create module with no auth middleware
m := &AmpModule{authMiddleware_: nil}
// Get the fallback middleware via getAuthMiddleware
ctx := modules.Context{Engine: r, AuthMiddleware: nil}
middleware := m.getAuthMiddleware(ctx)
if middleware == nil {
t.Fatal("getAuthMiddleware should return a fallback, not nil")
}
// Test that it works
called := false
r.GET("/test", middleware, func(c *gin.Context) {
called = true
c.String(200, "ok")
})
req := httptest.NewRequest("GET", "/test", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if !called {
t.Fatal("fallback middleware should allow requests through")
}
}
func TestAmpModule_SecretSource_FromConfig(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
upstream := httptest.NewServer(nil)
defer upstream.Close()
accessManager := sdkaccess.NewManager()
base := &handlers.BaseAPIHandler{}
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
// Config with explicit API key
cfg := &config.Config{
AmpCode: config.AmpCode{
UpstreamURL: upstream.URL,
UpstreamAPIKey: "config-key",
},
}
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
if err := m.Register(ctx); err != nil {
t.Fatalf("register error: %v", err)
}
// Secret source should be MultiSourceSecret with config key
if m.secretSource == nil {
t.Fatal("secretSource should be set")
}
// Verify it returns the config key
key, err := m.secretSource.Get(context.Background())
if err != nil {
t.Fatalf("Get error: %v", err)
}
if key != "config-key" {
t.Fatalf("want config-key, got %s", key)
}
}
func TestAmpModule_ProviderAliasesAlwaysRegistered(t *testing.T) {
gin.SetMode(gin.TestMode)
scenarios := []struct {
name string
configURL string
}{
{"with_upstream", "http://example.com"},
{"without_upstream", ""},
}
for _, scenario := range scenarios {
t.Run(scenario.name, func(t *testing.T) {
r := gin.New()
accessManager := sdkaccess.NewManager()
base := &handlers.BaseAPIHandler{}
m := NewLegacy(accessManager, func(c *gin.Context) { c.Next() })
cfg := &config.Config{AmpCode: config.AmpCode{UpstreamURL: scenario.configURL}}
ctx := modules.Context{Engine: r, BaseHandler: base, Config: cfg, AuthMiddleware: func(c *gin.Context) { c.Next() }}
if err := m.Register(ctx); err != nil && scenario.configURL != "" {
t.Fatalf("register error: %v", err)
}
// Provider aliases should always be available
req := httptest.NewRequest("GET", "/api/provider/openai/models", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == 404 {
t.Fatal("provider aliases should be registered")
}
})
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_DetectsRemovedKeyWithDuplicateInput(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k1"}},
},
}
if !m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected change to be detected when k2 is removed but new list contains duplicates")
}
}
func TestAmpModule_hasUpstreamAPIKeysChanged_IgnoresEmptyAndWhitespaceKeys(t *testing.T) {
m := &AmpModule{}
oldCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{"k1", "k2"}},
},
}
newCfg := &config.AmpCode{
UpstreamAPIKeys: []config.AmpUpstreamAPIKeyEntry{
{UpstreamAPIKey: "u1", APIKeys: []string{" k1 ", "", "k2", " "}},
},
}
if m.hasUpstreamAPIKeysChanged(oldCfg, newCfg) {
t.Fatal("expected no change when only whitespace/empty entries differ")
}
}
================================================
FILE: internal/api/modules/amp/fallback_handlers.go
================================================
package amp
import (
"bytes"
"io"
"net/http/httputil"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// AmpRouteType represents the type of routing decision made for an Amp request
type AmpRouteType string
const (
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free)
RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER"
// RouteTypeModelMapping indicates the request was remapped to another available model (free)
RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING"
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits)
RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS"
// RouteTypeNoProvider indicates no provider or fallback available
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
)
// MappedModelContextKey is the Gin context key for passing mapped model names.
const MappedModelContextKey = "mapped_model"
// logAmpRouting logs the routing decision for an Amp request with structured fields
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
fields := log.Fields{
"component": "amp-routing",
"route_type": string(routeType),
"requested_model": requestedModel,
"path": path,
"timestamp": time.Now().Format(time.RFC3339),
}
if resolvedModel != "" && resolvedModel != requestedModel {
fields["resolved_model"] = resolvedModel
}
if provider != "" {
fields["provider"] = provider
}
switch routeType {
case RouteTypeLocalProvider:
fields["cost"] = "free"
fields["source"] = "local_oauth"
log.WithFields(fields).Debugf("amp using local provider for model: %s", requestedModel)
case RouteTypeModelMapping:
fields["cost"] = "free"
fields["source"] = "local_oauth"
fields["mapping"] = requestedModel + " -> " + resolvedModel
// model mapping already logged in mapper; avoid duplicate here
case RouteTypeAmpCredits:
fields["cost"] = "amp_credits"
fields["source"] = "ampcode.com"
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"\"}]", requestedModel, requestedModel)
case RouteTypeNoProvider:
fields["cost"] = "none"
fields["source"] = "error"
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
log.WithFields(fields).Warnf("no provider available for model_id: %s", requestedModel)
}
}
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI
type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
forceModelMappings func() bool
}
// NewFallbackHandler creates a new fallback handler wrapper
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
return &FallbackHandler{
getProxy: getProxy,
forceModelMappings: func() bool { return false },
}
}
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
if forceModelMappings == nil {
forceModelMappings = func() bool { return false }
}
return &FallbackHandler{
getProxy: getProxy,
modelMapper: mapper,
forceModelMappings: forceModelMappings,
}
}
// SetModelMapper sets the model mapper for this handler (allows late binding)
func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
fh.modelMapper = mapper
}
// WrapHandler wraps a gin.HandlerFunc with fallback logic
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
requestPath := c.Request.URL.Path
// Read the request body to extract the model name
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
log.Errorf("amp fallback: failed to read request body: %v", err)
handler(c)
return
}
// Restore the body for the handler to read
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Try to extract model from request body or URL path (for Gemini)
modelName := extractModelFromRequest(bodyBytes, c)
if modelName == "" {
// Can't determine model, proceed with normal handler
handler(c)
return
}
// Normalize model (handles dynamic thinking suffixes)
suffixResult := thinking.ParseSuffix(modelName)
normalizedModel := suffixResult.ModelName
thinkingSuffix := ""
if suffixResult.HasSuffix {
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
}
resolveMappedModel := func() (string, []string) {
if fh.modelMapper == nil {
return "", nil
}
mappedModel := fh.modelMapper.MapModel(modelName)
if mappedModel == "" {
mappedModel = fh.modelMapper.MapModel(normalizedModel)
}
mappedModel = strings.TrimSpace(mappedModel)
if mappedModel == "" {
return "", nil
}
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
// already specifies its own thinking suffix.
if thinkingSuffix != "" {
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
if !mappedSuffixResult.HasSuffix {
mappedModel += thinkingSuffix
}
}
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
mappedProviders := util.GetProviderName(mappedBaseModel)
if len(mappedProviders) == 0 {
return "", nil
}
return mappedModel, mappedProviders
}
// Track resolved model for logging (may change if mapping is applied)
resolvedModel := normalizedModel
usedMapping := false
var providers []string
// Check if model mappings should be forced ahead of local API keys
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
if forceMappings {
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
// This allows users to route Amp requests to their preferred OAuth providers
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
// If no mapping applied, check for local providers
if !usedMapping {
providers = util.GetProviderName(normalizedModel)
}
} else {
// DEFAULT MODE: Check local providers first, then mappings as fallback
providers = util.GetProviderName(normalizedModel)
if len(providers) == 0 {
// No providers configured - check if we have a model mapping
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
}
}
// If no providers available, fallback to ampcode.com
if len(providers) == 0 {
proxy := fh.getProxy()
if proxy != nil {
// Log: Forwarding to ampcode.com (uses Amp credits)
logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath)
// Restore body again for the proxy
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Forward to ampcode.com
proxy.ServeHTTP(c.Writer, c.Request)
return
}
// No proxy available, let the normal handler return the error
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
}
// Log the routing decision
providerName := ""
if len(providers) > 0 {
providerName = providers[0]
}
if usedMapping {
// Log: Model was mapped to another model
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
rewriter := NewResponseRewriter(c.Writer, modelName)
c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
rewriter.Flush()
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
} else if len(providers) > 0 {
// Log: Using local provider (free)
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
// Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
} else {
// No provider, no mapping, no proxy: fall back to the wrapped handler so it can return an error response
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
}
}
}
// filterAntropicBetaHeader filters Anthropic-Beta header to remove features requiring special subscription
// This is needed when using local providers (bypassing the Amp proxy)
func filterAntropicBetaHeader(c *gin.Context) {
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
if filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07"); filtered != "" {
c.Request.Header.Set("Anthropic-Beta", filtered)
} else {
c.Request.Header.Del("Anthropic-Beta")
}
}
}
// rewriteModelInRequest replaces the model name in a JSON request body
func rewriteModelInRequest(body []byte, newModel string) []byte {
if !gjson.GetBytes(body, "model").Exists() {
return body
}
result, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
log.Warnf("amp model mapping: failed to rewrite model in request body: %v", err)
return body
}
return result
}
// extractModelFromRequest attempts to extract the model name from various request formats
func extractModelFromRequest(body []byte, c *gin.Context) string {
// First try to parse from JSON body (OpenAI, Claude, etc.)
// Check common model field names
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
return result.String()
}
// For Gemini requests, model is in the URL path
// Standard format: /models/{model}:generateContent -> :action parameter
if action := c.Param("action"); action != "" {
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
parts := strings.Split(action, ":")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
}
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
if path := c.Param("path"); path != "" {
// Look for /models/{model}:method pattern
if idx := strings.Index(path, "/models/"); idx >= 0 {
modelPart := path[idx+8:] // Skip "/models/"
// Split by colon to get model name
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
return modelPart[:colonIdx]
}
}
}
return ""
}
================================================
FILE: internal/api/modules/amp/fallback_handlers_test.go
================================================
package amp
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/http/httputil"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
})
defer reg.UnregisterClient("test-client-amp-fallback")
mapper := NewModelMapper([]config.AmpModelMapping{
{From: "gpt-5.2", To: "test/gpt-5.2"},
})
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
handler := func(c *gin.Context) {
var req struct {
Model string `json:"model"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"model": req.Model,
"seen_model": req.Model,
})
}
r := gin.New()
r.POST("/chat/completions", fallback.WrapHandler(handler))
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected status 200, got %d", w.Code)
}
var resp struct {
Model string `json:"model"`
SeenModel string `json:"seen_model"`
}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Failed to parse response JSON: %v", err)
}
if resp.Model != "gpt-5.2(xhigh)" {
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
}
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
}
}
================================================
FILE: internal/api/modules/amp/gemini_bridge.go
================================================
package amp
import (
"strings"
"github.com/gin-gonic/gin"
)
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
// to our standard Gemini handler by rewriting the request context.
//
// AMP CLI format: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
// Standard format: /models/gemini-3-pro-preview:streamGenerateContent
//
// This extracts the model+method from the AMP path and sets it as the :action parameter
// so the standard Gemini handler can process it.
//
// The handler parameter should be a Gemini-compatible handler that expects the :action param.
func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
// Get the full path from the catch-all parameter
path := c.Param("path")
// Extract model:method from AMP CLI path format
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
const modelsPrefix = "/models/"
if idx := strings.Index(path, modelsPrefix); idx >= 0 {
// Extract everything after modelsPrefix
actionPart := path[idx+len(modelsPrefix):]
// Check if model was mapped by FallbackHandler
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
if strModel, ok := mappedModel.(string); ok && strModel != "" {
// Replace the model part in the action
// actionPart is like "model-name:method"
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
method := actionPart[colonIdx:] // ":method"
actionPart = strModel + method
}
}
}
// Set this as the :action parameter that the Gemini handler expects
c.Params = append(c.Params, gin.Param{
Key: "action",
Value: actionPart,
})
// Call the handler
handler(c)
return
}
// If we can't parse the path, return 400
c.JSON(400, gin.H{
"error": "Invalid Gemini API path format",
})
}
}
================================================
FILE: internal/api/modules/amp/gemini_bridge_test.go
================================================
package amp
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
path string
mappedModel string // empty string means no mapping
expectedAction string
}{
{
name: "no_mapping_uses_url_model",
path: "/publishers/google/models/gemini-pro:generateContent",
mappedModel: "",
expectedAction: "gemini-pro:generateContent",
},
{
name: "mapped_model_replaces_url_model",
path: "/publishers/google/models/gemini-exp:generateContent",
mappedModel: "gemini-2.0-flash",
expectedAction: "gemini-2.0-flash:generateContent",
},
{
name: "mapping_preserves_method",
path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
mappedModel: "gemini-flash",
expectedAction: "gemini-flash:streamGenerateContent",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedAction string
mockGeminiHandler := func(c *gin.Context) {
capturedAction = c.Param("action")
c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
}
// Use the actual createGeminiBridgeHandler function
bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler)
r := gin.New()
if tt.mappedModel != "" {
r.Use(func(c *gin.Context) {
c.Set(MappedModelContextKey, tt.mappedModel)
c.Next()
})
}
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected status 200, got %d", w.Code)
}
if capturedAction != tt.expectedAction {
t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
}
})
}
}
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
gin.SetMode(gin.TestMode)
mockHandler := func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
}
bridgeHandler := createGeminiBridgeHandler(mockHandler)
r := gin.New()
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
}
}
================================================
FILE: internal/api/modules/amp/model_mapping.go
================================================
// Package amp provides model mapping functionality for routing Amp CLI requests
// to alternative models when the requested model is not available locally.
package amp
import (
"regexp"
"strings"
"sync"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// ModelMapper provides model name mapping/aliasing for Amp CLI requests.
// When an Amp request comes in for a model that isn't available locally,
// this mapper can redirect it to an alternative model that IS available.
type ModelMapper interface {
// MapModel returns the target model name if a mapping exists and the target
// model has available providers. Returns empty string if no mapping applies.
MapModel(requestedModel string) string
// UpdateMappings refreshes the mapping configuration (for hot-reload).
UpdateMappings(mappings []config.AmpModelMapping)
}
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
type DefaultModelMapper struct {
mu sync.RWMutex
mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order
}
// NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{
mappings: make(map[string]string),
regexps: nil,
}
m.UpdateMappings(mappings)
return m
}
// MapModel checks if a mapping exists for the requested model and if the
// target model has available local providers. Returns the mapped model name
// or empty string if no valid mapping exists.
//
// If the requested model contains a thinking suffix (e.g., "g25p(8192)"),
// the suffix is preserved in the returned model name (e.g., "gemini-2.5-pro(8192)").
// However, if the mapping target already contains a suffix, the config suffix
// takes priority over the user's suffix.
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
if requestedModel == "" {
return ""
}
m.mu.RLock()
defer m.mu.RUnlock()
// Extract thinking suffix from requested model using ParseSuffix
requestResult := thinking.ParseSuffix(requestedModel)
baseModel := requestResult.ModelName
// Normalize the base model for lookup (case-insensitive)
normalizedBase := strings.ToLower(strings.TrimSpace(baseModel))
// Check for direct mapping using base model name
targetModel, exists := m.mappings[normalizedBase]
if !exists {
// Try regex mappings in order using base model only
// (suffix is handled separately via ParseSuffix)
for _, rm := range m.regexps {
if rm.re.MatchString(baseModel) {
targetModel = rm.to
exists = true
break
}
}
if !exists {
return ""
}
}
// Check if target model already has a thinking suffix (config priority)
targetResult := thinking.ParseSuffix(targetModel)
// Verify target model has available providers (use base model for lookup)
providers := util.GetProviderName(targetResult.ModelName)
if len(providers) == 0 {
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
return ""
}
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
if targetResult.HasSuffix {
// Config's "to" already contains a suffix - use it as-is (config priority)
return targetModel
}
// Preserve user's thinking suffix on the mapped model
// (skip empty suffixes to avoid returning "model()")
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return targetModel + "(" + requestResult.RawSuffix + ")"
}
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
return targetModel
}
// UpdateMappings refreshes the mapping configuration from config.
// This is called during initialization and on config hot-reload.
func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
m.mu.Lock()
defer m.mu.Unlock()
// Clear and rebuild mappings
m.mappings = make(map[string]string, len(mappings))
m.regexps = make([]regexMapping, 0, len(mappings))
for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if from == "" || to == "" {
log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to)
continue
}
if mapping.Regex {
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
pattern := "(?i)" + from
re, err := regexp.Compile(pattern)
if err != nil {
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
continue
}
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
} else {
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
}
}
if len(m.mappings) > 0 {
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
}
if n := len(m.regexps); n > 0 {
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
}
}
// GetMappings returns a copy of current mappings (for debugging/status).
func (m *DefaultModelMapper) GetMappings() map[string]string {
m.mu.RLock()
defer m.mu.RUnlock()
result := make(map[string]string, len(m.mappings))
for k, v := range m.mappings {
result[k] = v
}
return result
}
type regexMapping struct {
re *regexp.Regexp
to string
}
================================================
FILE: internal/api/modules/amp/model_mapping_test.go
================================================
package amp
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func TestNewModelMapper(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
{From: "gpt-5", To: "gemini-2.5-pro"},
}
mapper := NewModelMapper(mappings)
if mapper == nil {
t.Fatal("Expected non-nil mapper")
}
result := mapper.GetMappings()
if len(result) != 2 {
t.Errorf("Expected 2 mappings, got %d", len(result))
}
}
func TestNewModelMapper_Empty(t *testing.T) {
mapper := NewModelMapper(nil)
if mapper == nil {
t.Fatal("Expected non-nil mapper")
}
result := mapper.GetMappings()
if len(result) != 0 {
t.Errorf("Expected 0 mappings, got %d", len(result))
}
}
func TestModelMapper_MapModel_NoProvider(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
// Without a registered provider for the target, mapping should return empty
result := mapper.MapModel("claude-opus-4.5")
if result != "" {
t.Errorf("Expected empty result when target has no provider, got %s", result)
}
}
func TestModelMapper_MapModel_WithProvider(t *testing.T) {
// Register a mock provider for the target model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client")
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
// With a registered provider, mapping should work
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
})
defer reg.UnregisterClient("test-client-thinking")
mappings := []config.AmpModelMapping{
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("gpt-5.2-alias")
if result != "gpt-5.2(xhigh)" {
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
}
}
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client2")
mappings := []config.AmpModelMapping{
{From: "Claude-Opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
// Should match case-insensitively
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_MapModel_NotFound(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
// Unknown model should return empty
result := mapper.MapModel("unknown-model")
if result != "" {
t.Errorf("Expected empty for unknown model, got %s", result)
}
}
func TestModelMapper_MapModel_EmptyInput(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("")
if result != "" {
t.Errorf("Expected empty for empty input, got %s", result)
}
}
func TestModelMapper_UpdateMappings(t *testing.T) {
mapper := NewModelMapper(nil)
// Initially empty
if len(mapper.GetMappings()) != 0 {
t.Error("Expected 0 initial mappings")
}
// Update with new mappings
mapper.UpdateMappings([]config.AmpModelMapping{
{From: "model-a", To: "model-b"},
{From: "model-c", To: "model-d"},
})
result := mapper.GetMappings()
if len(result) != 2 {
t.Errorf("Expected 2 mappings after update, got %d", len(result))
}
// Update again should replace, not append
mapper.UpdateMappings([]config.AmpModelMapping{
{From: "model-x", To: "model-y"},
})
result = mapper.GetMappings()
if len(result) != 1 {
t.Errorf("Expected 1 mapping after second update, got %d", len(result))
}
}
func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
mapper := NewModelMapper(nil)
mapper.UpdateMappings([]config.AmpModelMapping{
{From: "", To: "model-b"}, // Invalid: empty from
{From: "model-a", To: ""}, // Invalid: empty to
{From: " ", To: "model-b"}, // Invalid: whitespace from
{From: "model-c", To: "model-d"}, // Valid
})
result := mapper.GetMappings()
if len(result) != 1 {
t.Errorf("Expected 1 valid mapping, got %d", len(result))
}
}
func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
mappings := []config.AmpModelMapping{
{From: "model-a", To: "model-b"},
}
mapper := NewModelMapper(mappings)
// Get mappings and modify the returned map
result := mapper.GetMappings()
result["new-key"] = "new-value"
// Original should be unchanged
original := mapper.GetMappings()
if len(original) != 1 {
t.Errorf("Expected original to have 1 mapping, got %d", len(original))
}
if _, exists := original["new-key"]; exists {
t.Error("Original map was modified")
}
}
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-1")
mappings := []config.AmpModelMapping{
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
}
mapper := NewModelMapper(mappings)
// Incoming model has reasoning suffix, regex matches base, suffix is preserved
result := mapper.MapModel("gpt-5(high)")
if result != "gemini-2.5-pro(high)" {
t.Errorf("Expected gemini-2.5-pro(high), got %s", result)
}
}
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-2")
defer reg.UnregisterClient("test-client-regex-3")
mappings := []config.AmpModelMapping{
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
}
mapper := NewModelMapper(mappings)
// Exact match should win over regex
result := mapper.MapModel("gpt-5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
// Invalid regex should be skipped and not cause panic
mappings := []config.AmpModelMapping{
{From: "(", To: "target", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("anything")
if result != "" {
t.Errorf("Expected empty result due to invalid regex, got %s", result)
}
}
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-regex-4")
mappings := []config.AmpModelMapping{
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_SuffixPreservation(t *testing.T) {
reg := registry.GetGlobalRegistry()
// Register test models
reg.RegisterClient("test-client-suffix", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
reg.RegisterClient("test-client-suffix-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-suffix")
defer reg.UnregisterClient("test-client-suffix-2")
tests := []struct {
name string
mappings []config.AmpModelMapping
input string
want string
}{
{
name: "numeric suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(8192)",
want: "gemini-2.5-pro(8192)",
},
{
name: "level suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(high)",
want: "gemini-2.5-pro(high)",
},
{
name: "no suffix unchanged",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p",
want: "gemini-2.5-pro",
},
{
name: "config suffix takes priority",
mappings: []config.AmpModelMapping{{From: "alias", To: "gemini-2.5-pro(medium)"}},
input: "alias(high)",
want: "gemini-2.5-pro(medium)",
},
{
name: "regex with suffix preserved",
mappings: []config.AmpModelMapping{{From: "^g25.*", To: "gemini-2.5-pro", Regex: true}},
input: "g25p(8192)",
want: "gemini-2.5-pro(8192)",
},
{
name: "auto suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(auto)",
want: "gemini-2.5-pro(auto)",
},
{
name: "none suffix preserved",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p(none)",
want: "gemini-2.5-pro(none)",
},
{
name: "case insensitive base lookup with suffix",
mappings: []config.AmpModelMapping{{From: "G25P", To: "gemini-2.5-pro"}},
input: "g25p(high)",
want: "gemini-2.5-pro(high)",
},
{
name: "empty suffix filtered out",
mappings: []config.AmpModelMapping{{From: "g25p", To: "gemini-2.5-pro"}},
input: "g25p()",
want: "gemini-2.5-pro",
},
{
name: "incomplete suffix treated as no suffix",
mappings: []config.AmpModelMapping{{From: "g25p(high", To: "gemini-2.5-pro"}},
input: "g25p(high",
want: "gemini-2.5-pro",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mapper := NewModelMapper(tt.mappings)
got := mapper.MapModel(tt.input)
if got != tt.want {
t.Errorf("MapModel(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
================================================
FILE: internal/api/modules/amp/proxy.go
================================================
package amp
import (
"bytes"
"compress/gzip"
"context"
"errors"
"fmt"
"io"
"net/http"
"net/http/httputil"
"net/url"
"strconv"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
func removeQueryValuesMatching(req *http.Request, key string, match string) {
if req == nil || req.URL == nil || match == "" {
return
}
q := req.URL.Query()
values, ok := q[key]
if !ok || len(values) == 0 {
return
}
kept := make([]string, 0, len(values))
for _, v := range values {
if v == match {
continue
}
kept = append(kept, v)
}
if len(kept) == 0 {
q.Del(key)
} else {
q[key] = kept
}
req.URL.RawQuery = q.Encode()
}
// readCloser wraps a reader and forwards Close to a separate closer.
// Used to restore peeked bytes while preserving upstream body Close behavior.
type readCloser struct {
r io.Reader
c io.Closer
}
func (rc *readCloser) Read(p []byte) (int, error) { return rc.r.Read(p) }
func (rc *readCloser) Close() error { return rc.c.Close() }
// createReverseProxy creates a reverse proxy handler for Amp upstream
// with automatic gzip decompression via ModifyResponse
func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputil.ReverseProxy, error) {
parsed, err := url.Parse(upstreamURL)
if err != nil {
return nil, fmt.Errorf("invalid amp upstream url: %w", err)
}
proxy := httputil.NewSingleHostReverseProxy(parsed)
originalDirector := proxy.Director
// Modify outgoing requests to inject API key and fix routing
proxy.Director = func(req *http.Request) {
originalDirector(req)
req.Host = parsed.Host
// Remove client's Authorization header - it was only used for CLI Proxy API authentication
// We will set our own Authorization using the configured upstream-api-key
req.Header.Del("Authorization")
req.Header.Del("X-Api-Key")
req.Header.Del("X-Goog-Api-Key")
// Remove proxy, client identity, and browser fingerprint headers
misc.ScrubProxyAndFingerprintHeaders(req)
// Remove query-based credentials if they match the authenticated client API key.
// This prevents leaking client auth material to the Amp upstream while avoiding
// breaking unrelated upstream query parameters.
clientKey := getClientAPIKeyFromContext(req.Context())
removeQueryValuesMatching(req, "key", clientKey)
removeQueryValuesMatching(req, "auth_token", clientKey)
// Preserve correlation headers for debugging
if req.Header.Get("X-Request-ID") == "" {
// Could generate one here if needed
}
// Note: We do NOT filter Anthropic-Beta headers in the proxy path
// Users going through ampcode.com proxy are paying for the service and should get all features
// including 1M context window (context-1m-2025-08-07)
// Inject API key from secret source (only uses upstream-api-key from config)
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
req.Header.Set("X-Api-Key", key)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
} else if err != nil {
log.Warnf("amp secret source error (continuing without auth): %v", err)
}
}
// Modify incoming responses to handle gzip without Content-Encoding
// This addresses the same issue as inline handler gzip handling, but at the proxy level
proxy.ModifyResponse = func(resp *http.Response) error {
// Only process successful responses
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil
}
// Skip if already marked as gzip (Content-Encoding set)
if resp.Header.Get("Content-Encoding") != "" {
return nil
}
// Skip streaming responses (SSE, chunked)
if isStreamingResponse(resp) {
return nil
}
// Save reference to original upstream body for proper cleanup
originalBody := resp.Body
// Peek at first 2 bytes to detect gzip magic bytes
header := make([]byte, 2)
n, _ := io.ReadFull(originalBody, header)
// Check for gzip magic bytes (0x1f 0x8b)
// If n < 2, we didn't get enough bytes, so it's not gzip
if n >= 2 && header[0] == 0x1f && header[1] == 0x8b {
// It's gzip - read the rest of the body
rest, err := io.ReadAll(originalBody)
if err != nil {
// Restore what we read and return original body (preserve Close behavior)
resp.Body = &readCloser{
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
c: originalBody,
}
return nil
}
// Reconstruct complete gzipped data
gzippedData := append(header[:n], rest...)
// Decompress
gzipReader, err := gzip.NewReader(bytes.NewReader(gzippedData))
if err != nil {
log.Warnf("amp proxy: gzip header detected but decompress failed: %v", err)
// Close original body and return in-memory copy
_ = originalBody.Close()
resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
return nil
}
decompressed, err := io.ReadAll(gzipReader)
_ = gzipReader.Close()
if err != nil {
log.Warnf("amp proxy: gzip decompress error: %v", err)
// Close original body and return in-memory copy
_ = originalBody.Close()
resp.Body = io.NopCloser(bytes.NewReader(gzippedData))
return nil
}
// Close original body since we're replacing with in-memory decompressed content
_ = originalBody.Close()
// Replace body with decompressed content
resp.Body = io.NopCloser(bytes.NewReader(decompressed))
resp.ContentLength = int64(len(decompressed))
// Update headers to reflect decompressed state
resp.Header.Del("Content-Encoding") // No longer compressed
resp.Header.Del("Content-Length") // Remove stale compressed length
resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length
log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed))
} else {
// Not gzip - restore peeked bytes while preserving Close behavior
// Handle edge cases: n might be 0, 1, or 2 depending on EOF
resp.Body = &readCloser{
r: io.MultiReader(bytes.NewReader(header[:n]), originalBody),
c: originalBody,
}
}
return nil
}
// Error handler for proxy failures
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
// Client-side cancellations are common during polling; suppress logging in this case
if errors.Is(err, context.Canceled) {
return
}
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
rw.Header().Set("Content-Type", "application/json")
rw.WriteHeader(http.StatusBadGateway)
_, _ = rw.Write([]byte(`{"error":"amp_upstream_proxy_error","message":"Failed to reach Amp upstream"}`))
}
return proxy, nil
}
// isStreamingResponse detects if the response is streaming (SSE only)
// Note: We only treat text/event-stream as streaming. Chunked transfer encoding
// is a transport-level detail and doesn't mean we can't decompress the full response.
// Many JSON APIs use chunked encoding for normal responses.
func isStreamingResponse(resp *http.Response) bool {
contentType := resp.Header.Get("Content-Type")
// Only Server-Sent Events are true streaming responses
if strings.Contains(contentType, "text/event-stream") {
return true
}
return false
}
// proxyHandler converts httputil.ReverseProxy to gin.HandlerFunc
func proxyHandler(proxy *httputil.ReverseProxy) gin.HandlerFunc {
return func(c *gin.Context) {
proxy.ServeHTTP(c.Writer, c.Request)
}
}
// filterBetaFeatures removes a specific beta feature from comma-separated list
func filterBetaFeatures(header, featureToRemove string) string {
features := strings.Split(header, ",")
filtered := make([]string, 0, len(features))
for _, feature := range features {
trimmed := strings.TrimSpace(feature)
if trimmed != "" && trimmed != featureToRemove {
filtered = append(filtered, trimmed)
}
}
return strings.Join(filtered, ",")
}
================================================
FILE: internal/api/modules/amp/proxy_test.go
================================================
package amp
import (
"bytes"
"compress/gzip"
"context"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// Helper: compress data with gzip
func gzipBytes(b []byte) []byte {
var buf bytes.Buffer
zw := gzip.NewWriter(&buf)
zw.Write(b)
zw.Close()
return buf.Bytes()
}
// Helper: create a mock http.Response
func mkResp(status int, hdr http.Header, body []byte) *http.Response {
if hdr == nil {
hdr = http.Header{}
}
return &http.Response{
StatusCode: status,
Header: hdr,
Body: io.NopCloser(bytes.NewReader(body)),
ContentLength: int64(len(body)),
}
}
func TestCreateReverseProxy_ValidURL(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("key"))
if err != nil {
t.Fatalf("expected no error, got: %v", err)
}
if proxy == nil {
t.Fatal("expected proxy to be created")
}
}
func TestCreateReverseProxy_InvalidURL(t *testing.T) {
_, err := createReverseProxy("://invalid", NewStaticSecretSource("key"))
if err == nil {
t.Fatal("expected error for invalid URL")
}
}
func TestModifyResponse_GzipScenarios(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"ok":true}`)
good := gzipBytes(goodJSON)
truncated := good[:10]
corrupted := append([]byte{0x1f, 0x8b}, []byte("notgzip")...)
cases := []struct {
name string
header http.Header
body []byte
status int
wantBody []byte
wantCE string
}{
{
name: "decompresses_valid_gzip_no_header",
header: http.Header{},
body: good,
status: 200,
wantBody: goodJSON,
wantCE: "",
},
{
name: "skips_when_ce_present",
header: http.Header{"Content-Encoding": []string{"gzip"}},
body: good,
status: 200,
wantBody: good,
wantCE: "gzip",
},
{
name: "passes_truncated_unchanged",
header: http.Header{},
body: truncated,
status: 200,
wantBody: truncated,
wantCE: "",
},
{
name: "passes_corrupted_unchanged",
header: http.Header{},
body: corrupted,
status: 200,
wantBody: corrupted,
wantCE: "",
},
{
name: "non_gzip_unchanged",
header: http.Header{},
body: []byte("plain"),
status: 200,
wantBody: []byte("plain"),
wantCE: "",
},
{
name: "empty_body",
header: http.Header{},
body: []byte{},
status: 200,
wantBody: []byte{},
wantCE: "",
},
{
name: "single_byte_body",
header: http.Header{},
body: []byte{0x1f},
status: 200,
wantBody: []byte{0x1f},
wantCE: "",
},
{
name: "skips_non_2xx_status",
header: http.Header{},
body: good,
status: 404,
wantBody: good,
wantCE: "",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
resp := mkResp(tc.status, tc.header, tc.body)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
got, err := io.ReadAll(resp.Body)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if !bytes.Equal(got, tc.wantBody) {
t.Fatalf("body mismatch:\nwant: %q\ngot: %q", tc.wantBody, got)
}
if ce := resp.Header.Get("Content-Encoding"); ce != tc.wantCE {
t.Fatalf("Content-Encoding: want %q, got %q", tc.wantCE, ce)
}
})
}
}
func TestModifyResponse_UpdatesContentLengthHeader(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"message":"test response"}`)
gzipped := gzipBytes(goodJSON)
// Simulate upstream response with gzip body AND Content-Length header
// (this is the scenario the bot flagged - stale Content-Length after decompression)
resp := mkResp(200, http.Header{
"Content-Length": []string{fmt.Sprintf("%d", len(gzipped))}, // Compressed size
}, gzipped)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
// Verify body is decompressed
got, _ := io.ReadAll(resp.Body)
if !bytes.Equal(got, goodJSON) {
t.Fatalf("body should be decompressed, got: %q, want: %q", got, goodJSON)
}
// Verify Content-Length header is updated to decompressed size
wantCL := fmt.Sprintf("%d", len(goodJSON))
gotCL := resp.Header.Get("Content-Length")
if gotCL != wantCL {
t.Fatalf("Content-Length header mismatch: want %q (decompressed), got %q", wantCL, gotCL)
}
// Verify struct field also matches
if resp.ContentLength != int64(len(goodJSON)) {
t.Fatalf("resp.ContentLength mismatch: want %d, got %d", len(goodJSON), resp.ContentLength)
}
}
func TestModifyResponse_SkipsStreamingResponses(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"ok":true}`)
gzipped := gzipBytes(goodJSON)
t.Run("sse_skips_decompression", func(t *testing.T) {
resp := mkResp(200, http.Header{"Content-Type": []string{"text/event-stream"}}, gzipped)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
// SSE should NOT be decompressed
got, _ := io.ReadAll(resp.Body)
if !bytes.Equal(got, gzipped) {
t.Fatal("SSE response should not be decompressed")
}
})
}
func TestModifyResponse_DecompressesChunkedJSON(t *testing.T) {
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource("k"))
if err != nil {
t.Fatal(err)
}
goodJSON := []byte(`{"ok":true}`)
gzipped := gzipBytes(goodJSON)
t.Run("chunked_json_decompresses", func(t *testing.T) {
// Chunked JSON responses (like thread APIs) should be decompressed
resp := mkResp(200, http.Header{"Transfer-Encoding": []string{"chunked"}}, gzipped)
if err := proxy.ModifyResponse(resp); err != nil {
t.Fatalf("ModifyResponse error: %v", err)
}
// Should decompress because it's not SSE
got, _ := io.ReadAll(resp.Body)
if !bytes.Equal(got, goodJSON) {
t.Fatalf("chunked JSON should be decompressed, got: %q, want: %q", got, goodJSON)
}
})
}
func TestReverseProxy_InjectsHeaders(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("secret"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "secret" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer secret" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_EmptySecret(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
// Should NOT inject headers when secret is empty
if hdr.Get("X-Api-Key") != "" {
t.Fatalf("X-Api-Key should not be set, got: %q", hdr.Get("X-Api-Key"))
}
if authVal := hdr.Get("Authorization"); authVal != "" && authVal != "Bearer " {
t.Fatalf("Authorization should not be set, got: %q", authVal)
}
}
func TestReverseProxy_StripsClientCredentialsFromHeadersAndQuery(t *testing.T) {
type captured struct {
headers http.Header
query string
}
got := make(chan captured, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
got <- captured{headers: r.Header.Clone(), query: r.URL.RawQuery}
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("upstream"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "client-key")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
req, err := http.NewRequest(http.MethodGet, srv.URL+"/test?key=client-key&key=keep&auth_token=client-key&foo=bar", nil)
if err != nil {
t.Fatal(err)
}
req.Header.Set("Authorization", "Bearer client-key")
req.Header.Set("X-Api-Key", "client-key")
req.Header.Set("X-Goog-Api-Key", "client-key")
res, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatal(err)
}
res.Body.Close()
c := <-got
// These are client-provided credentials and must not reach the upstream.
if v := c.headers.Get("X-Goog-Api-Key"); v != "" {
t.Fatalf("X-Goog-Api-Key should be stripped, got: %q", v)
}
// We inject upstream Authorization/X-Api-Key, so the client auth must not survive.
if v := c.headers.Get("Authorization"); v != "Bearer upstream" {
t.Fatalf("Authorization should be upstream-injected, got: %q", v)
}
if v := c.headers.Get("X-Api-Key"); v != "upstream" {
t.Fatalf("X-Api-Key should be upstream-injected, got: %q", v)
}
// Query-based credentials should be stripped only when they match the authenticated client key.
// Should keep unrelated values and parameters.
if strings.Contains(c.query, "auth_token=client-key") || strings.Contains(c.query, "key=client-key") {
t.Fatalf("query credentials should be stripped, got raw query: %q", c.query)
}
if !strings.Contains(c.query, "key=keep") || !strings.Contains(c.query, "foo=bar") {
t.Fatalf("expected query to keep non-credential params, got raw query: %q", c.query)
}
}
func TestReverseProxy_InjectsMappedSecret_FromRequestContext(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Simulate clientAPIKeyMiddleware injection (per-request)
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k1")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "u1" {
t.Fatalf("X-Api-Key missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer u1" {
t.Fatalf("Authorization missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_MappedSecret_FallsBackToDefault(t *testing.T) {
gotHeaders := make(chan http.Header, 1)
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotHeaders <- r.Header.Clone()
w.WriteHeader(200)
w.Write([]byte(`ok`))
}))
defer upstream.Close()
defaultSource := NewStaticSecretSource("default")
mapped := NewMappedSecretSource(defaultSource)
mapped.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
proxy, err := createReverseProxy(upstream.URL, mapped)
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.WithValue(r.Context(), clientAPIKeyContextKey{}, "k2")
proxy.ServeHTTP(w, r.WithContext(ctx))
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
res.Body.Close()
hdr := <-gotHeaders
if hdr.Get("X-Api-Key") != "default" {
t.Fatalf("X-Api-Key fallback missing or wrong, got: %q", hdr.Get("X-Api-Key"))
}
if hdr.Get("Authorization") != "Bearer default" {
t.Fatalf("Authorization fallback missing or wrong, got: %q", hdr.Get("Authorization"))
}
}
func TestReverseProxy_ErrorHandler(t *testing.T) {
// Point proxy to a non-routable address to trigger error
proxy, err := createReverseProxy("http://127.0.0.1:1", NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/any")
if err != nil {
t.Fatal(err)
}
body, _ := io.ReadAll(res.Body)
res.Body.Close()
if res.StatusCode != http.StatusBadGateway {
t.Fatalf("want 502, got %d", res.StatusCode)
}
if !bytes.Contains(body, []byte(`"amp_upstream_proxy_error"`)) {
t.Fatalf("unexpected body: %s", body)
}
if ct := res.Header.Get("Content-Type"); ct != "application/json" {
t.Fatalf("content-type: want application/json, got %s", ct)
}
}
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
// Test that context.Canceled errors return 499 without generic error response
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
if err != nil {
t.Fatal(err)
}
// Create a canceled context to trigger the cancellation path
ctx, cancel := context.WithCancel(context.Background())
cancel() // Cancel immediately
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
rr := httptest.NewRecorder()
// Directly invoke the ErrorHandler with context.Canceled
proxy.ErrorHandler(rr, req, context.Canceled)
// Body should be empty for canceled requests (no JSON error response)
body := rr.Body.Bytes()
if len(body) > 0 {
t.Fatalf("expected empty body for canceled context, got: %s", body)
}
}
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
// Upstream returns gzipped JSON without Content-Encoding header
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(200)
w.Write(gzipBytes([]byte(`{"upstream":"ok"}`)))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
body, _ := io.ReadAll(res.Body)
res.Body.Close()
expected := []byte(`{"upstream":"ok"}`)
if !bytes.Equal(body, expected) {
t.Fatalf("want decompressed JSON, got: %s", body)
}
}
func TestReverseProxy_FullRoundTrip_PlainJSON(t *testing.T) {
// Upstream returns plain JSON
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
w.Write([]byte(`{"plain":"json"}`))
}))
defer upstream.Close()
proxy, err := createReverseProxy(upstream.URL, NewStaticSecretSource("key"))
if err != nil {
t.Fatal(err)
}
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxy.ServeHTTP(w, r)
}))
defer srv.Close()
res, err := http.Get(srv.URL + "/test")
if err != nil {
t.Fatal(err)
}
body, _ := io.ReadAll(res.Body)
res.Body.Close()
expected := []byte(`{"plain":"json"}`)
if !bytes.Equal(body, expected) {
t.Fatalf("want plain JSON unchanged, got: %s", body)
}
}
func TestIsStreamingResponse(t *testing.T) {
cases := []struct {
name string
header http.Header
want bool
}{
{
name: "sse",
header: http.Header{"Content-Type": []string{"text/event-stream"}},
want: true,
},
{
name: "chunked_not_streaming",
header: http.Header{"Transfer-Encoding": []string{"chunked"}},
want: false, // Chunked is transport-level, not streaming
},
{
name: "normal_json",
header: http.Header{"Content-Type": []string{"application/json"}},
want: false,
},
{
name: "empty",
header: http.Header{},
want: false,
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
resp := &http.Response{Header: tc.header}
got := isStreamingResponse(resp)
if got != tc.want {
t.Fatalf("want %v, got %v", tc.want, got)
}
})
}
}
func TestFilterBetaFeatures(t *testing.T) {
tests := []struct {
name string
header string
featureToRemove string
expected string
}{
{
name: "Remove context-1m from middle",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
{
name: "Remove context-1m from start",
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14",
},
{
name: "Remove context-1m from end",
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14",
},
{
name: "Feature not present",
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
{
name: "Only feature to remove",
header: "context-1m-2025-08-07",
featureToRemove: "context-1m-2025-08-07",
expected: "",
},
{
name: "Empty header",
header: "",
featureToRemove: "context-1m-2025-08-07",
expected: "",
},
{
name: "Header with spaces",
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
featureToRemove: "context-1m-2025-08-07",
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := filterBetaFeatures(tt.header, tt.featureToRemove)
if result != tt.expected {
t.Errorf("filterBetaFeatures() = %q, want %q", result, tt.expected)
}
})
}
}
================================================
FILE: internal/api/modules/amp/response_rewriter.go
================================================
package amp
import (
"bytes"
"net/http"
"strings"
"github.com/gin-gonic/gin"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ResponseRewriter wraps a gin.ResponseWriter to intercept and modify the response body
// It's used to rewrite model names in responses when model mapping is used
type ResponseRewriter struct {
gin.ResponseWriter
body *bytes.Buffer
originalModel string
isStreaming bool
}
// NewResponseRewriter creates a new response rewriter for model name substitution
func NewResponseRewriter(w gin.ResponseWriter, originalModel string) *ResponseRewriter {
return &ResponseRewriter{
ResponseWriter: w,
body: &bytes.Buffer{},
originalModel: originalModel,
}
}
// Write intercepts response writes and buffers them for model name replacement
func (rw *ResponseRewriter) Write(data []byte) (int, error) {
// Detect streaming on first write
if rw.body.Len() == 0 && !rw.isStreaming {
contentType := rw.Header().Get("Content-Type")
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
strings.Contains(contentType, "stream")
}
if rw.isStreaming {
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
if err == nil {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
return n, err
}
return rw.body.Write(data)
}
// Flush writes the buffered response with model names rewritten
func (rw *ResponseRewriter) Flush() {
if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
return
}
if rw.body.Len() > 0 {
if _, err := rw.ResponseWriter.Write(rw.rewriteModelInResponse(rw.body.Bytes())); err != nil {
log.Warnf("amp response rewriter: failed to write rewritten response: %v", err)
}
}
}
// modelFieldPaths lists all JSON paths where model name may appear
var modelFieldPaths = []string{"message.model", "model", "modelVersion", "response.model", "response.modelVersion"}
// rewriteModelInResponse replaces all occurrences of the mapped model with the original model in JSON
// It also suppresses "thinking" blocks if "tool_use" is present to ensure Amp client compatibility
func (rw *ResponseRewriter) rewriteModelInResponse(data []byte) []byte {
// 1. Amp Compatibility: Suppress thinking blocks if tool use is detected
// The Amp client struggles when both thinking and tool_use blocks are present
if gjson.GetBytes(data, `content.#(type=="tool_use")`).Exists() {
filtered := gjson.GetBytes(data, `content.#(type!="thinking")#`)
if filtered.Exists() {
originalCount := gjson.GetBytes(data, "content.#").Int()
filteredCount := filtered.Get("#").Int()
if originalCount > filteredCount {
var err error
data, err = sjson.SetBytes(data, "content", filtered.Value())
if err != nil {
log.Warnf("Amp ResponseRewriter: failed to suppress thinking blocks: %v", err)
} else {
log.Debugf("Amp ResponseRewriter: Suppressed %d thinking blocks due to tool usage", originalCount-filteredCount)
// Log the result for verification
log.Debugf("Amp ResponseRewriter: Resulting content: %s", gjson.GetBytes(data, "content").String())
}
}
}
}
if rw.originalModel == "" {
return data
}
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.originalModel)
}
}
return data
}
// rewriteStreamChunk rewrites model names in SSE stream chunks
func (rw *ResponseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.originalModel == "" {
return chunk
}
// SSE format: "data: {json}\n\n"
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
jsonData := bytes.TrimPrefix(line, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
// Rewrite JSON in the data line
rewritten := rw.rewriteModelInResponse(jsonData)
lines[i] = append([]byte("data: "), rewritten...)
}
}
}
return bytes.Join(lines, []byte("\n"))
}
================================================
FILE: internal/api/modules/amp/response_rewriter_test.go
================================================
package amp
import (
"testing"
)
func TestRewriteModelInResponse_TopLevel(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
input := []byte(`{"id":"resp_1","model":"gpt-5.3-codex","output":[]}`)
result := rw.rewriteModelInResponse(input)
expected := `{"id":"resp_1","model":"gpt-5.2-codex","output":[]}`
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func TestRewriteModelInResponse_ResponseModel(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
input := []byte(`{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"completed"}}`)
result := rw.rewriteModelInResponse(input)
expected := `{"type":"response.completed","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"completed"}}`
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func TestRewriteModelInResponse_ResponseCreated(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
input := []byte(`{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.3-codex","status":"in_progress"}}`)
result := rw.rewriteModelInResponse(input)
expected := `{"type":"response.created","response":{"id":"resp_1","model":"gpt-5.2-codex","status":"in_progress"}}`
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func TestRewriteModelInResponse_NoModelField(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
input := []byte(`{"type":"response.output_item.added","item":{"id":"item_1","type":"message"}}`)
result := rw.rewriteModelInResponse(input)
if string(result) != string(input) {
t.Errorf("expected no modification, got %s", string(result))
}
}
func TestRewriteModelInResponse_EmptyOriginalModel(t *testing.T) {
rw := &ResponseRewriter{originalModel: ""}
input := []byte(`{"model":"gpt-5.3-codex"}`)
result := rw.rewriteModelInResponse(input)
if string(result) != string(input) {
t.Errorf("expected no modification when originalModel is empty, got %s", string(result))
}
}
func TestRewriteStreamChunk_SSEWithResponseModel(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
chunk := []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.3-codex\",\"status\":\"completed\"}}\n\n")
result := rw.rewriteStreamChunk(chunk)
expected := "data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp_1\",\"model\":\"gpt-5.2-codex\",\"status\":\"completed\"}}\n\n"
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func TestRewriteStreamChunk_MultipleEvents(t *testing.T) {
rw := &ResponseRewriter{originalModel: "gpt-5.2-codex"}
chunk := []byte("data: {\"type\":\"response.created\",\"response\":{\"model\":\"gpt-5.3-codex\"}}\n\ndata: {\"type\":\"response.output_item.added\",\"item\":{\"id\":\"item_1\"}}\n\n")
result := rw.rewriteStreamChunk(chunk)
if string(result) == string(chunk) {
t.Error("expected response.model to be rewritten in SSE stream")
}
if !contains(result, []byte(`"model":"gpt-5.2-codex"`)) {
t.Errorf("expected rewritten model in output, got %s", string(result))
}
}
func TestRewriteStreamChunk_MessageModel(t *testing.T) {
rw := &ResponseRewriter{originalModel: "claude-opus-4.5"}
chunk := []byte("data: {\"message\":{\"model\":\"claude-sonnet-4\",\"role\":\"assistant\"}}\n\n")
result := rw.rewriteStreamChunk(chunk)
expected := "data: {\"message\":{\"model\":\"claude-opus-4.5\",\"role\":\"assistant\"}}\n\n"
if string(result) != expected {
t.Errorf("expected %s, got %s", expected, string(result))
}
}
func contains(data, substr []byte) bool {
for i := 0; i <= len(data)-len(substr); i++ {
if string(data[i:i+len(substr)]) == string(substr) {
return true
}
}
return false
}
================================================
FILE: internal/api/modules/amp/routes.go
================================================
package amp
import (
"context"
"errors"
"net"
"net/http"
"net/http/httputil"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
log "github.com/sirupsen/logrus"
)
// clientAPIKeyContextKey is the context key used to pass the client API key
// from gin.Context to the request context for SecretSource lookup.
type clientAPIKeyContextKey struct{}
// clientAPIKeyMiddleware injects the authenticated client API key from gin.Context["apiKey"]
// into the request context so that SecretSource can look it up for per-client upstream routing.
func clientAPIKeyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Extract the client API key from gin context (set by AuthMiddleware)
if apiKey, exists := c.Get("apiKey"); exists {
if keyStr, ok := apiKey.(string); ok && keyStr != "" {
// Inject into request context for SecretSource.Get(ctx) to read
ctx := context.WithValue(c.Request.Context(), clientAPIKeyContextKey{}, keyStr)
c.Request = c.Request.WithContext(ctx)
}
}
c.Next()
}
}
// getClientAPIKeyFromContext retrieves the client API key from request context.
// Returns empty string if not present.
func getClientAPIKeyFromContext(ctx context.Context) string {
if val := ctx.Value(clientAPIKeyContextKey{}); val != nil {
if keyStr, ok := val.(string); ok {
return keyStr
}
}
return ""
}
// localhostOnlyMiddleware returns a middleware that dynamically checks the module's
// localhost restriction setting. This allows hot-reload of the restriction without restarting.
func (m *AmpModule) localhostOnlyMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Check current setting (hot-reloadable)
if !m.IsRestrictedToLocalhost() {
c.Next()
return
}
// Use actual TCP connection address (RemoteAddr) to prevent header spoofing
// This cannot be forged by X-Forwarded-For or other client-controlled headers
remoteAddr := c.Request.RemoteAddr
// RemoteAddr format is "IP:port" or "[IPv6]:port", extract just the IP
host, _, err := net.SplitHostPort(remoteAddr)
if err != nil {
// Try parsing as raw IP (shouldn't happen with standard HTTP, but be defensive)
host = remoteAddr
}
// Parse the IP to handle both IPv4 and IPv6
ip := net.ParseIP(host)
if ip == nil {
log.Warnf("amp management: invalid RemoteAddr %s, denying access", remoteAddr)
c.AbortWithStatusJSON(403, gin.H{
"error": "Access denied: management routes restricted to localhost",
})
return
}
// Check if IP is loopback (127.0.0.1 or ::1)
if !ip.IsLoopback() {
log.Warnf("amp management: non-localhost connection from %s attempted access, denying", remoteAddr)
c.AbortWithStatusJSON(403, gin.H{
"error": "Access denied: management routes restricted to localhost",
})
return
}
c.Next()
}
}
// noCORSMiddleware disables CORS for management routes to prevent browser-based attacks.
// This overwrites any global CORS headers set by the server.
func noCORSMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
// Remove CORS headers to prevent cross-origin access from browsers
c.Header("Access-Control-Allow-Origin", "")
c.Header("Access-Control-Allow-Methods", "")
c.Header("Access-Control-Allow-Headers", "")
c.Header("Access-Control-Allow-Credentials", "")
// For OPTIONS preflight, deny with 403
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(403)
return
}
c.Next()
}
}
// managementAvailabilityMiddleware short-circuits management routes when the upstream
// proxy is disabled, preventing noisy localhost warnings and accidental exposure.
func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if m.getProxy() == nil {
logging.SkipGinRequestLogging(c)
c.AbortWithStatusJSON(http.StatusServiceUnavailable, gin.H{
"error": "amp upstream proxy not available",
})
return
}
c.Next()
}
}
// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
for _, prefix := range prefixes {
if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
c.Next()
return
}
}
auth(c)
}
}
// registerManagementRoutes registers Amp management proxy routes
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
// Uses dynamic middleware and proxy getter for hot-reload support.
// The auth middleware validates Authorization header against configured API keys.
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
ampAPI := engine.Group("/api")
// Always disable CORS for management routes to prevent browser-based attacks
ampAPI.Use(m.managementAvailabilityMiddleware(), noCORSMiddleware())
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
ampAPI.Use(m.localhostOnlyMiddleware())
// Apply authentication middleware - requires valid API key in Authorization header
var authWithBypass gin.HandlerFunc
if auth != nil {
ampAPI.Use(auth)
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
}
// Inject client API key into request context for per-client upstream routing
ampAPI.Use(clientAPIKeyMiddleware())
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
proxyHandler := func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
defer func() {
if rec := recover(); rec != nil {
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
// Upstream already wrote the status (often 404) before the client/stream ended.
return
}
panic(rec)
}
}()
proxy := m.getProxy()
if proxy == nil {
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
return
}
proxy.ServeHTTP(c.Writer, c.Request)
}
// Management routes - these are proxied directly to Amp upstream
ampAPI.Any("/internal", proxyHandler)
ampAPI.Any("/internal/*path", proxyHandler)
ampAPI.Any("/user", proxyHandler)
ampAPI.Any("/user/*path", proxyHandler)
ampAPI.Any("/auth", proxyHandler)
ampAPI.Any("/auth/*path", proxyHandler)
ampAPI.Any("/meta", proxyHandler)
ampAPI.Any("/meta/*path", proxyHandler)
ampAPI.Any("/ads", proxyHandler)
ampAPI.Any("/telemetry", proxyHandler)
ampAPI.Any("/telemetry/*path", proxyHandler)
ampAPI.Any("/threads", proxyHandler)
ampAPI.Any("/threads/*path", proxyHandler)
ampAPI.Any("/otel", proxyHandler)
ampAPI.Any("/otel/*path", proxyHandler)
ampAPI.Any("/tab", proxyHandler)
ampAPI.Any("/tab/*path", proxyHandler)
// Root-level routes that AMP CLI expects without /api prefix
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
if authWithBypass != nil {
rootMiddleware = append(rootMiddleware, authWithBypass)
}
// Add clientAPIKeyMiddleware after auth for per-client upstream routing
rootMiddleware = append(rootMiddleware, clientAPIKeyMiddleware())
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/settings", append(rootMiddleware, proxyHandler)...)
engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
// Root-level auth routes for CLI login flow
// Amp uses multiple auth routes: /auth/cli-login, /auth/callback, /auth/sign-in, /auth/logout
// We proxy all /auth/* to support the complete OAuth flow
engine.Any("/auth", append(rootMiddleware, proxyHandler)...)
engine.Any("/auth/*path", append(rootMiddleware, proxyHandler)...)
// Google v1beta1 passthrough with OAuth fallback
// AMP CLI uses non-standard paths like /publishers/google/models/...
// We bridge these to our standard Gemini handler to enable local OAuth.
// If no local OAuth is available, falls back to ampcode.com proxy.
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper, m.forceModelMappings)
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
// Route POST model calls through Gemini bridge with FallbackHandler.
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
// POST with /models/ path -> use Gemini bridge with fallback handler
// FallbackHandler will check provider/mapping and proxy if needed
geminiV1Beta1Handler(c)
return
}
}
// Non-POST or no local provider available -> proxy upstream
proxyHandler(c)
})
}
// registerProviderAliases registers /api/provider/{provider}/... routes
// These allow Amp CLI to route requests like:
//
// /api/provider/openai/v1/chat/completions
// /api/provider/anthropic/v1/messages
// /api/provider/google/v1beta/models
func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
// Create handler instances for different providers
openaiHandlers := openai.NewOpenAIAPIHandler(baseHandler)
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
// Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper, m.forceModelMappings)
// Provider-specific routes under /api/provider/:provider
ampProviders := engine.Group("/api/provider")
if auth != nil {
ampProviders.Use(auth)
}
// Inject client API key into request context for per-client upstream routing
ampProviders.Use(clientAPIKeyMiddleware())
provider := ampProviders.Group("/:provider")
// Dynamic models handler - routes to appropriate provider based on path parameter
ampModelsHandler := func(c *gin.Context) {
providerName := strings.ToLower(c.Param("provider"))
switch providerName {
case "anthropic":
claudeCodeHandlers.ClaudeModels(c)
case "google":
geminiHandlers.GeminiModels(c)
default:
// Default to OpenAI-compatible (works for openai, groq, cerebras, etc.)
openaiHandlers.OpenAIModels(c)
}
}
// Root-level routes (for providers that omit /v1, like groq/cerebras)
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
// /v1 routes (OpenAI/Claude-compatible endpoints)
v1Amp := provider.Group("/v1")
{
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
// OpenAI-compatible endpoints with fallback
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
// Claude/Anthropic-compatible endpoints with fallback
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
}
// /v1beta routes (Gemini native API)
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
v1betaAmp := provider.Group("/v1beta")
{
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
}
}
================================================
FILE: internal/api/modules/amp/routes_test.go
================================================
package amp
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
)
func TestRegisterManagementRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Create module with proxy for testing
m := &AmpModule{
restrictToLocalhost: false, // disable localhost restriction for tests
}
// Create a mock proxy that tracks calls
proxyCalled := false
mockProxy := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
proxyCalled = true
w.WriteHeader(200)
w.Write([]byte("proxied"))
}))
defer mockProxy.Close()
// Create real proxy to mock server
proxy, _ := createReverseProxy(mockProxy.URL, NewStaticSecretSource(""))
m.setProxy(proxy)
base := &handlers.BaseAPIHandler{}
m.registerManagementRoutes(r, base, nil)
srv := httptest.NewServer(r)
defer srv.Close()
managementPaths := []struct {
path string
method string
}{
{"/api/internal", http.MethodGet},
{"/api/internal/some/path", http.MethodGet},
{"/api/user", http.MethodGet},
{"/api/user/profile", http.MethodGet},
{"/api/auth", http.MethodGet},
{"/api/auth/login", http.MethodGet},
{"/api/meta", http.MethodGet},
{"/api/telemetry", http.MethodGet},
{"/api/threads", http.MethodGet},
{"/threads/", http.MethodGet},
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
{"/api/otel", http.MethodGet},
{"/api/tab", http.MethodGet},
{"/api/tab/some/path", http.MethodGet},
{"/auth", http.MethodGet}, // Root-level auth route
{"/auth/cli-login", http.MethodGet}, // CLI login flow
{"/auth/callback", http.MethodGet}, // OAuth callback
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
{"/api/provider/google/v1beta1/models", http.MethodGet},
{"/api/provider/google/v1beta1/models", http.MethodPost},
}
for _, path := range managementPaths {
t.Run(path.path, func(t *testing.T) {
proxyCalled = false
req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
if err != nil {
t.Fatalf("failed to build request: %v", err)
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("request failed: %v", err)
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotFound {
t.Fatalf("route %s not registered", path.path)
}
if !proxyCalled {
t.Fatalf("proxy handler not called for %s", path.path)
}
})
}
}
func TestRegisterProviderAliases_AllProvidersRegistered(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Minimal base handler setup (no need to initialize, just check routing)
base := &handlers.BaseAPIHandler{}
// Track if auth middleware was called
authCalled := false
authMiddleware := func(c *gin.Context) {
authCalled = true
c.Header("X-Auth", "ok")
// Abort with success to avoid calling the actual handler (which needs full setup)
c.AbortWithStatus(http.StatusOK)
}
m := &AmpModule{authMiddleware_: authMiddleware}
m.registerProviderAliases(r, base, authMiddleware)
paths := []struct {
path string
method string
}{
{"/api/provider/openai/models", http.MethodGet},
{"/api/provider/anthropic/models", http.MethodGet},
{"/api/provider/google/models", http.MethodGet},
{"/api/provider/groq/models", http.MethodGet},
{"/api/provider/openai/chat/completions", http.MethodPost},
{"/api/provider/anthropic/v1/messages", http.MethodPost},
{"/api/provider/google/v1beta/models", http.MethodGet},
}
for _, tc := range paths {
t.Run(tc.path, func(t *testing.T) {
authCalled = false
req := httptest.NewRequest(tc.method, tc.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == http.StatusNotFound {
t.Fatalf("route %s %s not registered", tc.method, tc.path)
}
if !authCalled {
t.Fatalf("auth middleware not executed for %s", tc.path)
}
if w.Header().Get("X-Auth") != "ok" {
t.Fatalf("auth middleware header not set for %s", tc.path)
}
})
}
}
func TestRegisterProviderAliases_DynamicModelsHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
base := &handlers.BaseAPIHandler{}
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
providers := []string{"openai", "anthropic", "google", "groq", "cerebras"}
for _, provider := range providers {
t.Run(provider, func(t *testing.T) {
path := "/api/provider/" + provider + "/models"
req := httptest.NewRequest(http.MethodGet, path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Should not 404
if w.Code == http.StatusNotFound {
t.Fatalf("models route not found for provider: %s", provider)
}
})
}
}
func TestRegisterProviderAliases_V1Routes(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
base := &handlers.BaseAPIHandler{}
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
v1Paths := []struct {
path string
method string
}{
{"/api/provider/openai/v1/models", http.MethodGet},
{"/api/provider/openai/v1/chat/completions", http.MethodPost},
{"/api/provider/openai/v1/completions", http.MethodPost},
{"/api/provider/anthropic/v1/messages", http.MethodPost},
{"/api/provider/anthropic/v1/messages/count_tokens", http.MethodPost},
}
for _, tc := range v1Paths {
t.Run(tc.path, func(t *testing.T) {
req := httptest.NewRequest(tc.method, tc.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == http.StatusNotFound {
t.Fatalf("v1 route %s %s not registered", tc.method, tc.path)
}
})
}
}
func TestRegisterProviderAliases_V1BetaRoutes(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
base := &handlers.BaseAPIHandler{}
m := &AmpModule{authMiddleware_: func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) }}
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
v1betaPaths := []struct {
path string
method string
}{
{"/api/provider/google/v1beta/models", http.MethodGet},
{"/api/provider/google/v1beta/models/generateContent", http.MethodPost},
}
for _, tc := range v1betaPaths {
t.Run(tc.path, func(t *testing.T) {
req := httptest.NewRequest(tc.method, tc.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code == http.StatusNotFound {
t.Fatalf("v1beta route %s %s not registered", tc.method, tc.path)
}
})
}
}
func TestRegisterProviderAliases_NoAuthMiddleware(t *testing.T) {
// Test that routes still register even if auth middleware is nil (fallback behavior)
gin.SetMode(gin.TestMode)
r := gin.New()
base := &handlers.BaseAPIHandler{}
m := &AmpModule{authMiddleware_: nil} // No auth middleware
m.registerProviderAliases(r, base, func(c *gin.Context) { c.AbortWithStatus(http.StatusOK) })
req := httptest.NewRequest(http.MethodGet, "/api/provider/openai/models", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Should still work (with fallback no-op auth)
if w.Code == http.StatusNotFound {
t.Fatal("routes should register even without auth middleware")
}
}
func TestLocalhostOnlyMiddleware_PreventsSpoofing(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Create module with localhost restriction enabled
m := &AmpModule{
restrictToLocalhost: true,
}
// Apply dynamic localhost-only middleware
r.Use(m.localhostOnlyMiddleware())
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
tests := []struct {
name string
remoteAddr string
forwardedFor string
expectedStatus int
description string
}{
{
name: "spoofed_header_remote_connection",
remoteAddr: "192.168.1.100:12345",
forwardedFor: "127.0.0.1",
expectedStatus: http.StatusForbidden,
description: "Spoofed X-Forwarded-For header should be ignored",
},
{
name: "real_localhost_ipv4",
remoteAddr: "127.0.0.1:54321",
forwardedFor: "",
expectedStatus: http.StatusOK,
description: "Real localhost IPv4 connection should work",
},
{
name: "real_localhost_ipv6",
remoteAddr: "[::1]:54321",
forwardedFor: "",
expectedStatus: http.StatusOK,
description: "Real localhost IPv6 connection should work",
},
{
name: "remote_ipv4",
remoteAddr: "203.0.113.42:8080",
forwardedFor: "",
expectedStatus: http.StatusForbidden,
description: "Remote IPv4 connection should be blocked",
},
{
name: "remote_ipv6",
remoteAddr: "[2001:db8::1]:9090",
forwardedFor: "",
expectedStatus: http.StatusForbidden,
description: "Remote IPv6 connection should be blocked",
},
{
name: "spoofed_localhost_ipv6",
remoteAddr: "203.0.113.42:8080",
forwardedFor: "::1",
expectedStatus: http.StatusForbidden,
description: "Spoofed X-Forwarded-For with IPv6 localhost should be ignored",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = tt.remoteAddr
if tt.forwardedFor != "" {
req.Header.Set("X-Forwarded-For", tt.forwardedFor)
}
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != tt.expectedStatus {
t.Errorf("%s: expected status %d, got %d", tt.description, tt.expectedStatus, w.Code)
}
})
}
}
func TestLocalhostOnlyMiddleware_HotReload(t *testing.T) {
gin.SetMode(gin.TestMode)
r := gin.New()
// Create module with localhost restriction initially enabled
m := &AmpModule{
restrictToLocalhost: true,
}
// Apply dynamic localhost-only middleware
r.Use(m.localhostOnlyMiddleware())
r.GET("/test", func(c *gin.Context) {
c.String(http.StatusOK, "ok")
})
// Test 1: Remote IP should be blocked when restriction is enabled
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Expected 403 when restriction enabled, got %d", w.Code)
}
// Test 2: Hot-reload - disable restriction
m.setRestrictToLocalhost(false)
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Errorf("Expected 200 after disabling restriction, got %d", w.Code)
}
// Test 3: Hot-reload - re-enable restriction
m.setRestrictToLocalhost(true)
req = httptest.NewRequest(http.MethodGet, "/test", nil)
req.RemoteAddr = "192.168.1.100:12345"
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusForbidden {
t.Errorf("Expected 403 after re-enabling restriction, got %d", w.Code)
}
}
================================================
FILE: internal/api/modules/amp/secret.go
================================================
package amp
import (
"context"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
// SecretSource provides Amp API keys with configurable precedence and caching
type SecretSource interface {
Get(ctx context.Context) (string, error)
}
// cachedSecret holds a secret value with expiration
type cachedSecret struct {
value string
expiresAt time.Time
}
// MultiSourceSecret implements precedence-based secret lookup:
// 1. Explicit config value (highest priority)
// 2. Environment variable AMP_API_KEY
// 3. File-based secret (lowest priority)
type MultiSourceSecret struct {
explicitKey string
envKey string
filePath string
cacheTTL time.Duration
mu sync.RWMutex
cache *cachedSecret
}
// NewMultiSourceSecret creates a secret source with precedence and caching
func NewMultiSourceSecret(explicitKey string, cacheTTL time.Duration) *MultiSourceSecret {
if cacheTTL == 0 {
cacheTTL = 5 * time.Minute // Default 5 minute cache
}
home, _ := os.UserHomeDir()
filePath := filepath.Join(home, ".local", "share", "amp", "secrets.json")
return &MultiSourceSecret{
explicitKey: strings.TrimSpace(explicitKey),
envKey: "AMP_API_KEY",
filePath: filePath,
cacheTTL: cacheTTL,
}
}
// NewMultiSourceSecretWithPath creates a secret source with a custom file path (for testing)
func NewMultiSourceSecretWithPath(explicitKey string, filePath string, cacheTTL time.Duration) *MultiSourceSecret {
if cacheTTL == 0 {
cacheTTL = 5 * time.Minute
}
return &MultiSourceSecret{
explicitKey: strings.TrimSpace(explicitKey),
envKey: "AMP_API_KEY",
filePath: filePath,
cacheTTL: cacheTTL,
}
}
// Get retrieves the Amp API key using precedence: config > env > file
// Results are cached for cacheTTL duration to avoid excessive file reads
func (s *MultiSourceSecret) Get(ctx context.Context) (string, error) {
// Precedence 1: Explicit config key (highest priority, no caching needed)
if s.explicitKey != "" {
return s.explicitKey, nil
}
// Precedence 2: Environment variable
if envValue := strings.TrimSpace(os.Getenv(s.envKey)); envValue != "" {
return envValue, nil
}
// Precedence 3: File-based secret (lowest priority, cached)
// Check cache first
s.mu.RLock()
if s.cache != nil && time.Now().Before(s.cache.expiresAt) {
value := s.cache.value
s.mu.RUnlock()
return value, nil
}
s.mu.RUnlock()
// Cache miss or expired - read from file
key, err := s.readFromFile()
if err != nil {
// Cache empty result to avoid repeated file reads on missing files
s.updateCache("")
return "", err
}
// Cache the result
s.updateCache(key)
return key, nil
}
// readFromFile reads the Amp API key from the secrets file
func (s *MultiSourceSecret) readFromFile() (string, error) {
content, err := os.ReadFile(s.filePath)
if err != nil {
if os.IsNotExist(err) {
return "", nil // Missing file is not an error, just no key available
}
return "", fmt.Errorf("failed to read amp secrets from %s: %w", s.filePath, err)
}
var secrets map[string]string
if err := json.Unmarshal(content, &secrets); err != nil {
return "", fmt.Errorf("failed to parse amp secrets from %s: %w", s.filePath, err)
}
key := strings.TrimSpace(secrets["apiKey@https://ampcode.com/"])
return key, nil
}
// updateCache updates the cached secret value
func (s *MultiSourceSecret) updateCache(value string) {
s.mu.Lock()
defer s.mu.Unlock()
s.cache = &cachedSecret{
value: value,
expiresAt: time.Now().Add(s.cacheTTL),
}
}
// InvalidateCache clears the cached secret, forcing a fresh read on next Get
func (s *MultiSourceSecret) InvalidateCache() {
s.mu.Lock()
defer s.mu.Unlock()
s.cache = nil
}
// UpdateExplicitKey refreshes the config-provided key and clears cache.
func (s *MultiSourceSecret) UpdateExplicitKey(key string) {
if s == nil {
return
}
s.mu.Lock()
s.explicitKey = strings.TrimSpace(key)
s.cache = nil
s.mu.Unlock()
}
// StaticSecretSource returns a fixed API key (for testing)
type StaticSecretSource struct {
key string
}
// NewStaticSecretSource creates a secret source with a fixed key
func NewStaticSecretSource(key string) *StaticSecretSource {
return &StaticSecretSource{key: strings.TrimSpace(key)}
}
// Get returns the static API key
func (s *StaticSecretSource) Get(ctx context.Context) (string, error) {
return s.key, nil
}
// MappedSecretSource wraps a default SecretSource and adds per-client API key mapping.
// When a request context contains a client API key that matches a configured mapping,
// the corresponding upstream key is returned. Otherwise, falls back to the default source.
type MappedSecretSource struct {
defaultSource SecretSource
mu sync.RWMutex
lookup map[string]string // clientKey -> upstreamKey
}
// NewMappedSecretSource creates a MappedSecretSource wrapping the given default source.
func NewMappedSecretSource(defaultSource SecretSource) *MappedSecretSource {
return &MappedSecretSource{
defaultSource: defaultSource,
lookup: make(map[string]string),
}
}
// Get retrieves the Amp API key, checking per-client mappings first.
// If the request context contains a client API key that matches a configured mapping,
// returns the corresponding upstream key. Otherwise, falls back to the default source.
func (s *MappedSecretSource) Get(ctx context.Context) (string, error) {
// Try to get client API key from request context
clientKey := getClientAPIKeyFromContext(ctx)
if clientKey != "" {
s.mu.RLock()
if upstreamKey, ok := s.lookup[clientKey]; ok && upstreamKey != "" {
s.mu.RUnlock()
return upstreamKey, nil
}
s.mu.RUnlock()
}
// Fall back to default source
return s.defaultSource.Get(ctx)
}
// UpdateMappings rebuilds the client-to-upstream key mapping from configuration entries.
// If the same client key appears in multiple entries, logs a warning and uses the first one.
func (s *MappedSecretSource) UpdateMappings(entries []config.AmpUpstreamAPIKeyEntry) {
newLookup := make(map[string]string)
for _, entry := range entries {
upstreamKey := strings.TrimSpace(entry.UpstreamAPIKey)
if upstreamKey == "" {
continue
}
for _, clientKey := range entry.APIKeys {
trimmedKey := strings.TrimSpace(clientKey)
if trimmedKey == "" {
continue
}
if _, exists := newLookup[trimmedKey]; exists {
// Log warning for duplicate client key, first one wins
log.Warnf("amp upstream-api-keys: client API key appears in multiple entries; using first mapping.")
continue
}
newLookup[trimmedKey] = upstreamKey
}
}
s.mu.Lock()
s.lookup = newLookup
s.mu.Unlock()
}
// UpdateDefaultExplicitKey updates the explicit key on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) UpdateDefaultExplicitKey(key string) {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.UpdateExplicitKey(key)
}
}
// InvalidateCache invalidates cache on the underlying MultiSourceSecret (if applicable).
func (s *MappedSecretSource) InvalidateCache() {
if ms, ok := s.defaultSource.(*MultiSourceSecret); ok {
ms.InvalidateCache()
}
}
================================================
FILE: internal/api/modules/amp/secret_test.go
================================================
package amp
import (
"context"
"encoding/json"
"os"
"path/filepath"
"sync"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
"github.com/sirupsen/logrus/hooks/test"
)
func TestMultiSourceSecret_PrecedenceOrder(t *testing.T) {
ctx := context.Background()
cases := []struct {
name string
configKey string
envKey string
fileJSON string
want string
}{
{"config_wins", "cfg", "env", `{"apiKey@https://ampcode.com/":"file"}`, "cfg"},
{"env_wins_when_no_cfg", "", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
{"file_when_no_cfg_env", "", "", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
{"empty_cfg_trims_then_env", " ", "env", `{"apiKey@https://ampcode.com/":"file"}`, "env"},
{"empty_env_then_file", "", " ", `{"apiKey@https://ampcode.com/":"file"}`, "file"},
{"missing_file_returns_empty", "", "", "", ""},
{"all_empty_returns_empty", " ", " ", `{"apiKey@https://ampcode.com/":" "}`, ""},
}
for _, tc := range cases {
tc := tc // capture range variable
t.Run(tc.name, func(t *testing.T) {
tmpDir := t.TempDir()
secretsPath := filepath.Join(tmpDir, "secrets.json")
if tc.fileJSON != "" {
if err := os.WriteFile(secretsPath, []byte(tc.fileJSON), 0600); err != nil {
t.Fatal(err)
}
}
t.Setenv("AMP_API_KEY", tc.envKey)
s := NewMultiSourceSecretWithPath(tc.configKey, secretsPath, 100*time.Millisecond)
got, err := s.Get(ctx)
if err != nil && tc.fileJSON != "" && json.Valid([]byte(tc.fileJSON)) {
t.Fatalf("unexpected error: %v", err)
}
if got != tc.want {
t.Fatalf("want %q, got %q", tc.want, got)
}
})
}
}
func TestMultiSourceSecret_CacheBehavior(t *testing.T) {
ctx := context.Background()
tmpDir := t.TempDir()
p := filepath.Join(tmpDir, "secrets.json")
// Initial value
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v1"}`), 0600); err != nil {
t.Fatal(err)
}
s := NewMultiSourceSecretWithPath("", p, 50*time.Millisecond)
// First read - should return v1
got1, err := s.Get(ctx)
if err != nil {
t.Fatalf("Get failed: %v", err)
}
if got1 != "v1" {
t.Fatalf("expected v1, got %s", got1)
}
// Change file; within TTL we should still see v1 (cached)
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v2"}`), 0600); err != nil {
t.Fatal(err)
}
got2, _ := s.Get(ctx)
if got2 != "v1" {
t.Fatalf("cache hit expected v1, got %s", got2)
}
// After TTL expires, should see v2
time.Sleep(60 * time.Millisecond)
got3, _ := s.Get(ctx)
if got3 != "v2" {
t.Fatalf("cache miss expected v2, got %s", got3)
}
// Invalidate forces re-read immediately
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"v3"}`), 0600); err != nil {
t.Fatal(err)
}
s.InvalidateCache()
got4, _ := s.Get(ctx)
if got4 != "v3" {
t.Fatalf("invalidate expected v3, got %s", got4)
}
}
func TestMultiSourceSecret_FileHandling(t *testing.T) {
ctx := context.Background()
t.Run("missing_file_no_error", func(t *testing.T) {
s := NewMultiSourceSecretWithPath("", "/nonexistent/path/secrets.json", 100*time.Millisecond)
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("expected no error for missing file, got: %v", err)
}
if got != "" {
t.Fatalf("expected empty string, got %q", got)
}
})
t.Run("invalid_json", func(t *testing.T) {
tmpDir := t.TempDir()
p := filepath.Join(tmpDir, "secrets.json")
if err := os.WriteFile(p, []byte(`{invalid json`), 0600); err != nil {
t.Fatal(err)
}
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
_, err := s.Get(ctx)
if err == nil {
t.Fatal("expected error for invalid JSON")
}
})
t.Run("missing_key_in_json", func(t *testing.T) {
tmpDir := t.TempDir()
p := filepath.Join(tmpDir, "secrets.json")
if err := os.WriteFile(p, []byte(`{"other":"value"}`), 0600); err != nil {
t.Fatal(err)
}
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "" {
t.Fatalf("expected empty string for missing key, got %q", got)
}
})
t.Run("empty_key_value", func(t *testing.T) {
tmpDir := t.TempDir()
p := filepath.Join(tmpDir, "secrets.json")
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":" "}`), 0600); err != nil {
t.Fatal(err)
}
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
got, _ := s.Get(ctx)
if got != "" {
t.Fatalf("expected empty after trim, got %q", got)
}
})
}
func TestMultiSourceSecret_Concurrency(t *testing.T) {
tmpDir := t.TempDir()
p := filepath.Join(tmpDir, "secrets.json")
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"concurrent"}`), 0600); err != nil {
t.Fatal(err)
}
s := NewMultiSourceSecretWithPath("", p, 5*time.Second)
ctx := context.Background()
// Spawn many goroutines calling Get concurrently
const goroutines = 50
const iterations = 100
var wg sync.WaitGroup
errors := make(chan error, goroutines)
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for j := 0; j < iterations; j++ {
val, err := s.Get(ctx)
if err != nil {
errors <- err
return
}
if val != "concurrent" {
errors <- err
return
}
}
}()
}
wg.Wait()
close(errors)
for err := range errors {
t.Errorf("concurrency error: %v", err)
}
}
func TestStaticSecretSource(t *testing.T) {
ctx := context.Background()
t.Run("returns_provided_key", func(t *testing.T) {
s := NewStaticSecretSource("test-key-123")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "test-key-123" {
t.Fatalf("want test-key-123, got %q", got)
}
})
t.Run("trims_whitespace", func(t *testing.T) {
s := NewStaticSecretSource(" test-key ")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "test-key" {
t.Fatalf("want test-key, got %q", got)
}
})
t.Run("empty_string", func(t *testing.T) {
s := NewStaticSecretSource("")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "" {
t.Fatalf("want empty string, got %q", got)
}
})
}
func TestMultiSourceSecret_CacheEmptyResult(t *testing.T) {
// Test that missing file results are cached to avoid repeated file reads
tmpDir := t.TempDir()
p := filepath.Join(tmpDir, "nonexistent.json")
s := NewMultiSourceSecretWithPath("", p, 100*time.Millisecond)
ctx := context.Background()
// First call - file doesn't exist, should cache empty result
got1, err := s.Get(ctx)
if err != nil {
t.Fatalf("expected no error for missing file, got: %v", err)
}
if got1 != "" {
t.Fatalf("expected empty string, got %q", got1)
}
// Create the file now
if err := os.WriteFile(p, []byte(`{"apiKey@https://ampcode.com/":"new-value"}`), 0600); err != nil {
t.Fatal(err)
}
// Second call - should still return empty (cached), not read the new file
got2, _ := s.Get(ctx)
if got2 != "" {
t.Fatalf("cache should return empty, got %q", got2)
}
// After TTL expires, should see the new value
time.Sleep(110 * time.Millisecond)
got3, _ := s.Get(ctx)
if got3 != "new-value" {
t.Fatalf("after cache expiry, expected new-value, got %q", got3)
}
}
func TestMappedSecretSource_UsesMappingFromContext(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1, got %q", got)
}
ctx = context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k2")
got, err = s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "default" {
t.Fatalf("want default fallback, got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_FirstWins(t *testing.T) {
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
ctx := context.WithValue(context.Background(), clientAPIKeyContextKey{}, "k1")
got, err := s.Get(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if got != "u1" {
t.Fatalf("want u1 (first wins), got %q", got)
}
}
func TestMappedSecretSource_DuplicateClientKey_LogsWarning(t *testing.T) {
hook := test.NewLocal(log.StandardLogger())
defer hook.Reset()
defaultSource := NewStaticSecretSource("default")
s := NewMappedSecretSource(defaultSource)
s.UpdateMappings([]config.AmpUpstreamAPIKeyEntry{
{
UpstreamAPIKey: "u1",
APIKeys: []string{"k1"},
},
{
UpstreamAPIKey: "u2",
APIKeys: []string{"k1"},
},
})
foundWarning := false
for _, entry := range hook.AllEntries() {
if entry.Level == log.WarnLevel && entry.Message == "amp upstream-api-keys: client API key appears in multiple entries; using first mapping." {
foundWarning = true
break
}
}
if !foundWarning {
t.Fatal("expected warning log for duplicate client key, but none was found")
}
}
================================================
FILE: internal/api/modules/modules.go
================================================
// Package modules provides a pluggable routing module system for extending
// the API server with optional features without modifying core routing logic.
package modules
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
)
// Context encapsulates the dependencies exposed to routing modules during
// registration. Modules can use the Gin engine to attach routes, the shared
// BaseAPIHandler for constructing SDK-specific handlers, and the resolved
// authentication middleware for protecting routes that require API keys.
type Context struct {
Engine *gin.Engine
BaseHandler *handlers.BaseAPIHandler
Config *config.Config
AuthMiddleware gin.HandlerFunc
}
// RouteModule represents a pluggable routing module that can register routes
// and handle configuration updates independently of the core server.
//
// DEPRECATED: Use RouteModuleV2 for new modules. This interface is kept for
// backwards compatibility and will be removed in a future version.
type RouteModule interface {
// Name returns a human-readable identifier for the module
Name() string
// Register sets up routes and handlers for this module.
// It receives the Gin engine, base handlers, and current configuration.
// Returns an error if registration fails (errors are logged but don't stop the server).
Register(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, cfg *config.Config) error
// OnConfigUpdated is called when the configuration is reloaded.
// Modules can respond to configuration changes here.
// Returns an error if the update cannot be applied.
OnConfigUpdated(cfg *config.Config) error
}
// RouteModuleV2 represents a pluggable bundle of routes that can integrate with
// the API server without modifying its core routing logic. Implementations can
// attach routes during Register and react to configuration updates via
// OnConfigUpdated.
//
// This is the preferred interface for new modules. It uses Context for cleaner
// dependency injection and supports idempotent registration.
type RouteModuleV2 interface {
// Name returns a unique identifier for logging and diagnostics.
Name() string
// Register wires the module's routes into the provided Gin engine. Modules
// should treat multiple calls as idempotent and avoid duplicate route
// registration when invoked more than once.
Register(ctx Context) error
// OnConfigUpdated notifies the module when the server configuration changes
// via hot reload. Implementations can refresh cached state or emit warnings.
OnConfigUpdated(cfg *config.Config) error
}
// RegisterModule is a helper that registers a module using either the V1 or V2
// interface. This allows gradual migration from V1 to V2 without breaking
// existing modules.
//
// Example usage:
//
// ctx := modules.Context{
// Engine: engine,
// BaseHandler: baseHandler,
// Config: cfg,
// AuthMiddleware: authMiddleware,
// }
// if err := modules.RegisterModule(ctx, ampModule); err != nil {
// log.Errorf("Failed to register module: %v", err)
// }
func RegisterModule(ctx Context, mod interface{}) error {
// Try V2 interface first (preferred)
if v2, ok := mod.(RouteModuleV2); ok {
return v2.Register(ctx)
}
// Fall back to V1 interface for backwards compatibility
if v1, ok := mod.(RouteModule); ok {
return v1.Register(ctx.Engine, ctx.BaseHandler, ctx.Config)
}
return fmt.Errorf("unsupported module type %T (must implement RouteModule or RouteModuleV2)", mod)
}
================================================
FILE: internal/api/server.go
================================================
// Package api provides the HTTP API server implementation for the CLI Proxy API.
// It includes the main server struct, routing setup, middleware for CORS and authentication,
// and integration with various AI API handlers (OpenAI, Claude, Gemini).
// The server supports hot-reloading of clients and configuration.
package api
import (
"context"
"crypto/subtle"
"errors"
"fmt"
"net/http"
"os"
"path/filepath"
"reflect"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/access"
managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules"
ampmodule "github.com/router-for-me/CLIProxyAPI/v6/internal/api/modules/amp"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/managementasset"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"gopkg.in/yaml.v3"
)
const oauthCallbackSuccessHTML = `Authentication successful Authentication successful! You can close this window.
This window will close automatically in 5 seconds.
`
type serverOptionConfig struct {
extraMiddleware []gin.HandlerFunc
engineConfigurator func(*gin.Engine)
routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)
requestLoggerFactory func(*config.Config, string) logging.RequestLogger
localPassword string
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
postAuthHook auth.PostAuthHook
}
// ServerOption customises HTTP server construction.
type ServerOption func(*serverOptionConfig)
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
configDir := filepath.Dir(configPath)
logsDir := logging.ResolveLogDirectory(cfg)
return logging.NewFileRequestLogger(cfg.RequestLog, logsDir, configDir, cfg.ErrorLogsMaxFiles)
}
// WithMiddleware appends additional Gin middleware during server construction.
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.extraMiddleware = append(cfg.extraMiddleware, mw...)
}
}
// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.engineConfigurator = fn
}
}
// WithRouterConfigurator appends a callback after default routes are registered.
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.routerConfigurator = fn
}
}
// WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests.
func WithLocalManagementPassword(password string) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.localPassword = password
}
}
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
return func(cfg *serverOptionConfig) {
if timeout <= 0 || onTimeout == nil {
return
}
cfg.keepAliveEnabled = true
cfg.keepAliveTimeout = timeout
cfg.keepAliveOnTimeout = onTimeout
}
}
// WithRequestLoggerFactory customises request logger creation.
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.requestLoggerFactory = factory
}
}
// WithPostAuthHook registers a hook to be called after auth record creation.
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.postAuthHook = hook
}
}
// Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct {
// engine is the Gin web framework engine instance.
engine *gin.Engine
// server is the underlying HTTP server.
server *http.Server
// handlers contains the API handlers for processing requests.
handlers *handlers.BaseAPIHandler
// cfg holds the current server configuration.
cfg *config.Config
// oldConfigYaml stores a YAML snapshot of the previous configuration for change detection.
// This prevents issues when the config object is modified in place by Management API.
oldConfigYaml []byte
// accessManager handles request authentication providers.
accessManager *sdkaccess.Manager
// requestLogger is the request logger instance for dynamic configuration updates.
requestLogger logging.RequestLogger
loggerToggle func(bool)
// configFilePath is the absolute path to the YAML config file for persistence.
configFilePath string
// currentPath is the absolute path to the current working directory.
currentPath string
// wsRoutes tracks registered websocket upgrade paths.
wsRouteMu sync.Mutex
wsRoutes map[string]struct{}
wsAuthChanged func(bool, bool)
wsAuthEnabled atomic.Bool
// management handler
mgmt *managementHandlers.Handler
// ampModule is the Amp routing module for model mapping hot-reload
ampModule *ampmodule.AmpModule
// managementRoutesRegistered tracks whether the management routes have been attached to the engine.
managementRoutesRegistered atomic.Bool
// managementRoutesEnabled controls whether management endpoints serve real handlers.
managementRoutesEnabled atomic.Bool
// envManagementSecret indicates whether MANAGEMENT_PASSWORD is configured.
envManagementSecret bool
localPassword string
keepAliveEnabled bool
keepAliveTimeout time.Duration
keepAliveOnTimeout func()
keepAliveHeartbeat chan struct{}
keepAliveStop chan struct{}
}
// NewServer creates and initializes a new API server instance.
// It sets up the Gin engine, middleware, routes, and handlers.
//
// Parameters:
// - cfg: The server configuration
// - authManager: core runtime auth manager
// - accessManager: request authentication manager
//
// Returns:
// - *Server: A new server instance
func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdkaccess.Manager, configFilePath string, opts ...ServerOption) *Server {
optionState := &serverOptionConfig{
requestLoggerFactory: defaultRequestLoggerFactory,
}
for i := range opts {
opts[i](optionState)
}
// Set gin mode
if !cfg.Debug {
gin.SetMode(gin.ReleaseMode)
}
// Create gin engine
engine := gin.New()
if optionState.engineConfigurator != nil {
optionState.engineConfigurator(engine)
}
// Add middleware
engine.Use(logging.GinLogrusLogger())
engine.Use(logging.GinLogrusRecovery())
for _, mw := range optionState.extraMiddleware {
engine.Use(mw)
}
// Add request logging middleware (positioned after recovery, before auth)
// Resolve logs directory relative to the configuration file directory.
var requestLogger logging.RequestLogger
var toggle func(bool)
if !cfg.CommercialMode {
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
}
}
}
engine.Use(corsMiddleware())
wd, err := os.Getwd()
if err != nil {
wd = configFilePath
}
envAdminPassword, envAdminPasswordSet := os.LookupEnv("MANAGEMENT_PASSWORD")
envAdminPassword = strings.TrimSpace(envAdminPassword)
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
// Create server instance
s := &Server{
engine: engine,
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
cfg: cfg,
accessManager: accessManager,
requestLogger: requestLogger,
loggerToggle: toggle,
configFilePath: configFilePath,
currentPath: wd,
envManagementSecret: envManagementSecret,
wsRoutes: make(map[string]struct{}),
}
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
// Save initial YAML snapshot
s.oldConfigYaml, _ = yaml.Marshal(cfg)
s.applyAccessConfig(nil, cfg)
if authManager != nil {
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
}
managementasset.SetCurrentConfig(cfg)
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
// Initialize management handler
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
if optionState.localPassword != "" {
s.mgmt.SetLocalPassword(optionState.localPassword)
}
logDir := logging.ResolveLogDirectory(cfg)
s.mgmt.SetLogDirectory(logDir)
if optionState.postAuthHook != nil {
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
}
s.localPassword = optionState.localPassword
// Setup routes
s.setupRoutes()
// Register Amp module using V2 interface with Context
s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager))
ctx := modules.Context{
Engine: engine,
BaseHandler: s.handlers,
Config: cfg,
AuthMiddleware: AuthMiddleware(accessManager),
}
if err := modules.RegisterModule(ctx, s.ampModule); err != nil {
log.Errorf("Failed to register Amp module: %v", err)
}
// Apply additional router configurators from options
if optionState.routerConfigurator != nil {
optionState.routerConfigurator(engine, s.handlers, cfg)
}
// Register management routes when configuration or environment secrets are available,
// or when a local management password is provided (e.g. TUI mode).
hasManagementSecret := cfg.RemoteManagement.SecretKey != "" || envManagementSecret || s.localPassword != ""
s.managementRoutesEnabled.Store(hasManagementSecret)
if hasManagementSecret {
s.registerManagementRoutes()
}
if optionState.keepAliveEnabled {
s.enableKeepAlive(optionState.keepAliveTimeout, optionState.keepAliveOnTimeout)
}
// Create HTTP server
s.server = &http.Server{
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Handler: engine,
}
return s
}
// setupRoutes configures the API routes for the server.
// It defines the endpoints and associates them with their respective handlers.
func (s *Server) setupRoutes() {
s.engine.GET("/management.html", s.serveManagementControlPanel)
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
geminiCLIHandlers := gemini.NewGeminiCLIAPIHandler(s.handlers)
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(s.handlers)
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(s.handlers)
// OpenAI compatible API routes
v1 := s.engine.Group("/v1")
v1.Use(AuthMiddleware(s.accessManager))
{
v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers))
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
v1.POST("/completions", openaiHandlers.Completions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
v1.POST("/responses", openaiResponsesHandlers.Responses)
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
}
// Gemini compatible API routes
v1beta := s.engine.Group("/v1beta")
v1beta.Use(AuthMiddleware(s.accessManager))
{
v1beta.GET("/models", geminiHandlers.GeminiModels)
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
}
// Root endpoint
s.engine.GET("/", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"message": "CLI Proxy API Server",
"endpoints": []string{
"POST /v1/chat/completions",
"POST /v1/completions",
"GET /v1/models",
},
})
})
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
// OAuth callback endpoints (reuse main server port)
// These endpoints receive provider redirects and persist
// the short-lived code/state for the waiting goroutine.
s.engine.GET("/anthropic/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/codex/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/google/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/iflow/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
s.engine.GET("/antigravity/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" {
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML)
})
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
}
// AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine.
// The handler is served as-is without additional middleware beyond the standard stack already configured.
func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
if s == nil || s.engine == nil || handler == nil {
return
}
trimmed := strings.TrimSpace(path)
if trimmed == "" {
trimmed = "/v1/ws"
}
if !strings.HasPrefix(trimmed, "/") {
trimmed = "/" + trimmed
}
s.wsRouteMu.Lock()
if _, exists := s.wsRoutes[trimmed]; exists {
s.wsRouteMu.Unlock()
return
}
s.wsRoutes[trimmed] = struct{}{}
s.wsRouteMu.Unlock()
authMiddleware := AuthMiddleware(s.accessManager)
conditionalAuth := func(c *gin.Context) {
if !s.wsAuthEnabled.Load() {
c.Next()
return
}
authMiddleware(c)
}
finalHandler := func(c *gin.Context) {
handler.ServeHTTP(c.Writer, c.Request)
c.Abort()
}
s.engine.GET(trimmed, conditionalAuth, finalHandler)
}
func (s *Server) registerManagementRoutes() {
if s == nil || s.engine == nil || s.mgmt == nil {
return
}
if !s.managementRoutesRegistered.CompareAndSwap(false, true) {
return
}
log.Info("management routes registered after secret key configuration")
mgmt := s.engine.Group("/v0/management")
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
mgmt.GET("/latest-version", s.mgmt.GetLatestVersion)
mgmt.GET("/debug", s.mgmt.GetDebug)
mgmt.PUT("/debug", s.mgmt.PutDebug)
mgmt.PATCH("/debug", s.mgmt.PutDebug)
mgmt.GET("/logging-to-file", s.mgmt.GetLoggingToFile)
mgmt.PUT("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.PATCH("/logging-to-file", s.mgmt.PutLoggingToFile)
mgmt.GET("/logs-max-total-size-mb", s.mgmt.GetLogsMaxTotalSizeMB)
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
mgmt.GET("/error-logs-max-files", s.mgmt.GetErrorLogsMaxFiles)
mgmt.PUT("/error-logs-max-files", s.mgmt.PutErrorLogsMaxFiles)
mgmt.PATCH("/error-logs-max-files", s.mgmt.PutErrorLogsMaxFiles)
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
mgmt.GET("/proxy-url", s.mgmt.GetProxyURL)
mgmt.PUT("/proxy-url", s.mgmt.PutProxyURL)
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
mgmt.POST("/api-call", s.mgmt.APICall)
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.GET("/quota-exceeded/switch-preview-model", s.mgmt.GetSwitchPreviewModel)
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys)
mgmt.GET("/gemini-api-key", s.mgmt.GetGeminiKeys)
mgmt.PUT("/gemini-api-key", s.mgmt.PutGeminiKeys)
mgmt.PATCH("/gemini-api-key", s.mgmt.PatchGeminiKey)
mgmt.DELETE("/gemini-api-key", s.mgmt.DeleteGeminiKey)
mgmt.GET("/logs", s.mgmt.GetLogs)
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
mgmt.GET("/ws-auth", s.mgmt.GetWebsocketAuth)
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.GET("/ampcode", s.mgmt.GetAmpCode)
mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL)
mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL)
mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey)
mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey)
mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost)
mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings)
mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings)
mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings)
mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings)
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.GET("/ampcode/upstream-api-keys", s.mgmt.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", s.mgmt.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", s.mgmt.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", s.mgmt.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
mgmt.GET("/max-retry-interval", s.mgmt.GetMaxRetryInterval)
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
mgmt.GET("/force-model-prefix", s.mgmt.GetForceModelPrefix)
mgmt.PUT("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.PATCH("/force-model-prefix", s.mgmt.PutForceModelPrefix)
mgmt.GET("/routing/strategy", s.mgmt.GetRoutingStrategy)
mgmt.PUT("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.PATCH("/routing/strategy", s.mgmt.PutRoutingStrategy)
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
mgmt.DELETE("/claude-api-key", s.mgmt.DeleteClaudeKey)
mgmt.GET("/codex-api-key", s.mgmt.GetCodexKeys)
mgmt.PUT("/codex-api-key", s.mgmt.PutCodexKeys)
mgmt.PATCH("/codex-api-key", s.mgmt.PatchCodexKey)
mgmt.DELETE("/codex-api-key", s.mgmt.DeleteCodexKey)
mgmt.GET("/openai-compatibility", s.mgmt.GetOpenAICompat)
mgmt.PUT("/openai-compatibility", s.mgmt.PutOpenAICompat)
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys)
mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys)
mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey)
mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey)
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
mgmt.GET("/oauth-model-alias", s.mgmt.GetOAuthModelAlias)
mgmt.PUT("/oauth-model-alias", s.mgmt.PutOAuthModelAlias)
mgmt.PATCH("/oauth-model-alias", s.mgmt.PatchOAuthModelAlias)
mgmt.DELETE("/oauth-model-alias", s.mgmt.DeleteOAuthModelAlias)
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
mgmt.GET("/model-definitions/:channel", s.mgmt.GetStaticModelDefinitions)
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.PATCH("/auth-files/status", s.mgmt.PatchAuthFileStatus)
mgmt.PATCH("/auth-files/fields", s.mgmt.PatchAuthFileFields)
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
}
}
func (s *Server) managementAvailabilityMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
if !s.managementRoutesEnabled.Load() {
c.AbortWithStatus(http.StatusNotFound)
return
}
c.Next()
}
}
func (s *Server) serveManagementControlPanel(c *gin.Context) {
cfg := s.cfg
if cfg == nil || cfg.RemoteManagement.DisableControlPanel {
c.AbortWithStatus(http.StatusNotFound)
return
}
filePath := managementasset.FilePath(s.configFilePath)
if strings.TrimSpace(filePath) == "" {
c.AbortWithStatus(http.StatusNotFound)
return
}
if _, err := os.Stat(filePath); err != nil {
if os.IsNotExist(err) {
// Synchronously ensure management.html is available with a detached context.
// Control panel bootstrap should not be canceled by client disconnects.
if !managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository) {
c.AbortWithStatus(http.StatusNotFound)
return
}
} else {
log.WithError(err).Error("failed to stat management control panel asset")
c.AbortWithStatus(http.StatusInternalServerError)
return
}
}
c.File(filePath)
}
func (s *Server) enableKeepAlive(timeout time.Duration, onTimeout func()) {
if timeout <= 0 || onTimeout == nil {
return
}
s.keepAliveEnabled = true
s.keepAliveTimeout = timeout
s.keepAliveOnTimeout = onTimeout
s.keepAliveHeartbeat = make(chan struct{}, 1)
s.keepAliveStop = make(chan struct{}, 1)
s.engine.GET("/keep-alive", s.handleKeepAlive)
go s.watchKeepAlive()
}
func (s *Server) handleKeepAlive(c *gin.Context) {
if s.localPassword != "" {
provided := strings.TrimSpace(c.GetHeader("Authorization"))
if provided != "" {
parts := strings.SplitN(provided, " ", 2)
if len(parts) == 2 && strings.EqualFold(parts[0], "bearer") {
provided = parts[1]
}
}
if provided == "" {
provided = strings.TrimSpace(c.GetHeader("X-Local-Password"))
}
if subtle.ConstantTimeCompare([]byte(provided), []byte(s.localPassword)) != 1 {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid password"})
return
}
}
s.signalKeepAlive()
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}
func (s *Server) signalKeepAlive() {
if !s.keepAliveEnabled {
return
}
select {
case s.keepAliveHeartbeat <- struct{}{}:
default:
}
}
func (s *Server) watchKeepAlive() {
if !s.keepAliveEnabled {
return
}
timer := time.NewTimer(s.keepAliveTimeout)
defer timer.Stop()
for {
select {
case <-timer.C:
log.Warnf("keep-alive endpoint idle for %s, shutting down", s.keepAliveTimeout)
if s.keepAliveOnTimeout != nil {
s.keepAliveOnTimeout()
}
return
case <-s.keepAliveHeartbeat:
if !timer.Stop() {
select {
case <-timer.C:
default:
}
}
timer.Reset(s.keepAliveTimeout)
case <-s.keepAliveStop:
return
}
}
}
// unifiedModelsHandler creates a unified handler for the /v1/models endpoint
// that routes to different handlers based on the User-Agent header.
// If User-Agent starts with "claude-cli", it routes to Claude handler,
// otherwise it routes to OpenAI handler.
func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc {
return func(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
// Route to Claude handler if User-Agent starts with "claude-cli"
if strings.HasPrefix(userAgent, "claude-cli") {
// log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent)
claudeHandler.ClaudeModels(c)
} else {
// log.Debugf("Routing /v1/models to OpenAI handler for User-Agent: %s", userAgent)
openaiHandler.OpenAIModels(c)
}
}
}
// Start begins listening for and serving HTTP or HTTPS requests.
// It's a blocking call and will only return on an unrecoverable error.
//
// Returns:
// - error: An error if the server fails to start
func (s *Server) Start() error {
if s == nil || s.server == nil {
return fmt.Errorf("failed to start HTTP server: server not initialized")
}
useTLS := s.cfg != nil && s.cfg.TLS.Enable
if useTLS {
cert := strings.TrimSpace(s.cfg.TLS.Cert)
key := strings.TrimSpace(s.cfg.TLS.Key)
if cert == "" || key == "" {
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
}
log.Debugf("Starting API server on %s with TLS", s.server.Addr)
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
}
return nil
}
log.Debugf("Starting API server on %s", s.server.Addr)
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTP server: %v", errServe)
}
return nil
}
// Stop gracefully shuts down the API server without interrupting any
// active connections.
//
// Parameters:
// - ctx: The context for graceful shutdown
//
// Returns:
// - error: An error if the server fails to stop
func (s *Server) Stop(ctx context.Context) error {
log.Debug("Stopping API server...")
if s.keepAliveEnabled {
select {
case s.keepAliveStop <- struct{}{}:
default:
}
}
// Shutdown the HTTP server.
if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
}
log.Debug("API server stopped")
return nil
}
// corsMiddleware returns a Gin middleware handler that adds CORS headers
// to every response, allowing cross-origin requests.
//
// Returns:
// - gin.HandlerFunc: The CORS middleware handler
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "*")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) {
if s == nil || s.accessManager == nil || newCfg == nil {
return
}
if _, err := access.ApplyAccessProviders(s.accessManager, oldCfg, newCfg); err != nil {
return
}
}
// UpdateClients updates the server's client list and configuration.
// This method is called when the configuration or authentication tokens change.
//
// Parameters:
// - clients: The new slice of AI service clients
// - cfg: The new application configuration
func (s *Server) UpdateClients(cfg *config.Config) {
// Reconstruct old config from YAML snapshot to avoid reference sharing issues
var oldCfg *config.Config
if len(s.oldConfigYaml) > 0 {
_ = yaml.Unmarshal(s.oldConfigYaml, &oldCfg)
}
// Update request logger enabled state if it has changed
previousRequestLog := false
if oldCfg != nil {
previousRequestLog = oldCfg.RequestLog
}
if s.requestLogger != nil && (oldCfg == nil || previousRequestLog != cfg.RequestLog) {
if s.loggerToggle != nil {
s.loggerToggle(cfg.RequestLog)
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
toggler.SetEnabled(cfg.RequestLog)
}
}
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
if err := logging.ConfigureLogOutput(cfg); err != nil {
log.Errorf("failed to reconfigure log output: %v", err)
}
}
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
}
if s.requestLogger != nil && (oldCfg == nil || oldCfg.ErrorLogsMaxFiles != cfg.ErrorLogsMaxFiles) {
if setter, ok := s.requestLogger.(interface{ SetErrorLogsMaxFiles(int) }); ok {
setter.SetErrorLogsMaxFiles(cfg.ErrorLogsMaxFiles)
}
}
if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling {
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
}
if s.handlers != nil && s.handlers.AuthManager != nil {
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second, cfg.MaxRetryCredentials)
}
// Update log level dynamically when debug flag changes
if oldCfg == nil || oldCfg.Debug != cfg.Debug {
util.SetLogLevel(cfg)
}
prevSecretEmpty := true
if oldCfg != nil {
prevSecretEmpty = oldCfg.RemoteManagement.SecretKey == ""
}
newSecretEmpty := cfg.RemoteManagement.SecretKey == ""
if s.envManagementSecret {
s.registerManagementRoutes()
if s.managementRoutesEnabled.CompareAndSwap(false, true) {
log.Info("management routes enabled via MANAGEMENT_PASSWORD")
} else {
s.managementRoutesEnabled.Store(true)
}
} else {
switch {
case prevSecretEmpty && !newSecretEmpty:
s.registerManagementRoutes()
if s.managementRoutesEnabled.CompareAndSwap(false, true) {
log.Info("management routes enabled after secret key update")
} else {
s.managementRoutesEnabled.Store(true)
}
case !prevSecretEmpty && newSecretEmpty:
if s.managementRoutesEnabled.CompareAndSwap(true, false) {
log.Info("management routes disabled after secret key removal")
} else {
s.managementRoutesEnabled.Store(false)
}
default:
s.managementRoutesEnabled.Store(!newSecretEmpty)
}
}
s.applyAccessConfig(oldCfg, cfg)
s.cfg = cfg
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
}
managementasset.SetCurrentConfig(cfg)
// Save YAML snapshot for next comparison
s.oldConfigYaml, _ = yaml.Marshal(cfg)
s.handlers.UpdateClients(&cfg.SDKConfig)
if s.mgmt != nil {
s.mgmt.SetConfig(cfg)
s.mgmt.SetAuthManager(s.handlers.AuthManager)
}
// Notify Amp module only when Amp config has changed.
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
if ampConfigChanged {
if s.ampModule != nil {
log.Debugf("triggering amp module config update")
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
log.Errorf("failed to update Amp module config: %v", err)
}
} else {
log.Warnf("amp module is nil, skipping config update")
}
}
// Count client sources from configuration and auth store.
tokenStore := sdkAuth.GetTokenStore()
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok {
dirSetter.SetBaseDir(cfg.AuthDir)
}
authEntries := util.CountAuthFiles(context.Background(), tokenStore)
geminiAPIKeyCount := len(cfg.GeminiKey)
claudeAPIKeyCount := len(cfg.ClaudeKey)
codexAPIKeyCount := len(cfg.CodexKey)
vertexAICompatCount := len(cfg.VertexCompatAPIKey)
openAICompatCount := 0
for i := range cfg.OpenAICompatibility {
entry := cfg.OpenAICompatibility[i]
openAICompatCount += len(entry.APIKeyEntries)
}
total := authEntries + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
fmt.Printf("server clients and configuration updated: %d clients (%d auth entries + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
total,
authEntries,
geminiAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,
vertexAICompatCount,
openAICompatCount,
)
}
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
if s == nil {
return
}
s.wsAuthChanged = fn
}
// (management handlers moved to internal/api/handlers/management)
// AuthMiddleware returns a Gin middleware handler that authenticates requests
// using the configured authentication providers. When no providers are available,
// it allows all requests (legacy behaviour).
func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
return func(c *gin.Context) {
if manager == nil {
c.Next()
return
}
result, err := manager.Authenticate(c.Request.Context(), c.Request)
if err == nil {
if result != nil {
c.Set("apiKey", result.Principal)
c.Set("accessProvider", result.Provider)
if len(result.Metadata) > 0 {
c.Set("accessMetadata", result.Metadata)
}
}
c.Next()
return
}
statusCode := err.HTTPStatusCode()
if statusCode >= http.StatusInternalServerError {
log.Errorf("authentication middleware error: %v", err)
}
c.AbortWithStatusJSON(statusCode, gin.H{"error": err.Message})
}
}
================================================
FILE: internal/api/server_test.go
================================================
package api
import (
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"testing"
"time"
gin "github.com/gin-gonic/gin"
proxyconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func newTestServer(t *testing.T) *Server {
t.Helper()
gin.SetMode(gin.TestMode)
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o700); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
cfg := &proxyconfig.Config{
SDKConfig: sdkconfig.SDKConfig{
APIKeys: []string{"test-key"},
},
Port: 0,
AuthDir: authDir,
Debug: true,
LoggingToFile: false,
UsageStatisticsEnabled: false,
}
authManager := auth.NewManager(nil, nil, nil)
accessManager := sdkaccess.NewManager()
configPath := filepath.Join(tmpDir, "config.yaml")
return NewServer(cfg, authManager, accessManager, configPath)
}
func TestAmpProviderModelRoutes(t *testing.T) {
testCases := []struct {
name string
path string
wantStatus int
wantContains string
}{
{
name: "openai root models",
path: "/api/provider/openai/models",
wantStatus: http.StatusOK,
wantContains: `"object":"list"`,
},
{
name: "groq root models",
path: "/api/provider/groq/models",
wantStatus: http.StatusOK,
wantContains: `"object":"list"`,
},
{
name: "openai models",
path: "/api/provider/openai/v1/models",
wantStatus: http.StatusOK,
wantContains: `"object":"list"`,
},
{
name: "anthropic models",
path: "/api/provider/anthropic/v1/models",
wantStatus: http.StatusOK,
wantContains: `"data"`,
},
{
name: "google models v1",
path: "/api/provider/google/v1/models",
wantStatus: http.StatusOK,
wantContains: `"models"`,
},
{
name: "google models v1beta",
path: "/api/provider/google/v1beta/models",
wantStatus: http.StatusOK,
wantContains: `"models"`,
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
server := newTestServer(t)
req := httptest.NewRequest(http.MethodGet, tc.path, nil)
req.Header.Set("Authorization", "Bearer test-key")
rr := httptest.NewRecorder()
server.engine.ServeHTTP(rr, req)
if rr.Code != tc.wantStatus {
t.Fatalf("unexpected status code for %s: got %d want %d; body=%s", tc.path, rr.Code, tc.wantStatus, rr.Body.String())
}
if body := rr.Body.String(); !strings.Contains(body, tc.wantContains) {
t.Fatalf("response body for %s missing %q: %s", tc.path, tc.wantContains, body)
}
})
}
}
func TestDefaultRequestLoggerFactory_UsesResolvedLogDirectory(t *testing.T) {
t.Setenv("WRITABLE_PATH", "")
t.Setenv("writable_path", "")
originalWD, errGetwd := os.Getwd()
if errGetwd != nil {
t.Fatalf("failed to get current working directory: %v", errGetwd)
}
tmpDir := t.TempDir()
if errChdir := os.Chdir(tmpDir); errChdir != nil {
t.Fatalf("failed to switch working directory: %v", errChdir)
}
defer func() {
if errChdirBack := os.Chdir(originalWD); errChdirBack != nil {
t.Fatalf("failed to restore working directory: %v", errChdirBack)
}
}()
// Force ResolveLogDirectory to fallback to auth-dir/logs by making ./logs not a writable directory.
if errWriteFile := os.WriteFile(filepath.Join(tmpDir, "logs"), []byte("not-a-directory"), 0o644); errWriteFile != nil {
t.Fatalf("failed to create blocking logs file: %v", errWriteFile)
}
configDir := filepath.Join(tmpDir, "config")
if errMkdirConfig := os.MkdirAll(configDir, 0o755); errMkdirConfig != nil {
t.Fatalf("failed to create config dir: %v", errMkdirConfig)
}
configPath := filepath.Join(configDir, "config.yaml")
authDir := filepath.Join(tmpDir, "auth")
if errMkdirAuth := os.MkdirAll(authDir, 0o700); errMkdirAuth != nil {
t.Fatalf("failed to create auth dir: %v", errMkdirAuth)
}
cfg := &proxyconfig.Config{
SDKConfig: proxyconfig.SDKConfig{
RequestLog: false,
},
AuthDir: authDir,
ErrorLogsMaxFiles: 10,
}
logger := defaultRequestLoggerFactory(cfg, configPath)
fileLogger, ok := logger.(*internallogging.FileRequestLogger)
if !ok {
t.Fatalf("expected *FileRequestLogger, got %T", logger)
}
errLog := fileLogger.LogRequestWithOptions(
"/v1/chat/completions",
http.MethodPost,
map[string][]string{"Content-Type": []string{"application/json"}},
[]byte(`{"input":"hello"}`),
http.StatusBadGateway,
map[string][]string{"Content-Type": []string{"application/json"}},
[]byte(`{"error":"upstream failure"}`),
nil,
nil,
nil,
true,
"issue-1711",
time.Now(),
time.Now(),
)
if errLog != nil {
t.Fatalf("failed to write forced error request log: %v", errLog)
}
authLogsDir := filepath.Join(authDir, "logs")
authEntries, errReadAuthDir := os.ReadDir(authLogsDir)
if errReadAuthDir != nil {
t.Fatalf("failed to read auth logs dir %s: %v", authLogsDir, errReadAuthDir)
}
foundErrorLogInAuthDir := false
for _, entry := range authEntries {
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
foundErrorLogInAuthDir = true
break
}
}
if !foundErrorLogInAuthDir {
t.Fatalf("expected forced error log in auth fallback dir %s, got entries: %+v", authLogsDir, authEntries)
}
configLogsDir := filepath.Join(configDir, "logs")
configEntries, errReadConfigDir := os.ReadDir(configLogsDir)
if errReadConfigDir != nil && !os.IsNotExist(errReadConfigDir) {
t.Fatalf("failed to inspect config logs dir %s: %v", configLogsDir, errReadConfigDir)
}
for _, entry := range configEntries {
if strings.HasPrefix(entry.Name(), "error-") && strings.HasSuffix(entry.Name(), ".log") {
t.Fatalf("unexpected forced error log in config dir %s", configLogsDir)
}
}
}
================================================
FILE: internal/auth/antigravity/auth.go
================================================
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
package antigravity
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// TokenResponse represents OAuth token response from Google
type TokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
// userInfo represents Google user profile
type userInfo struct {
Email string `json:"email"`
}
// AntigravityAuth handles Antigravity OAuth authentication
type AntigravityAuth struct {
httpClient *http.Client
}
// NewAntigravityAuth creates a new Antigravity auth service.
func NewAntigravityAuth(cfg *config.Config, httpClient *http.Client) *AntigravityAuth {
if httpClient != nil {
return &AntigravityAuth{httpClient: httpClient}
}
if cfg == nil {
cfg = &config.Config{}
}
return &AntigravityAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
}
}
// BuildAuthURL generates the OAuth authorization URL.
func (o *AntigravityAuth) BuildAuthURL(state, redirectURI string) string {
if strings.TrimSpace(redirectURI) == "" {
redirectURI = fmt.Sprintf("http://localhost:%d/oauth-callback", CallbackPort)
}
params := url.Values{}
params.Set("access_type", "offline")
params.Set("client_id", ClientID)
params.Set("prompt", "consent")
params.Set("redirect_uri", redirectURI)
params.Set("response_type", "code")
params.Set("scope", strings.Join(Scopes, " "))
params.Set("state", state)
return AuthEndpoint + "?" + params.Encode()
}
// ExchangeCodeForTokens exchanges authorization code for access and refresh tokens
func (o *AntigravityAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*TokenResponse, error) {
data := url.Values{}
data.Set("code", code)
data.Set("client_id", ClientID)
data.Set("client_secret", ClientSecret)
data.Set("redirect_uri", redirectURI)
data.Set("grant_type", "authorization_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, TokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("antigravity token exchange: create request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return nil, fmt.Errorf("antigravity token exchange: execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity token exchange: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
if errRead != nil {
return nil, fmt.Errorf("antigravity token exchange: read response: %w", errRead)
}
body := strings.TrimSpace(string(bodyBytes))
if body == "" {
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d", resp.StatusCode)
}
return nil, fmt.Errorf("antigravity token exchange: request failed: status %d: %s", resp.StatusCode, body)
}
var token TokenResponse
if errDecode := json.NewDecoder(resp.Body).Decode(&token); errDecode != nil {
return nil, fmt.Errorf("antigravity token exchange: decode response: %w", errDecode)
}
return &token, nil
}
// FetchUserInfo retrieves user email from Google
func (o *AntigravityAuth) FetchUserInfo(ctx context.Context, accessToken string) (string, error) {
accessToken = strings.TrimSpace(accessToken)
if accessToken == "" {
return "", fmt.Errorf("antigravity userinfo: missing access token")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, UserInfoEndpoint, nil)
if err != nil {
return "", fmt.Errorf("antigravity userinfo: create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("antigravity userinfo: execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity userinfo: close body error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(io.LimitReader(resp.Body, 8<<10))
if errRead != nil {
return "", fmt.Errorf("antigravity userinfo: read response: %w", errRead)
}
body := strings.TrimSpace(string(bodyBytes))
if body == "" {
return "", fmt.Errorf("antigravity userinfo: request failed: status %d", resp.StatusCode)
}
return "", fmt.Errorf("antigravity userinfo: request failed: status %d: %s", resp.StatusCode, body)
}
var info userInfo
if errDecode := json.NewDecoder(resp.Body).Decode(&info); errDecode != nil {
return "", fmt.Errorf("antigravity userinfo: decode response: %w", errDecode)
}
email := strings.TrimSpace(info.Email)
if email == "" {
return "", fmt.Errorf("antigravity userinfo: response missing email")
}
return email, nil
}
// FetchProjectID retrieves the project ID for the authenticated user via loadCodeAssist
func (o *AntigravityAuth) FetchProjectID(ctx context.Context, accessToken string) (string, error) {
loadReqBody := map[string]any{
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(loadReqBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
endpointURL := fmt.Sprintf("%s/%s:loadCodeAssist", APIEndpoint, APIVersion)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if err != nil {
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", APIUserAgent)
req.Header.Set("X-Goog-Api-Client", APIClient)
req.Header.Set("Client-Metadata", ClientMetadata)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
return "", fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity loadCodeAssist: close body error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(resp.Body)
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
return "", fmt.Errorf("request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var loadResp map[string]any
if errDecode := json.Unmarshal(bodyBytes, &loadResp); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
// Extract projectID from response
projectID := ""
if id, ok := loadResp["cloudaicompanionProject"].(string); ok {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, ok := loadResp["cloudaicompanionProject"].(map[string]any); ok {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
if projectID == "" {
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID, err = o.OnboardUser(ctx, accessToken, tierID)
if err != nil {
return "", err
}
return projectID, nil
}
return projectID, nil
}
// OnboardUser attempts to fetch the project ID via onboardUser by polling for completion
func (o *AntigravityAuth) OnboardUser(ctx context.Context, accessToken, tierID string) (string, error) {
log.Infof("Antigravity: onboarding user with tier: %s", tierID)
requestBody := map[string]any{
"tierId": tierID,
"metadata": map[string]string{
"ideType": "ANTIGRAVITY",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
},
}
rawBody, errMarshal := json.Marshal(requestBody)
if errMarshal != nil {
return "", fmt.Errorf("marshal request body: %w", errMarshal)
}
maxAttempts := 5
for attempt := 1; attempt <= maxAttempts; attempt++ {
log.Debugf("Polling attempt %d/%d", attempt, maxAttempts)
reqCtx := ctx
var cancel context.CancelFunc
if reqCtx == nil {
reqCtx = context.Background()
}
reqCtx, cancel = context.WithTimeout(reqCtx, 30*time.Second)
endpointURL := fmt.Sprintf("%s/%s:onboardUser", APIEndpoint, APIVersion)
req, errRequest := http.NewRequestWithContext(reqCtx, http.MethodPost, endpointURL, strings.NewReader(string(rawBody)))
if errRequest != nil {
cancel()
return "", fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Authorization", "Bearer "+accessToken)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", APIUserAgent)
req.Header.Set("X-Goog-Api-Client", APIClient)
req.Header.Set("Client-Metadata", ClientMetadata)
resp, errDo := o.httpClient.Do(req)
if errDo != nil {
cancel()
return "", fmt.Errorf("execute request: %w", errDo)
}
bodyBytes, errRead := io.ReadAll(resp.Body)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("close body error: %v", errClose)
}
cancel()
if errRead != nil {
return "", fmt.Errorf("read response: %w", errRead)
}
if resp.StatusCode == http.StatusOK {
var data map[string]any
if errDecode := json.Unmarshal(bodyBytes, &data); errDecode != nil {
return "", fmt.Errorf("decode response: %w", errDecode)
}
if done, okDone := data["done"].(bool); okDone && done {
projectID := ""
if responseData, okResp := data["response"].(map[string]any); okResp {
switch projectValue := responseData["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
case string:
projectID = strings.TrimSpace(projectValue)
}
}
if projectID != "" {
log.Infof("Successfully fetched project_id: %s", projectID)
return projectID, nil
}
return "", fmt.Errorf("no project_id in response")
}
time.Sleep(2 * time.Second)
continue
}
responsePreview := strings.TrimSpace(string(bodyBytes))
if len(responsePreview) > 500 {
responsePreview = responsePreview[:500]
}
responseErr := responsePreview
if len(responseErr) > 200 {
responseErr = responseErr[:200]
}
return "", fmt.Errorf("http %d: %s", resp.StatusCode, responseErr)
}
return "", nil
}
================================================
FILE: internal/auth/antigravity/constants.go
================================================
// Package antigravity provides OAuth2 authentication functionality for the Antigravity provider.
package antigravity
// OAuth client credentials and configuration
const (
ClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
ClientSecret = "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"
CallbackPort = 51121
)
// Scopes defines the OAuth scopes required for Antigravity authentication
var Scopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
"https://www.googleapis.com/auth/cclog",
"https://www.googleapis.com/auth/experimentsandconfigs",
}
// OAuth2 endpoints for Google authentication
const (
TokenEndpoint = "https://oauth2.googleapis.com/token"
AuthEndpoint = "https://accounts.google.com/o/oauth2/v2/auth"
UserInfoEndpoint = "https://www.googleapis.com/oauth2/v1/userinfo?alt=json"
)
// Antigravity API configuration
const (
APIEndpoint = "https://cloudcode-pa.googleapis.com"
APIVersion = "v1internal"
APIUserAgent = "google-api-nodejs-client/9.15.1"
APIClient = "google-cloud-sdk vscode_cloudshelleditor/0.1"
ClientMetadata = `{"ideType":"IDE_UNSPECIFIED","platform":"PLATFORM_UNSPECIFIED","pluginType":"GEMINI"}`
)
================================================
FILE: internal/auth/antigravity/filename.go
================================================
package antigravity
import (
"fmt"
"strings"
)
// CredentialFileName returns the filename used to persist Antigravity credentials.
// It uses the email as a suffix to disambiguate accounts.
func CredentialFileName(email string) string {
email = strings.TrimSpace(email)
if email == "" {
return "antigravity.json"
}
return fmt.Sprintf("antigravity-%s.json", email)
}
================================================
FILE: internal/auth/claude/anthropic.go
================================================
package claude
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
type PKCECodes struct {
// CodeVerifier is the cryptographically random string used to correlate
// the authorization request to the token request
CodeVerifier string `json:"code_verifier"`
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
CodeChallenge string `json:"code_challenge"`
}
// ClaudeTokenData holds OAuth token information from Anthropic
type ClaudeTokenData struct {
// AccessToken is the OAuth2 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain new access tokens
RefreshToken string `json:"refresh_token"`
// Email is the Anthropic account email
Email string `json:"email"`
// Expire is the timestamp of the token expire
Expire string `json:"expired"`
}
// ClaudeAuthBundle aggregates authentication data after OAuth flow completion
type ClaudeAuthBundle struct {
// APIKey is the Anthropic API key obtained from token exchange
APIKey string `json:"api_key"`
// TokenData contains the OAuth tokens from the authentication flow
TokenData ClaudeTokenData `json:"token_data"`
// LastRefresh is the timestamp of the last token refresh
LastRefresh string `json:"last_refresh"`
}
================================================
FILE: internal/auth/claude/anthropic_auth.go
================================================
// Package claude provides OAuth2 authentication functionality for Anthropic's Claude API.
// This package implements the complete OAuth2 flow with PKCE (Proof Key for Code Exchange)
// for secure authentication with Claude API, including token exchange, refresh, and storage.
package claude
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
// OAuth configuration constants for Claude/Anthropic
const (
AuthURL = "https://claude.ai/oauth/authorize"
TokenURL = "https://api.anthropic.com/v1/oauth/token"
ClientID = "9d1c250a-e61b-44d9-88ed-5944d1962f5e"
RedirectURI = "http://localhost:54545/callback"
)
// tokenResponse represents the response structure from Anthropic's OAuth token endpoint.
// It contains access token, refresh token, and associated user/organization information.
type tokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
Organization struct {
UUID string `json:"uuid"`
Name string `json:"name"`
} `json:"organization"`
Account struct {
UUID string `json:"uuid"`
EmailAddress string `json:"email_address"`
} `json:"account"`
}
// ClaudeAuth handles Anthropic OAuth2 authentication flow.
// It provides methods for generating authorization URLs, exchanging codes for tokens,
// and refreshing expired tokens using PKCE for enhanced security.
type ClaudeAuth struct {
httpClient *http.Client
}
// NewClaudeAuth creates a new Anthropic authentication service.
// It initializes the HTTP client with a custom TLS transport that uses Firefox
// fingerprint to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
//
// Parameters:
// - cfg: The application configuration containing proxy settings
//
// Returns:
// - *ClaudeAuth: A new Claude authentication service instance
func NewClaudeAuth(cfg *config.Config) *ClaudeAuth {
// Use custom HTTP client with Firefox TLS fingerprint to bypass
// Cloudflare's bot detection on Anthropic domains
return &ClaudeAuth{
httpClient: NewAnthropicHttpClient(&cfg.SDKConfig),
}
}
// GenerateAuthURL creates the OAuth authorization URL with PKCE.
// This method generates a secure authorization URL including PKCE challenge codes
// for the OAuth2 flow with Anthropic's API.
//
// Parameters:
// - state: A random state parameter for CSRF protection
// - pkceCodes: The PKCE codes for secure code exchange
//
// Returns:
// - string: The complete authorization URL
// - string: The state parameter for verification
// - error: An error if PKCE codes are missing or URL generation fails
func (o *ClaudeAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, string, error) {
if pkceCodes == nil {
return "", "", fmt.Errorf("PKCE codes are required")
}
params := url.Values{
"code": {"true"},
"client_id": {ClientID},
"response_type": {"code"},
"redirect_uri": {RedirectURI},
"scope": {"org:create_api_key user:profile user:inference"},
"code_challenge": {pkceCodes.CodeChallenge},
"code_challenge_method": {"S256"},
"state": {state},
}
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
return authURL, state, nil
}
// parseCodeAndState extracts the authorization code and state from the callback response.
// It handles the parsing of the code parameter which may contain additional fragments.
//
// Parameters:
// - code: The raw code parameter from the OAuth callback
//
// Returns:
// - parsedCode: The extracted authorization code
// - parsedState: The extracted state parameter if present
func (c *ClaudeAuth) parseCodeAndState(code string) (parsedCode, parsedState string) {
splits := strings.Split(code, "#")
parsedCode = splits[0]
if len(splits) > 1 {
parsedState = splits[1]
}
return
}
// ExchangeCodeForTokens exchanges authorization code for access tokens.
// This method implements the OAuth2 token exchange flow using PKCE for security.
// It sends the authorization code along with PKCE verifier to get access and refresh tokens.
//
// Parameters:
// - ctx: The context for the request
// - code: The authorization code received from OAuth callback
// - state: The state parameter for verification
// - pkceCodes: The PKCE codes for secure verification
//
// Returns:
// - *ClaudeAuthBundle: The complete authentication bundle with tokens
// - error: An error if token exchange fails
func (o *ClaudeAuth) ExchangeCodeForTokens(ctx context.Context, code, state string, pkceCodes *PKCECodes) (*ClaudeAuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("PKCE codes are required for token exchange")
}
newCode, newState := o.parseCodeAndState(code)
// Prepare token exchange request
reqBody := map[string]interface{}{
"code": newCode,
"state": state,
"grant_type": "authorization_code",
"client_id": ClientID,
"redirect_uri": RedirectURI,
"code_verifier": pkceCodes.CodeVerifier,
}
// Include state if present
if newState != "" {
reqBody["state"] = newState
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
// log.Debugf("Token exchange request: %s", string(jsonBody))
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token exchange request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("failed to close response body: %v", errClose)
}
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read token response: %w", err)
}
// log.Debugf("Token response: %s", string(body))
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
}
// log.Debugf("Token response: %s", string(body))
var tokenResp tokenResponse
if err = json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
// Create token data
tokenData := ClaudeTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
Email: tokenResp.Account.EmailAddress,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
}
// Create auth bundle
bundle := &ClaudeAuthBundle{
TokenData: tokenData,
LastRefresh: time.Now().Format(time.RFC3339),
}
return bundle, nil
}
// RefreshTokens refreshes the access token using the refresh token.
// This method exchanges a valid refresh token for a new access token,
// extending the user's authenticated session.
//
// Parameters:
// - ctx: The context for the request
// - refreshToken: The refresh token to use for getting new access token
//
// Returns:
// - *ClaudeTokenData: The new token data with updated access token
// - error: An error if token refresh fails
func (o *ClaudeAuth) RefreshTokens(ctx context.Context, refreshToken string) (*ClaudeTokenData, error) {
if refreshToken == "" {
return nil, fmt.Errorf("refresh token is required")
}
reqBody := map[string]interface{}{
"client_id": ClientID,
"grant_type": "refresh_token",
"refresh_token": refreshToken,
}
jsonBody, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("failed to marshal request body: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(string(jsonBody)))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token refresh request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
}
// log.Debugf("Token response: %s", string(body))
var tokenResp tokenResponse
if err = json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
// Create token data
return &ClaudeTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
Email: tokenResp.Account.EmailAddress,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
}, nil
}
// CreateTokenStorage creates a new ClaudeTokenStorage from auth bundle and user info.
// This method converts the authentication bundle into a token storage structure
// suitable for persistence and later use.
//
// Parameters:
// - bundle: The authentication bundle containing token data
//
// Returns:
// - *ClaudeTokenStorage: A new token storage instance
func (o *ClaudeAuth) CreateTokenStorage(bundle *ClaudeAuthBundle) *ClaudeTokenStorage {
storage := &ClaudeTokenStorage{
AccessToken: bundle.TokenData.AccessToken,
RefreshToken: bundle.TokenData.RefreshToken,
LastRefresh: bundle.LastRefresh,
Email: bundle.TokenData.Email,
Expire: bundle.TokenData.Expire,
}
return storage
}
// RefreshTokensWithRetry refreshes tokens with automatic retry logic.
// This method implements exponential backoff retry logic for token refresh operations,
// providing resilience against temporary network or service issues.
//
// Parameters:
// - ctx: The context for the request
// - refreshToken: The refresh token to use
// - maxRetries: The maximum number of retry attempts
//
// Returns:
// - *ClaudeTokenData: The refreshed token data
// - error: An error if all retry attempts fail
func (o *ClaudeAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*ClaudeTokenData, error) {
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
// Wait before retry
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Duration(attempt) * time.Second):
}
}
tokenData, err := o.RefreshTokens(ctx, refreshToken)
if err == nil {
return tokenData, nil
}
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
}
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
}
// UpdateTokenStorage updates an existing token storage with new token data.
// This method refreshes the token storage with newly obtained access and refresh tokens,
// updating timestamps and expiration information.
//
// Parameters:
// - storage: The existing token storage to update
// - tokenData: The new token data to apply
func (o *ClaudeAuth) UpdateTokenStorage(storage *ClaudeTokenStorage, tokenData *ClaudeTokenData) {
storage.AccessToken = tokenData.AccessToken
storage.RefreshToken = tokenData.RefreshToken
storage.LastRefresh = time.Now().Format(time.RFC3339)
storage.Email = tokenData.Email
storage.Expire = tokenData.Expire
}
================================================
FILE: internal/auth/claude/errors.go
================================================
// Package claude provides authentication and token management functionality
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the Claude API.
package claude
import (
"errors"
"fmt"
"net/http"
)
// OAuthError represents an OAuth-specific error.
type OAuthError struct {
// Code is the OAuth error code.
Code string `json:"error"`
// Description is a human-readable description of the error.
Description string `json:"error_description,omitempty"`
// URI is a URI identifying a human-readable web page with information about the error.
URI string `json:"error_uri,omitempty"`
// StatusCode is the HTTP status code associated with the error.
StatusCode int `json:"-"`
}
// Error returns a string representation of the OAuth error.
func (e *OAuthError) Error() string {
if e.Description != "" {
return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
}
return fmt.Sprintf("OAuth error: %s", e.Code)
}
// NewOAuthError creates a new OAuth error with the specified code, description, and status code.
func NewOAuthError(code, description string, statusCode int) *OAuthError {
return &OAuthError{
Code: code,
Description: description,
StatusCode: statusCode,
}
}
// AuthenticationError represents authentication-related errors.
type AuthenticationError struct {
// Type is the type of authentication error.
Type string `json:"type"`
// Message is a human-readable message describing the error.
Message string `json:"message"`
// Code is the HTTP status code associated with the error.
Code int `json:"code"`
// Cause is the underlying error that caused this authentication error.
Cause error `json:"-"`
}
// Error returns a string representation of the authentication error.
func (e *AuthenticationError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
}
return fmt.Sprintf("%s: %s", e.Type, e.Message)
}
// Common authentication error types.
var (
// ErrTokenExpired = &AuthenticationError{
// Type: "token_expired",
// Message: "Access token has expired",
// Code: http.StatusUnauthorized,
// }
// ErrInvalidState represents an error for invalid OAuth state parameter.
ErrInvalidState = &AuthenticationError{
Type: "invalid_state",
Message: "OAuth state parameter is invalid",
Code: http.StatusBadRequest,
}
// ErrCodeExchangeFailed represents an error when exchanging authorization code for tokens fails.
ErrCodeExchangeFailed = &AuthenticationError{
Type: "code_exchange_failed",
Message: "Failed to exchange authorization code for tokens",
Code: http.StatusBadRequest,
}
// ErrServerStartFailed represents an error when starting the OAuth callback server fails.
ErrServerStartFailed = &AuthenticationError{
Type: "server_start_failed",
Message: "Failed to start OAuth callback server",
Code: http.StatusInternalServerError,
}
// ErrPortInUse represents an error when the OAuth callback port is already in use.
ErrPortInUse = &AuthenticationError{
Type: "port_in_use",
Message: "OAuth callback port is already in use",
Code: 13, // Special exit code for port-in-use
}
// ErrCallbackTimeout represents an error when waiting for OAuth callback times out.
ErrCallbackTimeout = &AuthenticationError{
Type: "callback_timeout",
Message: "Timeout waiting for OAuth callback",
Code: http.StatusRequestTimeout,
}
)
// NewAuthenticationError creates a new authentication error with a cause based on a base error.
func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
return &AuthenticationError{
Type: baseErr.Type,
Message: baseErr.Message,
Code: baseErr.Code,
Cause: cause,
}
}
// IsAuthenticationError checks if an error is an authentication error.
func IsAuthenticationError(err error) bool {
var authenticationError *AuthenticationError
ok := errors.As(err, &authenticationError)
return ok
}
// IsOAuthError checks if an error is an OAuth error.
func IsOAuthError(err error) bool {
var oAuthError *OAuthError
ok := errors.As(err, &oAuthError)
return ok
}
// GetUserFriendlyMessage returns a user-friendly error message based on the error type.
func GetUserFriendlyMessage(err error) string {
switch {
case IsAuthenticationError(err):
var authErr *AuthenticationError
errors.As(err, &authErr)
switch authErr.Type {
case "token_expired":
return "Your authentication has expired. Please log in again."
case "token_invalid":
return "Your authentication is invalid. Please log in again."
case "authentication_required":
return "Please log in to continue."
case "port_in_use":
return "The required port is already in use. Please close any applications using port 3000 and try again."
case "callback_timeout":
return "Authentication timed out. Please try again."
case "browser_open_failed":
return "Could not open your browser automatically. Please copy and paste the URL manually."
default:
return "Authentication failed. Please try again."
}
case IsOAuthError(err):
var oauthErr *OAuthError
errors.As(err, &oauthErr)
switch oauthErr.Code {
case "access_denied":
return "Authentication was cancelled or denied."
case "invalid_request":
return "Invalid authentication request. Please try again."
case "server_error":
return "Authentication server error. Please try again later."
default:
return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
}
default:
return "An unexpected error occurred. Please try again."
}
}
================================================
FILE: internal/auth/claude/html_templates.go
================================================
// Package claude provides authentication and token management functionality
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the Claude API.
package claude
// LoginSuccessHtml is the HTML template displayed to users after successful OAuth authentication.
// This template provides a user-friendly success page with options to close the window
// or navigate to the Claude platform. It includes automatic window closing functionality
// and keyboard accessibility features.
const LoginSuccessHtml = `
Authentication Successful - Claude
✓
Authentication Successful!
You have successfully authenticated with Claude. You can now close this window and return to your terminal to continue.
{{SETUP_NOTICE}}
This window will close automatically in 10 seconds
`
// SetupNoticeHtml is the HTML template for the setup notice section.
// This template is embedded within the success page to inform users about
// additional setup steps required to complete their Claude account configuration.
const SetupNoticeHtml = `
Additional Setup Required
To complete your setup, please visit the Claude to configure your account.
`
================================================
FILE: internal/auth/claude/oauth_server.go
================================================
// Package claude provides authentication and token management functionality
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the Claude API.
package claude
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// OAuthServer handles the local HTTP server for OAuth callbacks.
// It listens for the authorization code response from the OAuth provider
// and captures the necessary parameters to complete the authentication flow.
type OAuthServer struct {
// server is the underlying HTTP server instance
server *http.Server
// port is the port number on which the server listens
port int
// resultChan is a channel for sending OAuth results
resultChan chan *OAuthResult
// errorChan is a channel for sending OAuth errors
errorChan chan error
// mu is a mutex for protecting server state
mu sync.Mutex
// running indicates whether the server is currently running
running bool
}
// OAuthResult contains the result of the OAuth callback.
// It holds either the authorization code and state for successful authentication
// or an error message if the authentication failed.
type OAuthResult struct {
// Code is the authorization code received from the OAuth provider
Code string
// State is the state parameter used to prevent CSRF attacks
State string
// Error contains any error message if the OAuth flow failed
Error string
}
// NewOAuthServer creates a new OAuth callback server.
// It initializes the server with the specified port and creates channels
// for handling OAuth results and errors.
//
// Parameters:
// - port: The port number on which the server should listen
//
// Returns:
// - *OAuthServer: A new OAuthServer instance
func NewOAuthServer(port int) *OAuthServer {
return &OAuthServer{
port: port,
resultChan: make(chan *OAuthResult, 1),
errorChan: make(chan error, 1),
}
}
// Start starts the OAuth callback server.
// It sets up the HTTP handlers for the callback and success endpoints,
// and begins listening on the specified port.
//
// Returns:
// - error: An error if the server fails to start
func (s *OAuthServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.running {
return fmt.Errorf("server is already running")
}
// Check if port is available
if !s.isPortAvailable() {
return fmt.Errorf("port %d is already in use", s.port)
}
mux := http.NewServeMux()
mux.HandleFunc("/callback", s.handleCallback)
mux.HandleFunc("/success", s.handleSuccess)
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
s.running = true
// Start server in goroutine
go func() {
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.errorChan <- fmt.Errorf("server failed to start: %w", err)
}
}()
// Give server a moment to start
time.Sleep(100 * time.Millisecond)
return nil
}
// Stop gracefully stops the OAuth callback server.
// It performs a graceful shutdown of the HTTP server with a timeout.
//
// Parameters:
// - ctx: The context for controlling the shutdown process
//
// Returns:
// - error: An error if the server fails to stop gracefully
func (s *OAuthServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running || s.server == nil {
return nil
}
log.Debug("Stopping OAuth callback server")
// Create a context with timeout for shutdown
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
err := s.server.Shutdown(shutdownCtx)
s.running = false
s.server = nil
return err
}
// WaitForCallback waits for the OAuth callback with a timeout.
// It blocks until either an OAuth result is received, an error occurs,
// or the specified timeout is reached.
//
// Parameters:
// - timeout: The maximum time to wait for the callback
//
// Returns:
// - *OAuthResult: The OAuth result if successful
// - error: An error if the callback times out or an error occurs
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
select {
case result := <-s.resultChan:
return result, nil
case err := <-s.errorChan:
return nil, err
case <-time.After(timeout):
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
}
// handleCallback handles the OAuth callback endpoint.
// It extracts the authorization code and state from the callback URL,
// validates the parameters, and sends the result to the waiting channel.
//
// Parameters:
// - w: The HTTP response writer
// - r: The HTTP request
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
log.Debug("Received OAuth callback")
// Validate request method
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Extract parameters
query := r.URL.Query()
code := query.Get("code")
state := query.Get("state")
errorParam := query.Get("error")
// Validate required parameters
if errorParam != "" {
log.Errorf("OAuth error received: %s", errorParam)
result := &OAuthResult{
Error: errorParam,
}
s.sendResult(result)
http.Error(w, fmt.Sprintf("OAuth error: %s", errorParam), http.StatusBadRequest)
return
}
if code == "" {
log.Error("No authorization code received")
result := &OAuthResult{
Error: "no_code",
}
s.sendResult(result)
http.Error(w, "No authorization code received", http.StatusBadRequest)
return
}
if state == "" {
log.Error("No state parameter received")
result := &OAuthResult{
Error: "no_state",
}
s.sendResult(result)
http.Error(w, "No state parameter received", http.StatusBadRequest)
return
}
// Send successful result
result := &OAuthResult{
Code: code,
State: state,
}
s.sendResult(result)
// Redirect to success page
http.Redirect(w, r, "/success", http.StatusFound)
}
// handleSuccess handles the success page endpoint.
// It serves a user-friendly HTML page indicating that authentication was successful.
//
// Parameters:
// - w: The HTTP response writer
// - r: The HTTP request
func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
log.Debug("Serving success page")
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
// Parse query parameters for customization
query := r.URL.Query()
setupRequired := query.Get("setup_required") == "true"
platformURL := query.Get("platform_url")
if platformURL == "" {
platformURL = "https://console.anthropic.com/"
}
// Generate success page HTML with dynamic content
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
_, err := w.Write([]byte(successHTML))
if err != nil {
log.Errorf("Failed to write success page: %v", err)
}
}
// generateSuccessHTML creates the HTML content for the success page.
// It customizes the page based on whether additional setup is required
// and includes a link to the platform.
//
// Parameters:
// - setupRequired: Whether additional setup is required after authentication
// - platformURL: The URL to the platform for additional setup
//
// Returns:
// - string: The HTML content for the success page
func (s *OAuthServer) generateSuccessHTML(setupRequired bool, platformURL string) string {
html := LoginSuccessHtml
// Replace platform URL placeholder
html = strings.Replace(html, "{{PLATFORM_URL}}", platformURL, -1)
// Add setup notice if required
if setupRequired {
setupNotice := strings.Replace(SetupNoticeHtml, "{{PLATFORM_URL}}", platformURL, -1)
html = strings.Replace(html, "{{SETUP_NOTICE}}", setupNotice, 1)
} else {
html = strings.Replace(html, "{{SETUP_NOTICE}}", "", 1)
}
return html
}
// sendResult sends the OAuth result to the waiting channel.
// It ensures that the result is sent without blocking the handler.
//
// Parameters:
// - result: The OAuth result to send
func (s *OAuthServer) sendResult(result *OAuthResult) {
select {
case s.resultChan <- result:
log.Debug("OAuth result sent to channel")
default:
log.Warn("OAuth result channel is full, result dropped")
}
}
// isPortAvailable checks if the specified port is available.
// It attempts to listen on the port to determine availability.
//
// Returns:
// - bool: True if the port is available, false otherwise
func (s *OAuthServer) isPortAvailable() bool {
addr := fmt.Sprintf(":%d", s.port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return false
}
defer func() {
_ = listener.Close()
}()
return true
}
// IsRunning returns whether the server is currently running.
//
// Returns:
// - bool: True if the server is running, false otherwise
func (s *OAuthServer) IsRunning() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.running
}
================================================
FILE: internal/auth/claude/pkce.go
================================================
// Package claude provides authentication and token management functionality
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the Claude API.
package claude
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
)
// GeneratePKCECodes generates a PKCE code verifier and challenge pair
// following RFC 7636 specifications for OAuth 2.0 PKCE extension.
// This provides additional security for the OAuth flow by ensuring that
// only the client that initiated the request can exchange the authorization code.
//
// Returns:
// - *PKCECodes: A struct containing the code verifier and challenge
// - error: An error if the generation fails, nil otherwise
func GeneratePKCECodes() (*PKCECodes, error) {
// Generate code verifier: 43-128 characters, URL-safe
codeVerifier, err := generateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
// Generate code challenge using S256 method
codeChallenge := generateCodeChallenge(codeVerifier)
return &PKCECodes{
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
}, nil
}
// generateCodeVerifier creates a cryptographically random string
// of 128 characters using URL-safe base64 encoding
func generateCodeVerifier() (string, error) {
// Generate 96 random bytes (will result in 128 base64 characters)
bytes := make([]byte, 96)
_, err := rand.Read(bytes)
if err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Encode to URL-safe base64 without padding
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes), nil
}
// generateCodeChallenge creates a SHA256 hash of the code verifier
// and encodes it using URL-safe base64 encoding without padding
func generateCodeChallenge(codeVerifier string) string {
hash := sha256.Sum256([]byte(codeVerifier))
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
}
================================================
FILE: internal/auth/claude/token.go
================================================
// Package claude provides authentication and token management functionality
// for Anthropic's Claude AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the Claude API.
package claude
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
// ClaudeTokenStorage stores OAuth2 token information for Anthropic Claude API authentication.
// It maintains compatibility with the existing auth system while adding Claude-specific fields
// for managing access tokens, refresh tokens, and user account information.
type ClaudeTokenStorage struct {
// IDToken is the JWT ID token containing user claims and identity information.
IDToken string `json:"id_token"`
// AccessToken is the OAuth2 access token used for authenticating API requests.
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain new access tokens when the current one expires.
RefreshToken string `json:"refresh_token"`
// LastRefresh is the timestamp of the last token refresh operation.
LastRefresh string `json:"last_refresh"`
// Email is the Anthropic account email address associated with this token.
Email string `json:"email"`
// Type indicates the authentication provider type, always "claude" for this storage.
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Claude token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
//
// Returns:
// - error: An error if the operation fails, nil otherwise
func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "claude"
// Create directory structure if it doesn't exist
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
// Create the token file
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
// Encode and write the token data as JSON
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
================================================
FILE: internal/auth/claude/utls_transport.go
================================================
// Package claude provides authentication functionality for Anthropic's Claude API.
// This file implements a custom HTTP transport using utls to bypass TLS fingerprinting.
package claude
import (
"net/http"
"strings"
"sync"
tls "github.com/refraction-networking/utls"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"golang.org/x/net/http2"
"golang.org/x/net/proxy"
)
// utlsRoundTripper implements http.RoundTripper using utls with Chrome fingerprint
// to bypass Cloudflare's TLS fingerprinting on Anthropic domains.
type utlsRoundTripper struct {
// mu protects the connections map and pending map
mu sync.Mutex
// connections caches HTTP/2 client connections per host
connections map[string]*http2.ClientConn
// pending tracks hosts that are currently being connected to (prevents race condition)
pending map[string]*sync.Cond
// dialer is used to create network connections, supporting proxies
dialer proxy.Dialer
}
// newUtlsRoundTripper creates a new utls-based round tripper with optional proxy support
func newUtlsRoundTripper(cfg *config.SDKConfig) *utlsRoundTripper {
var dialer proxy.Dialer = proxy.Direct
if cfg != nil {
proxyDialer, mode, errBuild := proxyutil.BuildDialer(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("failed to configure proxy dialer for %q: %v", cfg.ProxyURL, errBuild)
} else if mode != proxyutil.ModeInherit && proxyDialer != nil {
dialer = proxyDialer
}
}
return &utlsRoundTripper{
connections: make(map[string]*http2.ClientConn),
pending: make(map[string]*sync.Cond),
dialer: dialer,
}
}
// getOrCreateConnection gets an existing connection or creates a new one.
// It uses a per-host locking mechanism to prevent multiple goroutines from
// creating connections to the same host simultaneously.
func (t *utlsRoundTripper) getOrCreateConnection(host, addr string) (*http2.ClientConn, error) {
t.mu.Lock()
// Check if connection exists and is usable
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
t.mu.Unlock()
return h2Conn, nil
}
// Check if another goroutine is already creating a connection
if cond, ok := t.pending[host]; ok {
// Wait for the other goroutine to finish
cond.Wait()
// Check if connection is now available
if h2Conn, ok := t.connections[host]; ok && h2Conn.CanTakeNewRequest() {
t.mu.Unlock()
return h2Conn, nil
}
// Connection still not available, we'll create one
}
// Mark this host as pending
cond := sync.NewCond(&t.mu)
t.pending[host] = cond
t.mu.Unlock()
// Create connection outside the lock
h2Conn, err := t.createConnection(host, addr)
t.mu.Lock()
defer t.mu.Unlock()
// Remove pending marker and wake up waiting goroutines
delete(t.pending, host)
cond.Broadcast()
if err != nil {
return nil, err
}
// Store the new connection
t.connections[host] = h2Conn
return h2Conn, nil
}
// createConnection creates a new HTTP/2 connection with Chrome TLS fingerprint.
// Chrome's TLS fingerprint is closer to Node.js/OpenSSL (which real Claude Code uses)
// than Firefox, reducing the mismatch between TLS layer and HTTP headers.
func (t *utlsRoundTripper) createConnection(host, addr string) (*http2.ClientConn, error) {
conn, err := t.dialer.Dial("tcp", addr)
if err != nil {
return nil, err
}
tlsConfig := &tls.Config{ServerName: host}
tlsConn := tls.UClient(conn, tlsConfig, tls.HelloChrome_Auto)
if err := tlsConn.Handshake(); err != nil {
conn.Close()
return nil, err
}
tr := &http2.Transport{}
h2Conn, err := tr.NewClientConn(tlsConn)
if err != nil {
tlsConn.Close()
return nil, err
}
return h2Conn, nil
}
// RoundTrip implements http.RoundTripper
func (t *utlsRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
host := req.URL.Host
addr := host
if !strings.Contains(addr, ":") {
addr += ":443"
}
// Get hostname without port for TLS ServerName
hostname := req.URL.Hostname()
h2Conn, err := t.getOrCreateConnection(hostname, addr)
if err != nil {
return nil, err
}
resp, err := h2Conn.RoundTrip(req)
if err != nil {
// Connection failed, remove it from cache
t.mu.Lock()
if cached, ok := t.connections[hostname]; ok && cached == h2Conn {
delete(t.connections, hostname)
}
t.mu.Unlock()
return nil, err
}
return resp, nil
}
// NewAnthropicHttpClient creates an HTTP client that bypasses TLS fingerprinting
// for Anthropic domains by using utls with Chrome fingerprint.
// It accepts optional SDK configuration for proxy settings.
func NewAnthropicHttpClient(cfg *config.SDKConfig) *http.Client {
return &http.Client{
Transport: newUtlsRoundTripper(cfg),
}
}
================================================
FILE: internal/auth/codex/errors.go
================================================
package codex
import (
"errors"
"fmt"
"net/http"
)
// OAuthError represents an OAuth-specific error.
type OAuthError struct {
// Code is the OAuth error code.
Code string `json:"error"`
// Description is a human-readable description of the error.
Description string `json:"error_description,omitempty"`
// URI is a URI identifying a human-readable web page with information about the error.
URI string `json:"error_uri,omitempty"`
// StatusCode is the HTTP status code associated with the error.
StatusCode int `json:"-"`
}
// Error returns a string representation of the OAuth error.
func (e *OAuthError) Error() string {
if e.Description != "" {
return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
}
return fmt.Sprintf("OAuth error: %s", e.Code)
}
// NewOAuthError creates a new OAuth error with the specified code, description, and status code.
func NewOAuthError(code, description string, statusCode int) *OAuthError {
return &OAuthError{
Code: code,
Description: description,
StatusCode: statusCode,
}
}
// AuthenticationError represents authentication-related errors.
type AuthenticationError struct {
// Type is the type of authentication error.
Type string `json:"type"`
// Message is a human-readable message describing the error.
Message string `json:"message"`
// Code is the HTTP status code associated with the error.
Code int `json:"code"`
// Cause is the underlying error that caused this authentication error.
Cause error `json:"-"`
}
// Error returns a string representation of the authentication error.
func (e *AuthenticationError) Error() string {
if e.Cause != nil {
return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
}
return fmt.Sprintf("%s: %s", e.Type, e.Message)
}
// Common authentication error types.
var (
// ErrTokenExpired = &AuthenticationError{
// Type: "token_expired",
// Message: "Access token has expired",
// Code: http.StatusUnauthorized,
// }
// ErrInvalidState represents an error for invalid OAuth state parameter.
ErrInvalidState = &AuthenticationError{
Type: "invalid_state",
Message: "OAuth state parameter is invalid",
Code: http.StatusBadRequest,
}
// ErrCodeExchangeFailed represents an error when exchanging authorization code for tokens fails.
ErrCodeExchangeFailed = &AuthenticationError{
Type: "code_exchange_failed",
Message: "Failed to exchange authorization code for tokens",
Code: http.StatusBadRequest,
}
// ErrServerStartFailed represents an error when starting the OAuth callback server fails.
ErrServerStartFailed = &AuthenticationError{
Type: "server_start_failed",
Message: "Failed to start OAuth callback server",
Code: http.StatusInternalServerError,
}
// ErrPortInUse represents an error when the OAuth callback port is already in use.
ErrPortInUse = &AuthenticationError{
Type: "port_in_use",
Message: "OAuth callback port is already in use",
Code: 13, // Special exit code for port-in-use
}
// ErrCallbackTimeout represents an error when waiting for OAuth callback times out.
ErrCallbackTimeout = &AuthenticationError{
Type: "callback_timeout",
Message: "Timeout waiting for OAuth callback",
Code: http.StatusRequestTimeout,
}
// ErrBrowserOpenFailed represents an error when opening the browser for authentication fails.
ErrBrowserOpenFailed = &AuthenticationError{
Type: "browser_open_failed",
Message: "Failed to open browser for authentication",
Code: http.StatusInternalServerError,
}
)
// NewAuthenticationError creates a new authentication error with a cause based on a base error.
func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
return &AuthenticationError{
Type: baseErr.Type,
Message: baseErr.Message,
Code: baseErr.Code,
Cause: cause,
}
}
// IsAuthenticationError checks if an error is an authentication error.
func IsAuthenticationError(err error) bool {
var authenticationError *AuthenticationError
ok := errors.As(err, &authenticationError)
return ok
}
// IsOAuthError checks if an error is an OAuth error.
func IsOAuthError(err error) bool {
var oAuthError *OAuthError
ok := errors.As(err, &oAuthError)
return ok
}
// GetUserFriendlyMessage returns a user-friendly error message based on the error type.
func GetUserFriendlyMessage(err error) string {
switch {
case IsAuthenticationError(err):
var authErr *AuthenticationError
errors.As(err, &authErr)
switch authErr.Type {
case "token_expired":
return "Your authentication has expired. Please log in again."
case "token_invalid":
return "Your authentication is invalid. Please log in again."
case "authentication_required":
return "Please log in to continue."
case "port_in_use":
return "The required port is already in use. Please close any applications using port 3000 and try again."
case "callback_timeout":
return "Authentication timed out. Please try again."
case "browser_open_failed":
return "Could not open your browser automatically. Please copy and paste the URL manually."
default:
return "Authentication failed. Please try again."
}
case IsOAuthError(err):
var oauthErr *OAuthError
errors.As(err, &oauthErr)
switch oauthErr.Code {
case "access_denied":
return "Authentication was cancelled or denied."
case "invalid_request":
return "Invalid authentication request. Please try again."
case "server_error":
return "Authentication server error. Please try again later."
default:
return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
}
default:
return "An unexpected error occurred. Please try again."
}
}
================================================
FILE: internal/auth/codex/filename.go
================================================
package codex
import (
"fmt"
"strings"
"unicode"
)
// CredentialFileName returns the filename used to persist Codex OAuth credentials.
// When planType is available (e.g. "plus", "team"), it is appended after the email
// as a suffix to disambiguate subscriptions.
func CredentialFileName(email, planType, hashAccountID string, includeProviderPrefix bool) string {
email = strings.TrimSpace(email)
plan := normalizePlanTypeForFilename(planType)
prefix := ""
if includeProviderPrefix {
prefix = "codex"
}
if plan == "" {
return fmt.Sprintf("%s-%s.json", prefix, email)
} else if plan == "team" {
return fmt.Sprintf("%s-%s-%s-%s.json", prefix, hashAccountID, email, plan)
}
return fmt.Sprintf("%s-%s-%s.json", prefix, email, plan)
}
func normalizePlanTypeForFilename(planType string) string {
planType = strings.TrimSpace(planType)
if planType == "" {
return ""
}
parts := strings.FieldsFunc(planType, func(r rune) bool {
return !unicode.IsLetter(r) && !unicode.IsDigit(r)
})
if len(parts) == 0 {
return ""
}
for i, part := range parts {
parts[i] = strings.ToLower(strings.TrimSpace(part))
}
return strings.Join(parts, "-")
}
================================================
FILE: internal/auth/codex/html_templates.go
================================================
package codex
// LoginSuccessHTML is the HTML template for the page shown after a successful
// OAuth2 authentication with Codex. It informs the user that the authentication
// was successful and provides a countdown timer to automatically close the window.
const LoginSuccessHtml = `
Authentication Successful - Codex
✓
Authentication Successful!
You have successfully authenticated with Codex. You can now close this window and return to your terminal to continue.
{{SETUP_NOTICE}}
This window will close automatically in 10 seconds
`
// SetupNoticeHTML is the HTML template for the section that provides instructions
// for additional setup. This is displayed on the success page when further actions
// are required from the user.
const SetupNoticeHtml = `
Additional Setup Required
To complete your setup, please visit the Codex to configure your account.
`
================================================
FILE: internal/auth/codex/jwt_parser.go
================================================
package codex
import (
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
)
// JWTClaims represents the claims section of a JSON Web Token (JWT).
// It includes standard claims like issuer, subject, and expiration time, as well as
// custom claims specific to OpenAI's authentication.
type JWTClaims struct {
AtHash string `json:"at_hash"`
Aud []string `json:"aud"`
AuthProvider string `json:"auth_provider"`
AuthTime int `json:"auth_time"`
Email string `json:"email"`
EmailVerified bool `json:"email_verified"`
Exp int `json:"exp"`
CodexAuthInfo CodexAuthInfo `json:"https://api.openai.com/auth"`
Iat int `json:"iat"`
Iss string `json:"iss"`
Jti string `json:"jti"`
Rat int `json:"rat"`
Sid string `json:"sid"`
Sub string `json:"sub"`
}
// Organizations defines the structure for organization details within the JWT claims.
// It holds information about the user's organization, such as ID, role, and title.
type Organizations struct {
ID string `json:"id"`
IsDefault bool `json:"is_default"`
Role string `json:"role"`
Title string `json:"title"`
}
// CodexAuthInfo contains authentication-related details specific to Codex.
// This includes ChatGPT account information, subscription status, and user/organization IDs.
type CodexAuthInfo struct {
ChatgptAccountID string `json:"chatgpt_account_id"`
ChatgptPlanType string `json:"chatgpt_plan_type"`
ChatgptSubscriptionActiveStart any `json:"chatgpt_subscription_active_start"`
ChatgptSubscriptionActiveUntil any `json:"chatgpt_subscription_active_until"`
ChatgptSubscriptionLastChecked time.Time `json:"chatgpt_subscription_last_checked"`
ChatgptUserID string `json:"chatgpt_user_id"`
Groups []any `json:"groups"`
Organizations []Organizations `json:"organizations"`
UserID string `json:"user_id"`
}
// ParseJWTToken parses a JWT token string and extracts its claims without performing
// cryptographic signature verification. This is useful for introspecting the token's
// contents to retrieve user information from an ID token after it has been validated
// by the authentication server.
func ParseJWTToken(token string) (*JWTClaims, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid JWT token format: expected 3 parts, got %d", len(parts))
}
// Decode the claims (payload) part
claimsData, err := base64URLDecode(parts[1])
if err != nil {
return nil, fmt.Errorf("failed to decode JWT claims: %w", err)
}
var claims JWTClaims
if err = json.Unmarshal(claimsData, &claims); err != nil {
return nil, fmt.Errorf("failed to unmarshal JWT claims: %w", err)
}
return &claims, nil
}
// base64URLDecode decodes a Base64 URL-encoded string, adding padding if necessary.
// JWTs use a URL-safe Base64 alphabet and omit padding, so this function ensures
// correct decoding by re-adding the padding before decoding.
func base64URLDecode(data string) ([]byte, error) {
// Add padding if necessary
switch len(data) % 4 {
case 2:
data += "=="
case 3:
data += "="
}
return base64.URLEncoding.DecodeString(data)
}
// GetUserEmail extracts the user's email address from the JWT claims.
func (c *JWTClaims) GetUserEmail() string {
return c.Email
}
// GetAccountID extracts the user's account ID (subject) from the JWT claims.
// It retrieves the unique identifier for the user's ChatGPT account.
func (c *JWTClaims) GetAccountID() string {
return c.CodexAuthInfo.ChatgptAccountID
}
================================================
FILE: internal/auth/codex/oauth_server.go
================================================
package codex
import (
"context"
"errors"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// OAuthServer handles the local HTTP server for OAuth callbacks.
// It listens for the authorization code response from the OAuth provider
// and captures the necessary parameters to complete the authentication flow.
type OAuthServer struct {
// server is the underlying HTTP server instance
server *http.Server
// port is the port number on which the server listens
port int
// resultChan is a channel for sending OAuth results
resultChan chan *OAuthResult
// errorChan is a channel for sending OAuth errors
errorChan chan error
// mu is a mutex for protecting server state
mu sync.Mutex
// running indicates whether the server is currently running
running bool
}
// OAuthResult contains the result of the OAuth callback.
// It holds either the authorization code and state for successful authentication
// or an error message if the authentication failed.
type OAuthResult struct {
// Code is the authorization code received from the OAuth provider
Code string
// State is the state parameter used to prevent CSRF attacks
State string
// Error contains any error message if the OAuth flow failed
Error string
}
// NewOAuthServer creates a new OAuth callback server.
// It initializes the server with the specified port and creates channels
// for handling OAuth results and errors.
//
// Parameters:
// - port: The port number on which the server should listen
//
// Returns:
// - *OAuthServer: A new OAuthServer instance
func NewOAuthServer(port int) *OAuthServer {
return &OAuthServer{
port: port,
resultChan: make(chan *OAuthResult, 1),
errorChan: make(chan error, 1),
}
}
// Start starts the OAuth callback server.
// It sets up the HTTP handlers for the callback and success endpoints,
// and begins listening on the specified port.
//
// Returns:
// - error: An error if the server fails to start
func (s *OAuthServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.running {
return fmt.Errorf("server is already running")
}
// Check if port is available
if !s.isPortAvailable() {
return fmt.Errorf("port %d is already in use", s.port)
}
mux := http.NewServeMux()
mux.HandleFunc("/auth/callback", s.handleCallback)
mux.HandleFunc("/success", s.handleSuccess)
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
s.running = true
// Start server in goroutine
go func() {
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
s.errorChan <- fmt.Errorf("server failed to start: %w", err)
}
}()
// Give server a moment to start
time.Sleep(100 * time.Millisecond)
return nil
}
// Stop gracefully stops the OAuth callback server.
// It performs a graceful shutdown of the HTTP server with a timeout.
//
// Parameters:
// - ctx: The context for controlling the shutdown process
//
// Returns:
// - error: An error if the server fails to stop gracefully
func (s *OAuthServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running || s.server == nil {
return nil
}
log.Debug("Stopping OAuth callback server")
// Create a context with timeout for shutdown
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
defer cancel()
err := s.server.Shutdown(shutdownCtx)
s.running = false
s.server = nil
return err
}
// WaitForCallback waits for the OAuth callback with a timeout.
// It blocks until either an OAuth result is received, an error occurs,
// or the specified timeout is reached.
//
// Parameters:
// - timeout: The maximum time to wait for the callback
//
// Returns:
// - *OAuthResult: The OAuth result if successful
// - error: An error if the callback times out or an error occurs
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
select {
case result := <-s.resultChan:
return result, nil
case err := <-s.errorChan:
return nil, err
case <-time.After(timeout):
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
}
// handleCallback handles the OAuth callback endpoint.
// It extracts the authorization code and state from the callback URL,
// validates the parameters, and sends the result to the waiting channel.
//
// Parameters:
// - w: The HTTP response writer
// - r: The HTTP request
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
log.Debug("Received OAuth callback")
// Validate request method
if r.Method != http.MethodGet {
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
return
}
// Extract parameters
query := r.URL.Query()
code := query.Get("code")
state := query.Get("state")
errorParam := query.Get("error")
// Validate required parameters
if errorParam != "" {
log.Errorf("OAuth error received: %s", errorParam)
result := &OAuthResult{
Error: errorParam,
}
s.sendResult(result)
http.Error(w, fmt.Sprintf("OAuth error: %s", errorParam), http.StatusBadRequest)
return
}
if code == "" {
log.Error("No authorization code received")
result := &OAuthResult{
Error: "no_code",
}
s.sendResult(result)
http.Error(w, "No authorization code received", http.StatusBadRequest)
return
}
if state == "" {
log.Error("No state parameter received")
result := &OAuthResult{
Error: "no_state",
}
s.sendResult(result)
http.Error(w, "No state parameter received", http.StatusBadRequest)
return
}
// Send successful result
result := &OAuthResult{
Code: code,
State: state,
}
s.sendResult(result)
// Redirect to success page
http.Redirect(w, r, "/success", http.StatusFound)
}
// handleSuccess handles the success page endpoint.
// It serves a user-friendly HTML page indicating that authentication was successful.
//
// Parameters:
// - w: The HTTP response writer
// - r: The HTTP request
func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
log.Debug("Serving success page")
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.WriteHeader(http.StatusOK)
// Parse query parameters for customization
query := r.URL.Query()
setupRequired := query.Get("setup_required") == "true"
platformURL := query.Get("platform_url")
if platformURL == "" {
platformURL = "https://platform.openai.com"
}
// Generate success page HTML with dynamic content
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
_, err := w.Write([]byte(successHTML))
if err != nil {
log.Errorf("Failed to write success page: %v", err)
}
}
// generateSuccessHTML creates the HTML content for the success page.
// It customizes the page based on whether additional setup is required
// and includes a link to the platform.
//
// Parameters:
// - setupRequired: Whether additional setup is required after authentication
// - platformURL: The URL to the platform for additional setup
//
// Returns:
// - string: The HTML content for the success page
func (s *OAuthServer) generateSuccessHTML(setupRequired bool, platformURL string) string {
html := LoginSuccessHtml
// Replace platform URL placeholder
html = strings.Replace(html, "{{PLATFORM_URL}}", platformURL, -1)
// Add setup notice if required
if setupRequired {
setupNotice := strings.Replace(SetupNoticeHtml, "{{PLATFORM_URL}}", platformURL, -1)
html = strings.Replace(html, "{{SETUP_NOTICE}}", setupNotice, 1)
} else {
html = strings.Replace(html, "{{SETUP_NOTICE}}", "", 1)
}
return html
}
// sendResult sends the OAuth result to the waiting channel.
// It ensures that the result is sent without blocking the handler.
//
// Parameters:
// - result: The OAuth result to send
func (s *OAuthServer) sendResult(result *OAuthResult) {
select {
case s.resultChan <- result:
log.Debug("OAuth result sent to channel")
default:
log.Warn("OAuth result channel is full, result dropped")
}
}
// isPortAvailable checks if the specified port is available.
// It attempts to listen on the port to determine availability.
//
// Returns:
// - bool: True if the port is available, false otherwise
func (s *OAuthServer) isPortAvailable() bool {
addr := fmt.Sprintf(":%d", s.port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return false
}
defer func() {
_ = listener.Close()
}()
return true
}
// IsRunning returns whether the server is currently running.
//
// Returns:
// - bool: True if the server is running, false otherwise
func (s *OAuthServer) IsRunning() bool {
s.mu.Lock()
defer s.mu.Unlock()
return s.running
}
================================================
FILE: internal/auth/codex/openai.go
================================================
package codex
// PKCECodes holds the verification codes for the OAuth2 PKCE (Proof Key for Code Exchange) flow.
// PKCE is an extension to the Authorization Code flow to prevent CSRF and authorization code injection attacks.
type PKCECodes struct {
// CodeVerifier is the cryptographically random string used to correlate
// the authorization request to the token request
CodeVerifier string `json:"code_verifier"`
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
CodeChallenge string `json:"code_challenge"`
}
// CodexTokenData holds the OAuth token information obtained from OpenAI.
// It includes the ID token, access token, refresh token, and associated user details.
type CodexTokenData struct {
// IDToken is the JWT ID token containing user claims
IDToken string `json:"id_token"`
// AccessToken is the OAuth2 access token for API access
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain new access tokens
RefreshToken string `json:"refresh_token"`
// AccountID is the OpenAI account identifier
AccountID string `json:"account_id"`
// Email is the OpenAI account email
Email string `json:"email"`
// Expire is the timestamp of the token expire
Expire string `json:"expired"`
}
// CodexAuthBundle aggregates all authentication-related data after the OAuth flow is complete.
// This includes the API key, token data, and the timestamp of the last refresh.
type CodexAuthBundle struct {
// APIKey is the OpenAI API key obtained from token exchange
APIKey string `json:"api_key"`
// TokenData contains the OAuth tokens from the authentication flow
TokenData CodexTokenData `json:"token_data"`
// LastRefresh is the timestamp of the last token refresh
LastRefresh string `json:"last_refresh"`
}
================================================
FILE: internal/auth/codex/openai_auth.go
================================================
// Package codex provides authentication and token management for OpenAI's Codex API.
// It handles the OAuth2 flow, including generating authorization URLs, exchanging
// authorization codes for tokens, and refreshing expired tokens. The package also
// defines data structures for storing and managing Codex authentication credentials.
package codex
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// OAuth configuration constants for OpenAI Codex
const (
AuthURL = "https://auth.openai.com/oauth/authorize"
TokenURL = "https://auth.openai.com/oauth/token"
ClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
RedirectURI = "http://localhost:1455/auth/callback"
)
// CodexAuth handles the OpenAI OAuth2 authentication flow.
// It manages the HTTP client and provides methods for generating authorization URLs,
// exchanging authorization codes for tokens, and refreshing access tokens.
type CodexAuth struct {
httpClient *http.Client
}
// NewCodexAuth creates a new CodexAuth service instance.
// It initializes an HTTP client with proxy settings from the provided configuration.
func NewCodexAuth(cfg *config.Config) *CodexAuth {
return &CodexAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
}
}
// GenerateAuthURL creates the OAuth authorization URL with PKCE (Proof Key for Code Exchange).
// It constructs the URL with the necessary parameters, including the client ID,
// response type, redirect URI, scopes, and PKCE challenge.
func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, error) {
if pkceCodes == nil {
return "", fmt.Errorf("PKCE codes are required")
}
params := url.Values{
"client_id": {ClientID},
"response_type": {"code"},
"redirect_uri": {RedirectURI},
"scope": {"openid email profile offline_access"},
"state": {state},
"code_challenge": {pkceCodes.CodeChallenge},
"code_challenge_method": {"S256"},
"prompt": {"login"},
"id_token_add_organizations": {"true"},
"codex_cli_simplified_flow": {"true"},
}
authURL := fmt.Sprintf("%s?%s", AuthURL, params.Encode())
return authURL, nil
}
// ExchangeCodeForTokens exchanges an authorization code for access and refresh tokens.
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
// authorization code and PKCE verifier.
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
}
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
// a caller-provided redirect URI. This supports alternate auth flows such as device
// login while preserving the existing token parsing and storage behavior.
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
if pkceCodes == nil {
return nil, fmt.Errorf("PKCE codes are required for token exchange")
}
if strings.TrimSpace(redirectURI) == "" {
return nil, fmt.Errorf("redirect URI is required for token exchange")
}
// Prepare token exchange request
data := url.Values{
"grant_type": {"authorization_code"},
"client_id": {ClientID},
"code": {code},
"redirect_uri": {strings.TrimSpace(redirectURI)},
"code_verifier": {pkceCodes.CodeVerifier},
}
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token exchange request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read token response: %w", err)
}
// log.Debugf("Token response: %s", string(body))
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
}
// Parse token response
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
if err = json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
// Extract account ID from ID token
claims, err := ParseJWTToken(tokenResp.IDToken)
if err != nil {
log.Warnf("Failed to parse ID token: %v", err)
}
accountID := ""
email := ""
if claims != nil {
accountID = claims.GetAccountID()
email = claims.GetUserEmail()
}
// Create token data
tokenData := CodexTokenData{
IDToken: tokenResp.IDToken,
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
AccountID: accountID,
Email: email,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
}
// Create auth bundle
bundle := &CodexAuthBundle{
TokenData: tokenData,
LastRefresh: time.Now().Format(time.RFC3339),
}
return bundle, nil
}
// RefreshTokens refreshes an access token using a refresh token.
// This method is called when an access token has expired. It makes a request to the
// token endpoint to obtain a new set of tokens.
func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*CodexTokenData, error) {
if refreshToken == "" {
return nil, fmt.Errorf("refresh token is required")
}
data := url.Values{
"client_id": {ClientID},
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"scope": {"openid profile email"},
}
req, err := http.NewRequestWithContext(ctx, "POST", TokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := o.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("token refresh request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read refresh response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
}
if err = json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
}
// Extract account ID from ID token
claims, err := ParseJWTToken(tokenResp.IDToken)
if err != nil {
log.Warnf("Failed to parse refreshed ID token: %v", err)
}
accountID := ""
email := ""
if claims != nil {
accountID = claims.GetAccountID()
email = claims.Email
}
return &CodexTokenData{
IDToken: tokenResp.IDToken,
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
AccountID: accountID,
Email: email,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
}, nil
}
// CreateTokenStorage creates a new CodexTokenStorage from a CodexAuthBundle.
// It populates the storage struct with token data, user information, and timestamps.
func (o *CodexAuth) CreateTokenStorage(bundle *CodexAuthBundle) *CodexTokenStorage {
storage := &CodexTokenStorage{
IDToken: bundle.TokenData.IDToken,
AccessToken: bundle.TokenData.AccessToken,
RefreshToken: bundle.TokenData.RefreshToken,
AccountID: bundle.TokenData.AccountID,
LastRefresh: bundle.LastRefresh,
Email: bundle.TokenData.Email,
Expire: bundle.TokenData.Expire,
}
return storage
}
// RefreshTokensWithRetry refreshes tokens with a built-in retry mechanism.
// It attempts to refresh the tokens up to a specified maximum number of retries,
// with an exponential backoff strategy to handle transient network errors.
func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*CodexTokenData, error) {
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
// Wait before retry
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Duration(attempt) * time.Second):
}
}
tokenData, err := o.RefreshTokens(ctx, refreshToken)
if err == nil {
return tokenData, nil
}
if isNonRetryableRefreshErr(err) {
log.Warnf("Token refresh attempt %d failed with non-retryable error: %v", attempt+1, err)
return nil, err
}
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
}
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
}
func isNonRetryableRefreshErr(err error) bool {
if err == nil {
return false
}
raw := strings.ToLower(err.Error())
return strings.Contains(raw, "refresh_token_reused")
}
// UpdateTokenStorage updates an existing CodexTokenStorage with new token data.
// This is typically called after a successful token refresh to persist the new credentials.
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
storage.IDToken = tokenData.IDToken
storage.AccessToken = tokenData.AccessToken
storage.RefreshToken = tokenData.RefreshToken
storage.AccountID = tokenData.AccountID
storage.LastRefresh = time.Now().Format(time.RFC3339)
storage.Email = tokenData.Email
storage.Expire = tokenData.Expire
}
================================================
FILE: internal/auth/codex/openai_auth_test.go
================================================
package codex
import (
"context"
"io"
"net/http"
"strings"
"sync/atomic"
"testing"
)
type roundTripFunc func(*http.Request) (*http.Response, error)
func (f roundTripFunc) RoundTrip(req *http.Request) (*http.Response, error) {
return f(req)
}
func TestRefreshTokensWithRetry_NonRetryableOnlyAttemptsOnce(t *testing.T) {
var calls int32
auth := &CodexAuth{
httpClient: &http.Client{
Transport: roundTripFunc(func(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&calls, 1)
return &http.Response{
StatusCode: http.StatusBadRequest,
Body: io.NopCloser(strings.NewReader(`{"error":"invalid_grant","code":"refresh_token_reused"}`)),
Header: make(http.Header),
Request: req,
}, nil
}),
},
}
_, err := auth.RefreshTokensWithRetry(context.Background(), "dummy_refresh_token", 3)
if err == nil {
t.Fatalf("expected error for non-retryable refresh failure")
}
if !strings.Contains(strings.ToLower(err.Error()), "refresh_token_reused") {
t.Fatalf("expected refresh_token_reused in error, got: %v", err)
}
if got := atomic.LoadInt32(&calls); got != 1 {
t.Fatalf("expected 1 refresh attempt, got %d", got)
}
}
================================================
FILE: internal/auth/codex/pkce.go
================================================
// Package codex provides authentication and token management functionality
// for OpenAI's Codex AI services. It handles OAuth2 PKCE (Proof Key for Code Exchange)
// code generation for secure authentication flows.
package codex
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
)
// GeneratePKCECodes generates a new pair of PKCE (Proof Key for Code Exchange) codes.
// It creates a cryptographically random code verifier and its corresponding
// SHA256 code challenge, as specified in RFC 7636. This is a critical security
// feature for the OAuth 2.0 authorization code flow.
func GeneratePKCECodes() (*PKCECodes, error) {
// Generate code verifier: 43-128 characters, URL-safe
codeVerifier, err := generateCodeVerifier()
if err != nil {
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
}
// Generate code challenge using S256 method
codeChallenge := generateCodeChallenge(codeVerifier)
return &PKCECodes{
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
}, nil
}
// generateCodeVerifier creates a cryptographically secure random string to be used
// as the code verifier in the PKCE flow. The verifier is a high-entropy string
// that is later used to prove possession of the client that initiated the
// authorization request.
func generateCodeVerifier() (string, error) {
// Generate 96 random bytes (will result in 128 base64 characters)
bytes := make([]byte, 96)
_, err := rand.Read(bytes)
if err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
// Encode to URL-safe base64 without padding
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes), nil
}
// generateCodeChallenge creates a code challenge from a given code verifier.
// The challenge is derived by taking the SHA256 hash of the verifier and then
// Base64 URL-encoding the result. This is sent in the initial authorization
// request and later verified against the verifier.
func generateCodeChallenge(codeVerifier string) string {
hash := sha256.Sum256([]byte(codeVerifier))
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
}
================================================
FILE: internal/auth/codex/token.go
================================================
// Package codex provides authentication and token management functionality
// for OpenAI's Codex AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the Codex API.
package codex
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
// CodexTokenStorage stores OAuth2 token information for OpenAI Codex API authentication.
// It maintains compatibility with the existing auth system while adding Codex-specific fields
// for managing access tokens, refresh tokens, and user account information.
type CodexTokenStorage struct {
// IDToken is the JWT ID token containing user claims and identity information.
IDToken string `json:"id_token"`
// AccessToken is the OAuth2 access token used for authenticating API requests.
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain new access tokens when the current one expires.
RefreshToken string `json:"refresh_token"`
// AccountID is the OpenAI account identifier associated with this token.
AccountID string `json:"account_id"`
// LastRefresh is the timestamp of the last token refresh operation.
LastRefresh string `json:"last_refresh"`
// Email is the OpenAI account email address associated with this token.
Email string `json:"email"`
// Type indicates the authentication provider type, always "codex" for this storage.
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Codex token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
//
// Returns:
// - error: An error if the operation fails, nil otherwise
func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "codex"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
================================================
FILE: internal/auth/empty/token.go
================================================
// Package empty provides a no-operation token storage implementation.
// This package is used when authentication tokens are not required or when
// using API key-based authentication instead of OAuth tokens for any provider.
package empty
// EmptyStorage is a no-operation implementation of the TokenStorage interface.
// It provides empty implementations for scenarios where token storage is not needed,
// such as when using API keys instead of OAuth tokens for authentication.
type EmptyStorage struct {
// Type indicates the authentication provider type, always "empty" for this implementation.
Type string `json:"type"`
}
// SaveTokenToFile is a no-operation implementation that always succeeds.
// This method satisfies the TokenStorage interface but performs no actual file operations
// since empty storage doesn't require persistent token data.
//
// Parameters:
// - _: The file path parameter is ignored in this implementation
//
// Returns:
// - error: Always returns nil (no error)
func (ts *EmptyStorage) SaveTokenToFile(_ string) error {
ts.Type = "empty"
return nil
}
================================================
FILE: internal/auth/gemini/gemini_auth.go
================================================
// Package gemini provides authentication and token management functionality
// for Google's Gemini AI services. It handles OAuth2 authentication flows,
// including obtaining tokens via web-based authorization, storing tokens,
// and refreshing them when they expire.
package gemini
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
// OAuth configuration constants for Gemini
const (
ClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
ClientSecret = "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"
DefaultCallbackPort = 8085
)
// OAuth scopes for Gemini authentication
var Scopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
// GeminiAuth provides methods for handling the Gemini OAuth2 authentication flow.
// It encapsulates the logic for obtaining, storing, and refreshing authentication tokens
// for Google's Gemini AI services.
type GeminiAuth struct {
}
// WebLoginOptions customizes the interactive OAuth flow.
type WebLoginOptions struct {
NoBrowser bool
CallbackPort int
Prompt func(string) (string, error)
}
// NewGeminiAuth creates a new instance of GeminiAuth.
func NewGeminiAuth() *GeminiAuth {
return &GeminiAuth{}
}
// GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls.
// It manages the entire OAuth2 flow, including handling proxies, loading existing tokens,
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
//
// Parameters:
// - ctx: The context for the HTTP client
// - ts: The Gemini token storage containing authentication tokens
// - cfg: The configuration containing proxy settings
// - opts: Optional parameters to customize browser and prompt behavior
//
// Returns:
// - *http.Client: An HTTP client configured with authentication
// - error: An error if the client configuration fails, nil otherwise
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
callbackPort := DefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
} else if transport != nil {
proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
}
var err error
// Configure the OAuth2 client.
conf := &oauth2.Config{
ClientID: ClientID,
ClientSecret: ClientSecret,
RedirectURL: callbackURL, // This will be used by the local server.
Scopes: Scopes,
Endpoint: google.Endpoint,
}
var token *oauth2.Token
// If no token is found in storage, initiate the web-based OAuth flow.
if ts.Token == nil {
fmt.Printf("Could not load token from file, starting OAuth flow.\n")
token, err = g.getTokenFromWeb(ctx, conf, opts)
if err != nil {
return nil, fmt.Errorf("failed to get token from web: %w", err)
}
// After getting a new token, create a new token storage object with user info.
newTs, errCreateTokenStorage := g.createTokenStorage(ctx, conf, token, ts.ProjectID)
if errCreateTokenStorage != nil {
log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage)
return nil, errCreateTokenStorage
}
*ts = *newTs
}
// Unmarshal the stored token into an oauth2.Token object.
tsToken, _ := json.Marshal(ts.Token)
if err = json.Unmarshal(tsToken, &token); err != nil {
return nil, fmt.Errorf("failed to unmarshal token: %w", err)
}
// Return an HTTP client that automatically handles token refreshing.
return conf.Client(ctx, token), nil
}
// createTokenStorage creates a new GeminiTokenStorage object. It fetches the user's email
// using the provided token and populates the storage structure.
//
// Parameters:
// - ctx: The context for the HTTP request
// - config: The OAuth2 configuration
// - token: The OAuth2 token to use for authentication
// - projectID: The Google Cloud Project ID to associate with this token
//
// Returns:
// - *GeminiTokenStorage: A new token storage object with user information
// - error: An error if the token storage creation fails, nil otherwise
func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*GeminiTokenStorage, error) {
httpClient := config.Client(ctx, token)
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if err != nil {
return nil, fmt.Errorf("could not get user info: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
resp, err := httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
emailResult := gjson.GetBytes(bodyBytes, "email")
if emailResult.Exists() && emailResult.Type == gjson.String {
fmt.Printf("Authenticated user email: %s\n", emailResult.String())
} else {
fmt.Println("Failed to get user email from token")
}
var ifToken map[string]any
jsonData, _ := json.Marshal(token)
err = json.Unmarshal(jsonData, &ifToken)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal token: %w", err)
}
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
ifToken["client_id"] = ClientID
ifToken["client_secret"] = ClientSecret
ifToken["scopes"] = Scopes
ifToken["universe_domain"] = "googleapis.com"
ts := GeminiTokenStorage{
Token: ifToken,
ProjectID: projectID,
Email: emailResult.String(),
}
return &ts, nil
}
// getTokenFromWeb initiates the web-based OAuth2 authorization flow.
// It starts a local HTTP server to listen for the callback from Google's auth server,
// opens the user's browser to the authorization URL, and exchanges the received
// authorization code for an access token.
//
// Parameters:
// - ctx: The context for the HTTP client
// - config: The OAuth2 configuration
// - opts: Optional parameters to customize browser and prompt behavior
//
// Returns:
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
// - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
callbackPort := DefaultCallbackPort
if opts != nil && opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
callbackURL := fmt.Sprintf("http://localhost:%d/oauth2callback", callbackPort)
// Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string, 1)
errChan := make(chan error, 1)
// Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux()
server := &http.Server{Addr: fmt.Sprintf(":%d", callbackPort), Handler: mux}
config.RedirectURL = callbackURL
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
select {
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
default:
}
return
}
code := r.URL.Query().Get("code")
if code == "" {
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
select {
case errChan <- fmt.Errorf("code not found in callback"):
default:
}
return
}
_, _ = fmt.Fprint(w, "Authentication successful! You can close this window.
")
select {
case codeChan <- code:
default:
}
})
// Start the server in a goroutine.
go func() {
if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Errorf("ListenAndServe(): %v", err)
select {
case errChan <- err:
default:
}
}
}()
// Open the authorization URL in the user's browser.
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
noBrowser := false
if opts != nil {
noBrowser = opts.NoBrowser
}
if !noBrowser {
fmt.Println("Opening browser for authentication...")
// Check if browser is available
if !browser.IsAvailable() {
log.Warn("No browser available on this system")
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
} else {
if err := browser.OpenURL(authURL); err != nil {
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
log.Warn(codex.GetUserFriendlyMessage(authErr))
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please manually open this URL in your browser:\n\n%s\n", authURL)
// Log platform info for debugging
platformInfo := browser.GetPlatformInfo()
log.Debugf("Browser platform info: %+v", platformInfo)
} else {
log.Debug("Browser opened successfully")
}
}
} else {
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Please open this URL in your browser:\n\n%s\n", authURL)
}
fmt.Println("Waiting for authentication callback...")
// Wait for the authorization code or an error.
var authCode string
timeoutTimer := time.NewTimer(5 * time.Minute)
defer timeoutTimer.Stop()
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts != nil && opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
default:
}
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil {
return nil, err
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
return nil, err
}
if parsed == nil {
continue
}
if parsed.Error != "" {
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
}
if parsed.Code == "" {
return nil, fmt.Errorf("code not found in callback")
}
authCode = parsed.Code
break waitForCallback
case <-timeoutTimer.C:
return nil, fmt.Errorf("oauth flow timed out")
}
}
// Shutdown the server.
if err := server.Shutdown(ctx); err != nil {
log.Errorf("Failed to shut down server: %v", err)
}
// Exchange the authorization code for a token.
token, err := config.Exchange(ctx, authCode)
if err != nil {
return nil, fmt.Errorf("failed to exchange token: %w", err)
}
fmt.Println("Authentication successful.")
return token, nil
}
================================================
FILE: internal/auth/gemini/gemini_token.go
================================================
// Package gemini provides authentication and token management functionality
// for Google's Gemini AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the Gemini API.
package gemini
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
// GeminiTokenStorage stores OAuth2 token information for Google Gemini API authentication.
// It maintains compatibility with the existing auth system while adding Gemini-specific fields
// for managing access tokens, refresh tokens, and user account information.
type GeminiTokenStorage struct {
// Token holds the raw OAuth2 token data, including access and refresh tokens.
Token any `json:"token"`
// ProjectID is the Google Cloud Project ID associated with this token.
ProjectID string `json:"project_id"`
// Email is the email address of the authenticated user.
Email string `json:"email"`
// Auto indicates if the project ID was automatically selected.
Auto bool `json:"auto"`
// Checked indicates if the associated Cloud AI API has been verified as enabled.
Checked bool `json:"checked"`
// Type indicates the authentication provider type, always "gemini" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
//
// Returns:
// - error: An error if the operation fails, nil otherwise
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "gemini"
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
if errClose := f.Close(); errClose != nil {
log.Errorf("failed to close file: %v", errClose)
}
}()
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err := enc.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
// CredentialFileName returns the filename used to persist Gemini CLI credentials.
// When projectID represents multiple projects (comma-separated or literal ALL),
// the suffix is normalized to "all" and a "gemini-" prefix is enforced to keep
// web and CLI generated files consistent.
func CredentialFileName(email, projectID string, includeProviderPrefix bool) string {
email = strings.TrimSpace(email)
project := strings.TrimSpace(projectID)
if strings.EqualFold(project, "all") || strings.Contains(project, ",") {
return fmt.Sprintf("gemini-%s-all.json", email)
}
prefix := ""
if includeProviderPrefix {
prefix = "gemini-"
}
return fmt.Sprintf("%s%s-%s.json", prefix, email, project)
}
================================================
FILE: internal/auth/iflow/cookie_helpers.go
================================================
package iflow
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
)
// NormalizeCookie normalizes raw cookie strings for iFlow authentication flows.
func NormalizeCookie(raw string) (string, error) {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return "", fmt.Errorf("cookie cannot be empty")
}
combined := strings.Join(strings.Fields(trimmed), " ")
if !strings.HasSuffix(combined, ";") {
combined += ";"
}
if !strings.Contains(combined, "BXAuth=") {
return "", fmt.Errorf("cookie missing BXAuth field")
}
return combined, nil
}
// SanitizeIFlowFileName normalizes user identifiers for safe filename usage.
func SanitizeIFlowFileName(raw string) string {
if raw == "" {
return ""
}
cleanEmail := strings.ReplaceAll(raw, "*", "x")
var result strings.Builder
for _, r := range cleanEmail {
if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_' || r == '@' || r == '.' || r == '-' {
result.WriteRune(r)
}
}
return strings.TrimSpace(result.String())
}
// ExtractBXAuth extracts the BXAuth value from a cookie string.
func ExtractBXAuth(cookie string) string {
parts := strings.Split(cookie, ";")
for _, part := range parts {
part = strings.TrimSpace(part)
if strings.HasPrefix(part, "BXAuth=") {
return strings.TrimPrefix(part, "BXAuth=")
}
}
return ""
}
// CheckDuplicateBXAuth checks if the given BXAuth value already exists in any iflow auth file.
// Returns the path of the existing file if found, empty string otherwise.
func CheckDuplicateBXAuth(authDir, bxAuth string) (string, error) {
if bxAuth == "" {
return "", nil
}
entries, err := os.ReadDir(authDir)
if err != nil {
if os.IsNotExist(err) {
return "", nil
}
return "", fmt.Errorf("read auth dir failed: %w", err)
}
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasPrefix(name, "iflow-") || !strings.HasSuffix(name, ".json") {
continue
}
filePath := filepath.Join(authDir, name)
data, err := os.ReadFile(filePath)
if err != nil {
continue
}
var tokenData struct {
Cookie string `json:"cookie"`
}
if err := json.Unmarshal(data, &tokenData); err != nil {
continue
}
existingBXAuth := ExtractBXAuth(tokenData.Cookie)
if existingBXAuth != "" && existingBXAuth == bxAuth {
return filePath, nil
}
}
return "", nil
}
================================================
FILE: internal/auth/iflow/iflow_auth.go
================================================
package iflow
import (
"compress/gzip"
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// OAuth endpoints and client metadata are derived from the reference Python implementation.
iFlowOAuthTokenEndpoint = "https://iflow.cn/oauth/token"
iFlowOAuthAuthorizeEndpoint = "https://iflow.cn/oauth"
iFlowUserInfoEndpoint = "https://iflow.cn/api/oauth/getUserInfo"
iFlowSuccessRedirectURL = "https://iflow.cn/oauth/success"
// Cookie authentication endpoints
iFlowAPIKeyEndpoint = "https://platform.iflow.cn/api/openapi/apikey"
// Client credentials provided by iFlow for the Code Assist integration.
iFlowOAuthClientID = "10009311001"
iFlowOAuthClientSecret = "4Z3YjXycVsQvyGF1etiNlIBB4RsqSDtW"
)
// DefaultAPIBaseURL is the canonical chat completions endpoint.
const DefaultAPIBaseURL = "https://apis.iflow.cn/v1"
// SuccessRedirectURL is exposed for consumers needing the official success page.
const SuccessRedirectURL = iFlowSuccessRedirectURL
// CallbackPort defines the local port used for OAuth callbacks.
const CallbackPort = 11451
// IFlowAuth encapsulates the HTTP client helpers for the OAuth flow.
type IFlowAuth struct {
httpClient *http.Client
}
// NewIFlowAuth constructs a new IFlowAuth with proxy-aware transport.
func NewIFlowAuth(cfg *config.Config) *IFlowAuth {
client := &http.Client{Timeout: 30 * time.Second}
return &IFlowAuth{httpClient: util.SetProxy(&cfg.SDKConfig, client)}
}
// AuthorizationURL builds the authorization URL and matching redirect URI.
func (ia *IFlowAuth) AuthorizationURL(state string, port int) (authURL, redirectURI string) {
redirectURI = fmt.Sprintf("http://localhost:%d/oauth2callback", port)
values := url.Values{}
values.Set("loginMethod", "phone")
values.Set("type", "phone")
values.Set("redirect", redirectURI)
values.Set("state", state)
values.Set("client_id", iFlowOAuthClientID)
authURL = fmt.Sprintf("%s?%s", iFlowOAuthAuthorizeEndpoint, values.Encode())
return authURL, redirectURI
}
// ExchangeCodeForTokens exchanges an authorization code for access and refresh tokens.
func (ia *IFlowAuth) ExchangeCodeForTokens(ctx context.Context, code, redirectURI string) (*IFlowTokenData, error) {
form := url.Values{}
form.Set("grant_type", "authorization_code")
form.Set("code", code)
form.Set("redirect_uri", redirectURI)
form.Set("client_id", iFlowOAuthClientID)
form.Set("client_secret", iFlowOAuthClientSecret)
req, err := ia.newTokenRequest(ctx, form)
if err != nil {
return nil, err
}
return ia.doTokenRequest(ctx, req)
}
// RefreshTokens exchanges a refresh token for a new access token.
func (ia *IFlowAuth) RefreshTokens(ctx context.Context, refreshToken string) (*IFlowTokenData, error) {
form := url.Values{}
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
form.Set("client_id", iFlowOAuthClientID)
form.Set("client_secret", iFlowOAuthClientSecret)
req, err := ia.newTokenRequest(ctx, form)
if err != nil {
return nil, err
}
return ia.doTokenRequest(ctx, req)
}
func (ia *IFlowAuth) newTokenRequest(ctx context.Context, form url.Values) (*http.Request, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodPost, iFlowOAuthTokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return nil, fmt.Errorf("iflow token: create request failed: %w", err)
}
basic := base64.StdEncoding.EncodeToString([]byte(iFlowOAuthClientID + ":" + iFlowOAuthClientSecret))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
req.Header.Set("Authorization", "Basic "+basic)
return req, nil
}
func (ia *IFlowAuth) doTokenRequest(ctx context.Context, req *http.Request) (*IFlowTokenData, error) {
resp, err := ia.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("iflow token: request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("iflow token: read response failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("iflow token request failed: status=%d body=%s", resp.StatusCode, string(body))
return nil, fmt.Errorf("iflow token: %d %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var tokenResp IFlowTokenResponse
if err = json.Unmarshal(body, &tokenResp); err != nil {
return nil, fmt.Errorf("iflow token: decode response failed: %w", err)
}
data := &IFlowTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
Scope: tokenResp.Scope,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
}
if tokenResp.AccessToken == "" {
log.Debug(string(body))
return nil, fmt.Errorf("iflow token: missing access token in response")
}
info, errAPI := ia.FetchUserInfo(ctx, tokenResp.AccessToken)
if errAPI != nil {
return nil, fmt.Errorf("iflow token: fetch user info failed: %w", errAPI)
}
if strings.TrimSpace(info.APIKey) == "" {
return nil, fmt.Errorf("iflow token: empty api key returned")
}
email := strings.TrimSpace(info.Email)
if email == "" {
email = strings.TrimSpace(info.Phone)
}
if email == "" {
return nil, fmt.Errorf("iflow token: missing account email/phone in user info")
}
data.APIKey = info.APIKey
data.Email = email
return data, nil
}
// FetchUserInfo retrieves account metadata (including API key) for the provided access token.
func (ia *IFlowAuth) FetchUserInfo(ctx context.Context, accessToken string) (*userInfoData, error) {
if strings.TrimSpace(accessToken) == "" {
return nil, fmt.Errorf("iflow api key: access token is empty")
}
endpoint := fmt.Sprintf("%s?accessToken=%s", iFlowUserInfoEndpoint, url.QueryEscape(accessToken))
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("iflow api key: create request failed: %w", err)
}
req.Header.Set("Accept", "application/json")
resp, err := ia.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("iflow api key: request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("iflow api key: read response failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("iflow api key failed: status=%d body=%s", resp.StatusCode, string(body))
return nil, fmt.Errorf("iflow api key: %d %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var result userInfoResponse
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("iflow api key: decode body failed: %w", err)
}
if !result.Success {
return nil, fmt.Errorf("iflow api key: request not successful")
}
if result.Data.APIKey == "" {
return nil, fmt.Errorf("iflow api key: missing api key in response")
}
return &result.Data, nil
}
// CreateTokenStorage converts token data into persistence storage.
func (ia *IFlowAuth) CreateTokenStorage(data *IFlowTokenData) *IFlowTokenStorage {
if data == nil {
return nil
}
return &IFlowTokenStorage{
AccessToken: data.AccessToken,
RefreshToken: data.RefreshToken,
LastRefresh: time.Now().Format(time.RFC3339),
Expire: data.Expire,
APIKey: data.APIKey,
Email: data.Email,
TokenType: data.TokenType,
Scope: data.Scope,
}
}
// UpdateTokenStorage updates the persisted token storage with latest token data.
func (ia *IFlowAuth) UpdateTokenStorage(storage *IFlowTokenStorage, data *IFlowTokenData) {
if storage == nil || data == nil {
return
}
storage.AccessToken = data.AccessToken
storage.RefreshToken = data.RefreshToken
storage.LastRefresh = time.Now().Format(time.RFC3339)
storage.Expire = data.Expire
if data.APIKey != "" {
storage.APIKey = data.APIKey
}
if data.Email != "" {
storage.Email = data.Email
}
storage.TokenType = data.TokenType
storage.Scope = data.Scope
}
// IFlowTokenResponse models the OAuth token endpoint response.
type IFlowTokenResponse struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
}
// IFlowTokenData captures processed token details.
type IFlowTokenData struct {
AccessToken string
RefreshToken string
TokenType string
Scope string
Expire string
APIKey string
Email string
Cookie string
}
// userInfoResponse represents the structure returned by the user info endpoint.
type userInfoResponse struct {
Success bool `json:"success"`
Data userInfoData `json:"data"`
}
type userInfoData struct {
APIKey string `json:"apiKey"`
Email string `json:"email"`
Phone string `json:"phone"`
}
// iFlowAPIKeyResponse represents the response from the API key endpoint
type iFlowAPIKeyResponse struct {
Success bool `json:"success"`
Code string `json:"code"`
Message string `json:"message"`
Data iFlowKeyData `json:"data"`
Extra interface{} `json:"extra"`
}
// iFlowKeyData contains the API key information
type iFlowKeyData struct {
HasExpired bool `json:"hasExpired"`
ExpireTime string `json:"expireTime"`
Name string `json:"name"`
APIKey string `json:"apiKey"`
APIKeyMask string `json:"apiKeyMask"`
}
// iFlowRefreshRequest represents the request body for refreshing API key
type iFlowRefreshRequest struct {
Name string `json:"name"`
}
// AuthenticateWithCookie performs authentication using browser cookies
func (ia *IFlowAuth) AuthenticateWithCookie(ctx context.Context, cookie string) (*IFlowTokenData, error) {
if strings.TrimSpace(cookie) == "" {
return nil, fmt.Errorf("iflow cookie authentication: cookie is empty")
}
// First, get initial API key information using GET request to obtain the name
keyInfo, err := ia.fetchAPIKeyInfo(ctx, cookie)
if err != nil {
return nil, fmt.Errorf("iflow cookie authentication: fetch initial API key info failed: %w", err)
}
// Refresh the API key using POST request
refreshedKeyInfo, err := ia.RefreshAPIKey(ctx, cookie, keyInfo.Name)
if err != nil {
return nil, fmt.Errorf("iflow cookie authentication: refresh API key failed: %w", err)
}
// Convert to token data format using refreshed key
data := &IFlowTokenData{
APIKey: refreshedKeyInfo.APIKey,
Expire: refreshedKeyInfo.ExpireTime,
Email: refreshedKeyInfo.Name,
Cookie: cookie,
}
return data, nil
}
// fetchAPIKeyInfo retrieves API key information using GET request with cookie
func (ia *IFlowAuth) fetchAPIKeyInfo(ctx context.Context, cookie string) (*iFlowKeyData, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, iFlowAPIKeyEndpoint, nil)
if err != nil {
return nil, fmt.Errorf("iflow cookie: create GET request failed: %w", err)
}
// Set cookie and other headers to mimic browser
req.Header.Set("Cookie", cookie)
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8")
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Sec-Fetch-Dest", "empty")
req.Header.Set("Sec-Fetch-Mode", "cors")
req.Header.Set("Sec-Fetch-Site", "same-origin")
resp, err := ia.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("iflow cookie: GET request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// Handle gzip compression
var reader io.Reader = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, fmt.Errorf("iflow cookie: create gzip reader failed: %w", err)
}
defer func() { _ = gzipReader.Close() }()
reader = gzipReader
}
body, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("iflow cookie: read GET response failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("iflow cookie GET request failed: status=%d body=%s", resp.StatusCode, string(body))
return nil, fmt.Errorf("iflow cookie: GET request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var keyResp iFlowAPIKeyResponse
if err = json.Unmarshal(body, &keyResp); err != nil {
return nil, fmt.Errorf("iflow cookie: decode GET response failed: %w", err)
}
if !keyResp.Success {
return nil, fmt.Errorf("iflow cookie: GET request not successful: %s", keyResp.Message)
}
// Handle initial response where apiKey field might be apiKeyMask
if keyResp.Data.APIKey == "" && keyResp.Data.APIKeyMask != "" {
keyResp.Data.APIKey = keyResp.Data.APIKeyMask
}
return &keyResp.Data, nil
}
// RefreshAPIKey refreshes the API key using POST request
func (ia *IFlowAuth) RefreshAPIKey(ctx context.Context, cookie, name string) (*iFlowKeyData, error) {
if strings.TrimSpace(cookie) == "" {
return nil, fmt.Errorf("iflow cookie refresh: cookie is empty")
}
if strings.TrimSpace(name) == "" {
return nil, fmt.Errorf("iflow cookie refresh: name is empty")
}
// Prepare request body
refreshReq := iFlowRefreshRequest{
Name: name,
}
bodyBytes, err := json.Marshal(refreshReq)
if err != nil {
return nil, fmt.Errorf("iflow cookie refresh: marshal request failed: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, iFlowAPIKeyEndpoint, strings.NewReader(string(bodyBytes)))
if err != nil {
return nil, fmt.Errorf("iflow cookie refresh: create POST request failed: %w", err)
}
// Set cookie and other headers to mimic browser
req.Header.Set("Cookie", cookie)
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json, text/plain, */*")
req.Header.Set("User-Agent", "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36")
req.Header.Set("Accept-Language", "zh-CN,zh;q=0.9,en;q=0.8")
req.Header.Set("Accept-Encoding", "gzip, deflate, br")
req.Header.Set("Connection", "keep-alive")
req.Header.Set("Origin", "https://platform.iflow.cn")
req.Header.Set("Referer", "https://platform.iflow.cn/")
resp, err := ia.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("iflow cookie refresh: POST request failed: %w", err)
}
defer func() { _ = resp.Body.Close() }()
// Handle gzip compression
var reader io.Reader = resp.Body
if resp.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(resp.Body)
if err != nil {
return nil, fmt.Errorf("iflow cookie refresh: create gzip reader failed: %w", err)
}
defer func() { _ = gzipReader.Close() }()
reader = gzipReader
}
body, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("iflow cookie refresh: read POST response failed: %w", err)
}
if resp.StatusCode != http.StatusOK {
log.Debugf("iflow cookie POST request failed: status=%d body=%s", resp.StatusCode, string(body))
return nil, fmt.Errorf("iflow cookie refresh: POST request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var keyResp iFlowAPIKeyResponse
if err = json.Unmarshal(body, &keyResp); err != nil {
return nil, fmt.Errorf("iflow cookie refresh: decode POST response failed: %w", err)
}
if !keyResp.Success {
return nil, fmt.Errorf("iflow cookie refresh: POST request not successful: %s", keyResp.Message)
}
return &keyResp.Data, nil
}
// ShouldRefreshAPIKey checks if the API key needs to be refreshed (within 2 days of expiry)
func ShouldRefreshAPIKey(expireTime string) (bool, time.Duration, error) {
if strings.TrimSpace(expireTime) == "" {
return false, 0, fmt.Errorf("iflow cookie: expire time is empty")
}
expire, err := time.Parse("2006-01-02 15:04", expireTime)
if err != nil {
return false, 0, fmt.Errorf("iflow cookie: parse expire time failed: %w", err)
}
now := time.Now()
twoDaysFromNow := now.Add(48 * time.Hour)
needsRefresh := expire.Before(twoDaysFromNow)
timeUntilExpiry := expire.Sub(now)
return needsRefresh, timeUntilExpiry, nil
}
// CreateCookieTokenStorage converts cookie-based token data into persistence storage
func (ia *IFlowAuth) CreateCookieTokenStorage(data *IFlowTokenData) *IFlowTokenStorage {
if data == nil {
return nil
}
// Only save the BXAuth field from the cookie
bxAuth := ExtractBXAuth(data.Cookie)
cookieToSave := ""
if bxAuth != "" {
cookieToSave = "BXAuth=" + bxAuth + ";"
}
return &IFlowTokenStorage{
APIKey: data.APIKey,
Email: data.Email,
Expire: data.Expire,
Cookie: cookieToSave,
LastRefresh: time.Now().Format(time.RFC3339),
Type: "iflow",
}
}
// UpdateCookieTokenStorage updates the persisted token storage with refreshed API key data
func (ia *IFlowAuth) UpdateCookieTokenStorage(storage *IFlowTokenStorage, keyData *iFlowKeyData) {
if storage == nil || keyData == nil {
return
}
storage.APIKey = keyData.APIKey
storage.Expire = keyData.ExpireTime
storage.LastRefresh = time.Now().Format(time.RFC3339)
}
================================================
FILE: internal/auth/iflow/iflow_token.go
================================================
package iflow
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
// IFlowTokenStorage persists iFlow OAuth credentials alongside the derived API key.
type IFlowTokenStorage struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
LastRefresh string `json:"last_refresh"`
Expire string `json:"expired"`
APIKey string `json:"api_key"`
Email string `json:"email"`
TokenType string `json:"token_type"`
Scope string `json:"scope"`
Cookie string `json:"cookie"`
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serialises the token storage to disk.
func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "iflow"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil {
return fmt.Errorf("iflow token: create directory failed: %w", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("iflow token: create file failed: %w", err)
}
defer func() { _ = f.Close() }()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("iflow token: encode token failed: %w", err)
}
return nil
}
================================================
FILE: internal/auth/iflow/oauth_server.go
================================================
package iflow
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const errorRedirectURL = "https://iflow.cn/oauth/error"
// OAuthResult captures the outcome of the local OAuth callback.
type OAuthResult struct {
Code string
State string
Error string
}
// OAuthServer provides a minimal HTTP server for handling the iFlow OAuth callback.
type OAuthServer struct {
server *http.Server
port int
result chan *OAuthResult
errChan chan error
mu sync.Mutex
running bool
}
// NewOAuthServer constructs a new OAuthServer bound to the provided port.
func NewOAuthServer(port int) *OAuthServer {
return &OAuthServer{
port: port,
result: make(chan *OAuthResult, 1),
errChan: make(chan error, 1),
}
}
// Start launches the callback listener.
func (s *OAuthServer) Start() error {
s.mu.Lock()
defer s.mu.Unlock()
if s.running {
return fmt.Errorf("iflow oauth server already running")
}
if !s.isPortAvailable() {
return fmt.Errorf("port %d is already in use", s.port)
}
mux := http.NewServeMux()
mux.HandleFunc("/oauth2callback", s.handleCallback)
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: mux,
ReadTimeout: 10 * time.Second,
WriteTimeout: 10 * time.Second,
}
s.running = true
go func() {
if err := s.server.ListenAndServe(); err != nil && err != http.ErrServerClosed {
s.errChan <- err
}
}()
time.Sleep(100 * time.Millisecond)
return nil
}
// Stop gracefully terminates the callback listener.
func (s *OAuthServer) Stop(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
if !s.running || s.server == nil {
return nil
}
defer func() {
s.running = false
s.server = nil
}()
return s.server.Shutdown(ctx)
}
// WaitForCallback blocks until a callback result, server error, or timeout occurs.
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
select {
case res := <-s.result:
return res, nil
case err := <-s.errChan:
return nil, err
case <-time.After(timeout):
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
}
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
if r.Method != http.MethodGet {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
query := r.URL.Query()
if errParam := strings.TrimSpace(query.Get("error")); errParam != "" {
s.sendResult(&OAuthResult{Error: errParam})
http.Redirect(w, r, errorRedirectURL, http.StatusFound)
return
}
code := strings.TrimSpace(query.Get("code"))
if code == "" {
s.sendResult(&OAuthResult{Error: "missing_code"})
http.Redirect(w, r, errorRedirectURL, http.StatusFound)
return
}
state := query.Get("state")
s.sendResult(&OAuthResult{Code: code, State: state})
http.Redirect(w, r, SuccessRedirectURL, http.StatusFound)
}
func (s *OAuthServer) sendResult(res *OAuthResult) {
select {
case s.result <- res:
default:
log.Debug("iflow oauth result channel full, dropping result")
}
}
func (s *OAuthServer) isPortAvailable() bool {
addr := fmt.Sprintf(":%d", s.port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return false
}
_ = listener.Close()
return true
}
================================================
FILE: internal/auth/kimi/kimi.go
================================================
// Package kimi provides authentication and token management for Kimi (Moonshot AI) API.
// It handles the RFC 8628 OAuth2 Device Authorization Grant flow for secure authentication.
package kimi
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"runtime"
"strings"
"time"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// kimiClientID is Kimi Code's OAuth client ID.
kimiClientID = "17e5f671-d194-4dfb-9706-5516cb48c098"
// kimiOAuthHost is the OAuth server endpoint.
kimiOAuthHost = "https://auth.kimi.com"
// kimiDeviceCodeURL is the endpoint for requesting device codes.
kimiDeviceCodeURL = kimiOAuthHost + "/api/oauth/device_authorization"
// kimiTokenURL is the endpoint for exchanging device codes for tokens.
kimiTokenURL = kimiOAuthHost + "/api/oauth/token"
// KimiAPIBaseURL is the base URL for Kimi API requests.
KimiAPIBaseURL = "https://api.kimi.com/coding"
// defaultPollInterval is the default interval for polling token endpoint.
defaultPollInterval = 5 * time.Second
// maxPollDuration is the maximum time to wait for user authorization.
maxPollDuration = 15 * time.Minute
// refreshThresholdSeconds is when to refresh token before expiry (5 minutes).
refreshThresholdSeconds = 300
)
// KimiAuth handles Kimi authentication flow.
type KimiAuth struct {
deviceClient *DeviceFlowClient
cfg *config.Config
}
// NewKimiAuth creates a new KimiAuth service instance.
func NewKimiAuth(cfg *config.Config) *KimiAuth {
return &KimiAuth{
deviceClient: NewDeviceFlowClient(cfg),
cfg: cfg,
}
}
// StartDeviceFlow initiates the device flow authentication.
func (k *KimiAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
return k.deviceClient.RequestDeviceCode(ctx)
}
// WaitForAuthorization polls for user authorization and returns the auth bundle.
func (k *KimiAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiAuthBundle, error) {
tokenData, err := k.deviceClient.PollForToken(ctx, deviceCode)
if err != nil {
return nil, err
}
return &KimiAuthBundle{
TokenData: tokenData,
DeviceID: k.deviceClient.deviceID,
}, nil
}
// CreateTokenStorage creates a new KimiTokenStorage from auth bundle.
func (k *KimiAuth) CreateTokenStorage(bundle *KimiAuthBundle) *KimiTokenStorage {
expired := ""
if bundle.TokenData.ExpiresAt > 0 {
expired = time.Unix(bundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
}
return &KimiTokenStorage{
AccessToken: bundle.TokenData.AccessToken,
RefreshToken: bundle.TokenData.RefreshToken,
TokenType: bundle.TokenData.TokenType,
Scope: bundle.TokenData.Scope,
DeviceID: strings.TrimSpace(bundle.DeviceID),
Expired: expired,
Type: "kimi",
}
}
// DeviceFlowClient handles the OAuth2 device flow for Kimi.
type DeviceFlowClient struct {
httpClient *http.Client
cfg *config.Config
deviceID string
}
// NewDeviceFlowClient creates a new device flow client.
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
return NewDeviceFlowClientWithDeviceID(cfg, "")
}
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
client := &http.Client{Timeout: 30 * time.Second}
if cfg != nil {
client = util.SetProxy(&cfg.SDKConfig, client)
}
resolvedDeviceID := strings.TrimSpace(deviceID)
if resolvedDeviceID == "" {
resolvedDeviceID = getOrCreateDeviceID()
}
return &DeviceFlowClient{
httpClient: client,
cfg: cfg,
deviceID: resolvedDeviceID,
}
}
// getOrCreateDeviceID returns an in-memory device ID for the current authentication flow.
func getOrCreateDeviceID() string {
return uuid.New().String()
}
// getDeviceModel returns a device model string.
func getDeviceModel() string {
osName := runtime.GOOS
arch := runtime.GOARCH
switch osName {
case "darwin":
return fmt.Sprintf("macOS %s", arch)
case "windows":
return fmt.Sprintf("Windows %s", arch)
case "linux":
return fmt.Sprintf("Linux %s", arch)
default:
return fmt.Sprintf("%s %s", osName, arch)
}
}
// getHostname returns the machine hostname.
func getHostname() string {
hostname, err := os.Hostname()
if err != nil {
return "unknown"
}
return hostname
}
// commonHeaders returns headers required for Kimi API requests.
func (c *DeviceFlowClient) commonHeaders() map[string]string {
return map[string]string{
"X-Msh-Platform": "cli-proxy-api",
"X-Msh-Version": "1.0.0",
"X-Msh-Device-Name": getHostname(),
"X-Msh-Device-Model": getDeviceModel(),
"X-Msh-Device-Id": c.deviceID,
}
}
// RequestDeviceCode initiates the device flow by requesting a device code from Kimi.
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
data := url.Values{}
data.Set("client_id", kimiClientID)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiDeviceCodeURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("kimi: failed to create device code request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
for k, v := range c.commonHeaders() {
req.Header.Set(k, v)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("kimi: device code request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("kimi device code: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("kimi: failed to read device code response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("kimi: device code request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
var deviceCode DeviceCodeResponse
if err = json.Unmarshal(bodyBytes, &deviceCode); err != nil {
return nil, fmt.Errorf("kimi: failed to parse device code response: %w", err)
}
return &deviceCode, nil
}
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiTokenData, error) {
if deviceCode == nil {
return nil, fmt.Errorf("kimi: device code is nil")
}
interval := time.Duration(deviceCode.Interval) * time.Second
if interval < defaultPollInterval {
interval = defaultPollInterval
}
deadline := time.Now().Add(maxPollDuration)
if deviceCode.ExpiresIn > 0 {
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
if codeDeadline.Before(deadline) {
deadline = codeDeadline
}
}
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil, fmt.Errorf("kimi: context cancelled: %w", ctx.Err())
case <-ticker.C:
if time.Now().After(deadline) {
return nil, fmt.Errorf("kimi: device code expired")
}
token, pollErr, shouldContinue := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
if token != nil {
return token, nil
}
if !shouldContinue {
return nil, pollErr
}
// Continue polling
}
}
}
// exchangeDeviceCode attempts to exchange the device code for an access token.
// Returns (token, error, shouldContinue).
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*KimiTokenData, error, bool) {
data := url.Values{}
data.Set("client_id", kimiClientID)
data.Set("device_code", deviceCode)
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("kimi: failed to create token request: %w", err), false
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
for k, v := range c.commonHeaders() {
req.Header.Set(k, v)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("kimi: token request failed: %w", err), false
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("kimi token exchange: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("kimi: failed to read token response: %w", err), false
}
// Parse response - Kimi returns 200 for both success and pending states
var oauthResp struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn float64 `json:"expires_in"`
Scope string `json:"scope"`
}
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
return nil, fmt.Errorf("kimi: failed to parse token response: %w", err), false
}
if oauthResp.Error != "" {
switch oauthResp.Error {
case "authorization_pending":
return nil, nil, true // Continue polling
case "slow_down":
return nil, nil, true // Continue polling (with increased interval handled by caller)
case "expired_token":
return nil, fmt.Errorf("kimi: device code expired"), false
case "access_denied":
return nil, fmt.Errorf("kimi: access denied by user"), false
default:
return nil, fmt.Errorf("kimi: OAuth error: %s - %s", oauthResp.Error, oauthResp.ErrorDescription), false
}
}
if oauthResp.AccessToken == "" {
return nil, fmt.Errorf("kimi: empty access token in response"), false
}
var expiresAt int64
if oauthResp.ExpiresIn > 0 {
expiresAt = time.Now().Unix() + int64(oauthResp.ExpiresIn)
}
return &KimiTokenData{
AccessToken: oauthResp.AccessToken,
RefreshToken: oauthResp.RefreshToken,
TokenType: oauthResp.TokenType,
ExpiresAt: expiresAt,
Scope: oauthResp.Scope,
}, nil, false
}
// RefreshToken exchanges a refresh token for a new access token.
func (c *DeviceFlowClient) RefreshToken(ctx context.Context, refreshToken string) (*KimiTokenData, error) {
data := url.Values{}
data.Set("client_id", kimiClientID)
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("kimi: failed to create refresh request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
for k, v := range c.commonHeaders() {
req.Header.Set(k, v)
}
resp, err := c.httpClient.Do(req)
if err != nil {
return nil, fmt.Errorf("kimi: refresh request failed: %w", err)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("kimi refresh token: close body error: %v", errClose)
}
}()
bodyBytes, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("kimi: failed to read refresh response: %w", err)
}
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
return nil, fmt.Errorf("kimi: refresh token rejected (status %d)", resp.StatusCode)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("kimi: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
TokenType string `json:"token_type"`
ExpiresIn float64 `json:"expires_in"`
Scope string `json:"scope"`
}
if err = json.Unmarshal(bodyBytes, &tokenResp); err != nil {
return nil, fmt.Errorf("kimi: failed to parse refresh response: %w", err)
}
if tokenResp.AccessToken == "" {
return nil, fmt.Errorf("kimi: empty access token in refresh response")
}
var expiresAt int64
if tokenResp.ExpiresIn > 0 {
expiresAt = time.Now().Unix() + int64(tokenResp.ExpiresIn)
}
return &KimiTokenData{
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
TokenType: tokenResp.TokenType,
ExpiresAt: expiresAt,
Scope: tokenResp.Scope,
}, nil
}
================================================
FILE: internal/auth/kimi/token.go
================================================
// Package kimi provides authentication and token management functionality
// for Kimi (Moonshot AI) services. It handles OAuth2 device flow token storage,
// serialization, and retrieval for maintaining authenticated sessions with the Kimi API.
package kimi
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
// KimiTokenStorage stores OAuth2 token information for Kimi API authentication.
type KimiTokenStorage struct {
// AccessToken is the OAuth2 access token used for authenticating API requests.
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
RefreshToken string `json:"refresh_token"`
// TokenType is the type of token, typically "Bearer".
TokenType string `json:"token_type"`
// Scope is the OAuth2 scope granted to the token.
Scope string `json:"scope,omitempty"`
// DeviceID is the OAuth device flow identifier used for Kimi requests.
DeviceID string `json:"device_id,omitempty"`
// Expired is the RFC3339 timestamp when the access token expires.
Expired string `json:"expired,omitempty"`
// Type indicates the authentication provider type, always "kimi" for this storage.
Type string `json:"type"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// KimiTokenData holds the raw OAuth token response from Kimi.
type KimiTokenData struct {
// AccessToken is the OAuth2 access token.
AccessToken string `json:"access_token"`
// RefreshToken is the OAuth2 refresh token.
RefreshToken string `json:"refresh_token"`
// TokenType is the type of token, typically "Bearer".
TokenType string `json:"token_type"`
// ExpiresAt is the Unix timestamp when the token expires.
ExpiresAt int64 `json:"expires_at"`
// Scope is the OAuth2 scope granted to the token.
Scope string `json:"scope"`
}
// KimiAuthBundle bundles authentication data for storage.
type KimiAuthBundle struct {
// TokenData contains the OAuth token information.
TokenData *KimiTokenData
// DeviceID is the device identifier used during OAuth device flow.
DeviceID string
}
// DeviceCodeResponse represents Kimi's device code response.
type DeviceCodeResponse struct {
// DeviceCode is the device verification code.
DeviceCode string `json:"device_code"`
// UserCode is the code the user must enter at the verification URI.
UserCode string `json:"user_code"`
// VerificationURI is the URL where the user should enter the code.
VerificationURI string `json:"verification_uri,omitempty"`
// VerificationURIComplete is the URL with the code pre-filled.
VerificationURIComplete string `json:"verification_uri_complete"`
// ExpiresIn is the number of seconds until the device code expires.
ExpiresIn int `json:"expires_in"`
// Interval is the minimum number of seconds to wait between polling requests.
Interval int `json:"interval"`
}
// SaveTokenToFile serializes the Kimi token storage to a JSON file.
func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "kimi"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
encoder := json.NewEncoder(f)
encoder.SetIndent("", " ")
if err = encoder.Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
// IsExpired checks if the token has expired.
func (ts *KimiTokenStorage) IsExpired() bool {
if ts.Expired == "" {
return false // No expiry set, assume valid
}
t, err := time.Parse(time.RFC3339, ts.Expired)
if err != nil {
return true // Has expiry string but can't parse
}
// Consider expired if within refresh threshold
return time.Now().Add(time.Duration(refreshThresholdSeconds) * time.Second).After(t)
}
// NeedsRefresh checks if the token should be refreshed.
func (ts *KimiTokenStorage) NeedsRefresh() bool {
if ts.RefreshToken == "" {
return false // Can't refresh without refresh token
}
return ts.IsExpired()
}
================================================
FILE: internal/auth/models.go
================================================
// Package auth provides authentication functionality for various AI service providers.
// It includes interfaces and implementations for token storage and authentication methods.
package auth
// TokenStorage defines the interface for storing authentication tokens.
// Implementations of this interface should provide methods to persist
// authentication tokens to a file system location.
type TokenStorage interface {
// SaveTokenToFile persists authentication tokens to the specified file path.
//
// Parameters:
// - authFilePath: The file path where the authentication tokens should be saved
//
// Returns:
// - error: An error if the save operation fails, nil otherwise
SaveTokenToFile(authFilePath string) error
}
================================================
FILE: internal/auth/qwen/qwen_auth.go
================================================
package qwen
import (
"context"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
const (
// QwenOAuthDeviceCodeEndpoint is the URL for initiating the OAuth 2.0 device authorization flow.
QwenOAuthDeviceCodeEndpoint = "https://chat.qwen.ai/api/v1/oauth2/device/code"
// QwenOAuthTokenEndpoint is the URL for exchanging device codes or refresh tokens for access tokens.
QwenOAuthTokenEndpoint = "https://chat.qwen.ai/api/v1/oauth2/token"
// QwenOAuthClientID is the client identifier for the Qwen OAuth 2.0 application.
QwenOAuthClientID = "f0304373b74a44d2b584a3fb70ca9e56"
// QwenOAuthScope defines the permissions requested by the application.
QwenOAuthScope = "openid profile email model.completion"
// QwenOAuthGrantType specifies the grant type for the device code flow.
QwenOAuthGrantType = "urn:ietf:params:oauth:grant-type:device_code"
)
// QwenTokenData represents the OAuth credentials, including access and refresh tokens.
type QwenTokenData struct {
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain a new access token when the current one expires.
RefreshToken string `json:"refresh_token,omitempty"`
// TokenType indicates the type of token, typically "Bearer".
TokenType string `json:"token_type"`
// ResourceURL specifies the base URL of the resource server.
ResourceURL string `json:"resource_url,omitempty"`
// Expire indicates the expiration date and time of the access token.
Expire string `json:"expiry_date,omitempty"`
}
// DeviceFlow represents the response from the device authorization endpoint.
type DeviceFlow struct {
// DeviceCode is the code that the client uses to poll for an access token.
DeviceCode string `json:"device_code"`
// UserCode is the code that the user enters at the verification URI.
UserCode string `json:"user_code"`
// VerificationURI is the URL where the user can enter the user code to authorize the device.
VerificationURI string `json:"verification_uri"`
// VerificationURIComplete is a URI that includes the user_code, which can be used to automatically
// fill in the code on the verification page.
VerificationURIComplete string `json:"verification_uri_complete"`
// ExpiresIn is the time in seconds until the device_code and user_code expire.
ExpiresIn int `json:"expires_in"`
// Interval is the minimum time in seconds that the client should wait between polling requests.
Interval int `json:"interval"`
// CodeVerifier is the cryptographically random string used in the PKCE flow.
CodeVerifier string `json:"code_verifier"`
}
// QwenTokenResponse represents the successful token response from the token endpoint.
type QwenTokenResponse struct {
// AccessToken is the token used to access protected resources.
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain a new access token.
RefreshToken string `json:"refresh_token,omitempty"`
// TokenType indicates the type of token, typically "Bearer".
TokenType string `json:"token_type"`
// ResourceURL specifies the base URL of the resource server.
ResourceURL string `json:"resource_url,omitempty"`
// ExpiresIn is the time in seconds until the access token expires.
ExpiresIn int `json:"expires_in"`
}
// QwenAuth manages authentication and token handling for the Qwen API.
type QwenAuth struct {
httpClient *http.Client
}
// NewQwenAuth creates a new QwenAuth instance with a proxy-configured HTTP client.
func NewQwenAuth(cfg *config.Config) *QwenAuth {
return &QwenAuth{
httpClient: util.SetProxy(&cfg.SDKConfig, &http.Client{}),
}
}
// generateCodeVerifier generates a cryptographically random string for the PKCE code verifier.
func (qa *QwenAuth) generateCodeVerifier() (string, error) {
bytes := make([]byte, 32)
if _, err := rand.Read(bytes); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(bytes), nil
}
// generateCodeChallenge creates a SHA-256 hash of the code verifier, used as the PKCE code challenge.
func (qa *QwenAuth) generateCodeChallenge(codeVerifier string) string {
hash := sha256.Sum256([]byte(codeVerifier))
return base64.RawURLEncoding.EncodeToString(hash[:])
}
// generatePKCEPair creates a new code verifier and its corresponding code challenge for PKCE.
func (qa *QwenAuth) generatePKCEPair() (string, string, error) {
codeVerifier, err := qa.generateCodeVerifier()
if err != nil {
return "", "", err
}
codeChallenge := qa.generateCodeChallenge(codeVerifier)
return codeVerifier, codeChallenge, nil
}
// RefreshTokens exchanges a refresh token for a new access token.
func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*QwenTokenData, error) {
data := url.Values{}
data.Set("grant_type", "refresh_token")
data.Set("refresh_token", refreshToken)
data.Set("client_id", QwenOAuthClientID)
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthTokenEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := qa.httpClient.Do(req)
// resp, err := qa.httpClient.PostForm(QwenOAuthTokenEndpoint, data)
if err != nil {
return nil, fmt.Errorf("token refresh request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
var errorData map[string]interface{}
if err = json.Unmarshal(body, &errorData); err == nil {
return nil, fmt.Errorf("token refresh failed: %v - %v", errorData["error"], errorData["error_description"])
}
return nil, fmt.Errorf("token refresh failed: %s", string(body))
}
var tokenData QwenTokenResponse
if err = json.Unmarshal(body, &tokenData); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
return &QwenTokenData{
AccessToken: tokenData.AccessToken,
TokenType: tokenData.TokenType,
RefreshToken: tokenData.RefreshToken,
ResourceURL: tokenData.ResourceURL,
Expire: time.Now().Add(time.Duration(tokenData.ExpiresIn) * time.Second).Format(time.RFC3339),
}, nil
}
// InitiateDeviceFlow starts the OAuth 2.0 device authorization flow and returns the device flow details.
func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) {
// Generate PKCE code verifier and challenge
codeVerifier, codeChallenge, err := qa.generatePKCEPair()
if err != nil {
return nil, fmt.Errorf("failed to generate PKCE pair: %w", err)
}
data := url.Values{}
data.Set("client_id", QwenOAuthClientID)
data.Set("scope", QwenOAuthScope)
data.Set("code_challenge", codeChallenge)
data.Set("code_challenge_method", "S256")
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthDeviceCodeEndpoint, strings.NewReader(data.Encode()))
if err != nil {
return nil, fmt.Errorf("failed to create token request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, err := qa.httpClient.Do(req)
// resp, err := qa.httpClient.PostForm(QwenOAuthDeviceCodeEndpoint, data)
if err != nil {
return nil, fmt.Errorf("device authorization request failed: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read response body: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("device authorization failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
}
var result DeviceFlow
if err = json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("failed to parse device flow response: %w", err)
}
// Check if the response indicates success
if result.DeviceCode == "" {
return nil, fmt.Errorf("device authorization failed: device_code not found in response")
}
// Add the code_verifier to the result so it can be used later for polling
result.CodeVerifier = codeVerifier
return &result, nil
}
// PollForToken polls the token endpoint with the device code to obtain an access token.
func (qa *QwenAuth) PollForToken(deviceCode, codeVerifier string) (*QwenTokenData, error) {
pollInterval := 5 * time.Second
maxAttempts := 60 // 5 minutes max
for attempt := 0; attempt < maxAttempts; attempt++ {
data := url.Values{}
data.Set("grant_type", QwenOAuthGrantType)
data.Set("client_id", QwenOAuthClientID)
data.Set("device_code", deviceCode)
data.Set("code_verifier", codeVerifier)
resp, err := http.PostForm(QwenOAuthTokenEndpoint, data)
if err != nil {
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
time.Sleep(pollInterval)
continue
}
body, err := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if err != nil {
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
time.Sleep(pollInterval)
continue
}
if resp.StatusCode != http.StatusOK {
// Parse the response as JSON to check for OAuth RFC 8628 standard errors
var errorData map[string]interface{}
if err = json.Unmarshal(body, &errorData); err == nil {
// According to OAuth RFC 8628, handle standard polling responses
if resp.StatusCode == http.StatusBadRequest {
errorType, _ := errorData["error"].(string)
switch errorType {
case "authorization_pending":
// User has not yet approved the authorization request. Continue polling.
fmt.Printf("Polling attempt %d/%d...\n\n", attempt+1, maxAttempts)
time.Sleep(pollInterval)
continue
case "slow_down":
// Client is polling too frequently. Increase poll interval.
pollInterval = time.Duration(float64(pollInterval) * 1.5)
if pollInterval > 10*time.Second {
pollInterval = 10 * time.Second
}
fmt.Printf("Server requested to slow down, increasing poll interval to %v\n\n", pollInterval)
time.Sleep(pollInterval)
continue
case "expired_token":
return nil, fmt.Errorf("device code expired. Please restart the authentication process")
case "access_denied":
return nil, fmt.Errorf("authorization denied by user. Please restart the authentication process")
}
}
// For other errors, return with proper error information
errorType, _ := errorData["error"].(string)
errorDesc, _ := errorData["error_description"].(string)
return nil, fmt.Errorf("device token poll failed: %s - %s", errorType, errorDesc)
}
// If JSON parsing fails, fall back to text response
return nil, fmt.Errorf("device token poll failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
}
// log.Debugf("%s", string(body))
// Success - parse token data
var response QwenTokenResponse
if err = json.Unmarshal(body, &response); err != nil {
return nil, fmt.Errorf("failed to parse token response: %w", err)
}
// Convert to QwenTokenData format and save
tokenData := &QwenTokenData{
AccessToken: response.AccessToken,
RefreshToken: response.RefreshToken,
TokenType: response.TokenType,
ResourceURL: response.ResourceURL,
Expire: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339),
}
return tokenData, nil
}
return nil, fmt.Errorf("authentication timeout. Please restart the authentication process")
}
// RefreshTokensWithRetry attempts to refresh tokens with a specified number of retries upon failure.
func (o *QwenAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*QwenTokenData, error) {
var lastErr error
for attempt := 0; attempt < maxRetries; attempt++ {
if attempt > 0 {
// Wait before retry
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(time.Duration(attempt) * time.Second):
}
}
tokenData, err := o.RefreshTokens(ctx, refreshToken)
if err == nil {
return tokenData, nil
}
lastErr = err
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
}
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
}
// CreateTokenStorage creates a QwenTokenStorage object from a QwenTokenData object.
func (o *QwenAuth) CreateTokenStorage(tokenData *QwenTokenData) *QwenTokenStorage {
storage := &QwenTokenStorage{
AccessToken: tokenData.AccessToken,
RefreshToken: tokenData.RefreshToken,
LastRefresh: time.Now().Format(time.RFC3339),
ResourceURL: tokenData.ResourceURL,
Expire: tokenData.Expire,
}
return storage
}
// UpdateTokenStorage updates an existing token storage with new token data
func (o *QwenAuth) UpdateTokenStorage(storage *QwenTokenStorage, tokenData *QwenTokenData) {
storage.AccessToken = tokenData.AccessToken
storage.RefreshToken = tokenData.RefreshToken
storage.LastRefresh = time.Now().Format(time.RFC3339)
storage.ResourceURL = tokenData.ResourceURL
storage.Expire = tokenData.Expire
}
================================================
FILE: internal/auth/qwen/qwen_token.go
================================================
// Package qwen provides authentication and token management functionality
// for Alibaba's Qwen AI services. It handles OAuth2 token storage, serialization,
// and retrieval for maintaining authenticated sessions with the Qwen API.
package qwen
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
)
// QwenTokenStorage stores OAuth2 token information for Alibaba Qwen API authentication.
// It maintains compatibility with the existing auth system while adding Qwen-specific fields
// for managing access tokens, refresh tokens, and user account information.
type QwenTokenStorage struct {
// AccessToken is the OAuth2 access token used for authenticating API requests.
AccessToken string `json:"access_token"`
// RefreshToken is used to obtain new access tokens when the current one expires.
RefreshToken string `json:"refresh_token"`
// LastRefresh is the timestamp of the last token refresh operation.
LastRefresh string `json:"last_refresh"`
// ResourceURL is the base URL for API requests.
ResourceURL string `json:"resource_url"`
// Email is the Qwen account email address associated with this token.
Email string `json:"email"`
// Type indicates the authentication provider type, always "qwen" for this storage.
Type string `json:"type"`
// Expire is the timestamp when the current access token expires.
Expire string `json:"expired"`
// Metadata holds arbitrary key-value pairs injected via hooks.
// It is not exported to JSON directly to allow flattening during serialization.
Metadata map[string]any `json:"-"`
}
// SetMetadata allows external callers to inject metadata into the storage before saving.
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
ts.Metadata = meta
}
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
// This method creates the necessary directory structure and writes the token
// data in JSON format to the specified file path for persistent storage.
// It merges any injected metadata into the top-level JSON object.
//
// Parameters:
// - authFilePath: The full path where the token file should be saved
//
// Returns:
// - error: An error if the operation fails, nil otherwise
func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
ts.Type = "qwen"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
// Merge metadata using helper
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
if errMerge != nil {
return fmt.Errorf("failed to merge metadata: %w", errMerge)
}
if err = json.NewEncoder(f).Encode(data); err != nil {
return fmt.Errorf("failed to write token to file: %w", err)
}
return nil
}
================================================
FILE: internal/auth/vertex/keyutil.go
================================================
package vertex
import (
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"strings"
)
// NormalizeServiceAccountJSON normalizes the given JSON-encoded service account payload.
// It returns the normalized JSON (with sanitized private_key) or, if normalization fails,
// the original bytes and the encountered error.
func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) {
if len(raw) == 0 {
return raw, nil
}
var payload map[string]any
if err := json.Unmarshal(raw, &payload); err != nil {
return raw, err
}
normalized, err := NormalizeServiceAccountMap(payload)
if err != nil {
return raw, err
}
out, err := json.Marshal(normalized)
if err != nil {
return raw, err
}
return out, nil
}
// NormalizeServiceAccountMap returns a copy of the given service account map with
// a sanitized private_key field that is guaranteed to contain a valid RSA PRIVATE KEY PEM block.
func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) {
if sa == nil {
return nil, fmt.Errorf("service account payload is empty")
}
pk, _ := sa["private_key"].(string)
if strings.TrimSpace(pk) == "" {
return nil, fmt.Errorf("service account missing private_key")
}
normalized, err := sanitizePrivateKey(pk)
if err != nil {
return nil, err
}
clone := make(map[string]any, len(sa))
for k, v := range sa {
clone[k] = v
}
clone["private_key"] = normalized
return clone, nil
}
func sanitizePrivateKey(raw string) (string, error) {
pk := strings.ReplaceAll(raw, "\r\n", "\n")
pk = strings.ReplaceAll(pk, "\r", "\n")
pk = stripANSIEscape(pk)
pk = strings.ToValidUTF8(pk, "")
pk = strings.TrimSpace(pk)
normalized := pk
if block, _ := pem.Decode([]byte(pk)); block == nil {
// Attempt to reconstruct from the textual payload.
if reconstructed, err := rebuildPEM(pk); err == nil {
normalized = reconstructed
} else {
return "", fmt.Errorf("private_key is not valid pem: %w", err)
}
}
block, _ := pem.Decode([]byte(normalized))
if block == nil {
return "", fmt.Errorf("private_key pem decode failed")
}
rsaBlock, err := ensureRSAPrivateKey(block)
if err != nil {
return "", err
}
return string(pem.EncodeToMemory(rsaBlock)), nil
}
func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) {
if block == nil {
return nil, fmt.Errorf("pem block is nil")
}
if block.Type == "RSA PRIVATE KEY" {
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
return nil, fmt.Errorf("private_key invalid rsa: %w", err)
}
return block, nil
}
if block.Type == "PRIVATE KEY" {
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("private_key invalid pkcs8: %w", err)
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("private_key is not an RSA key")
}
der := x509.MarshalPKCS1PrivateKey(rsaKey)
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
}
// Attempt auto-detection: try PKCS#1 first, then PKCS#8.
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
der := x509.MarshalPKCS1PrivateKey(rsaKey)
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
}
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
der := x509.MarshalPKCS1PrivateKey(rsaKey)
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
}
}
return nil, fmt.Errorf("private_key uses unsupported format")
}
func rebuildPEM(raw string) (string, error) {
kind := "PRIVATE KEY"
if strings.Contains(raw, "RSA PRIVATE KEY") {
kind = "RSA PRIVATE KEY"
}
header := "-----BEGIN " + kind + "-----"
footer := "-----END " + kind + "-----"
start := strings.Index(raw, header)
end := strings.Index(raw, footer)
if start < 0 || end <= start {
return "", fmt.Errorf("missing pem markers")
}
body := raw[start+len(header) : end]
payload := filterBase64(body)
if payload == "" {
return "", fmt.Errorf("private_key base64 payload empty")
}
der, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return "", fmt.Errorf("private_key base64 decode failed: %w", err)
}
block := &pem.Block{Type: kind, Bytes: der}
return string(pem.EncodeToMemory(block)), nil
}
func filterBase64(s string) string {
var b strings.Builder
for _, r := range s {
switch {
case r >= 'A' && r <= 'Z':
b.WriteRune(r)
case r >= 'a' && r <= 'z':
b.WriteRune(r)
case r >= '0' && r <= '9':
b.WriteRune(r)
case r == '+' || r == '/' || r == '=':
b.WriteRune(r)
default:
// skip
}
}
return b.String()
}
func stripANSIEscape(s string) string {
in := []rune(s)
var out []rune
for i := 0; i < len(in); i++ {
r := in[i]
if r != 0x1b {
out = append(out, r)
continue
}
if i+1 >= len(in) {
continue
}
next := in[i+1]
switch next {
case ']':
i += 2
for i < len(in) {
if in[i] == 0x07 {
break
}
if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' {
i++
break
}
i++
}
case '[':
i += 2
for i < len(in) {
if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') {
break
}
i++
}
default:
// skip single ESC
}
}
return string(out)
}
================================================
FILE: internal/auth/vertex/vertex_credentials.go
================================================
// Package vertex provides token storage for Google Vertex AI Gemini via service account credentials.
// It serialises service account JSON into an auth file that is consumed by the runtime executor.
package vertex
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
// VertexCredentialStorage stores the service account JSON for Vertex AI access.
// The content is persisted verbatim under the "service_account" key, together with
// helper fields for project, location and email to improve logging and discovery.
type VertexCredentialStorage struct {
// ServiceAccount holds the parsed service account JSON content.
ServiceAccount map[string]any `json:"service_account"`
// ProjectID is derived from the service account JSON (project_id).
ProjectID string `json:"project_id"`
// Email is the client_email from the service account JSON.
Email string `json:"email"`
// Location optionally sets a default region (e.g., us-central1) for Vertex endpoints.
Location string `json:"location,omitempty"`
// Type is the provider identifier stored alongside credentials. Always "vertex".
Type string `json:"type"`
}
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
// It ensures the parent directory exists and logs the operation for transparency.
func (s *VertexCredentialStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
if s == nil {
return fmt.Errorf("vertex credential: storage is nil")
}
if s.ServiceAccount == nil {
return fmt.Errorf("vertex credential: service account content is empty")
}
// Ensure we tag the file with the provider type.
s.Type = "vertex"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil {
return fmt.Errorf("vertex credential: create directory failed: %w", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("vertex credential: create file failed: %w", err)
}
defer func() {
if errClose := f.Close(); errClose != nil {
log.Errorf("vertex credential: failed to close file: %v", errClose)
}
}()
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err = enc.Encode(s); err != nil {
return fmt.Errorf("vertex credential: encode failed: %w", err)
}
return nil
}
================================================
FILE: internal/browser/browser.go
================================================
// Package browser provides cross-platform functionality for opening URLs in the default web browser.
// It abstracts the underlying operating system commands and provides a simple interface.
package browser
import (
"fmt"
"os/exec"
"runtime"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
)
// OpenURL opens the specified URL in the default web browser.
// It first attempts to use a platform-agnostic library and falls back to
// platform-specific commands if that fails.
//
// Parameters:
// - url: The URL to open.
//
// Returns:
// - An error if the URL cannot be opened, otherwise nil.
func OpenURL(url string) error {
fmt.Printf("Attempting to open URL in browser: %s\n", url)
// Try using the open-golang library first
err := open.Run(url)
if err == nil {
log.Debug("Successfully opened URL using open-golang library")
return nil
}
log.Debugf("open-golang failed: %v, trying platform-specific commands", err)
// Fallback to platform-specific commands
return openURLPlatformSpecific(url)
}
// openURLPlatformSpecific is a helper function that opens a URL using OS-specific commands.
// This serves as a fallback mechanism for OpenURL.
//
// Parameters:
// - url: The URL to open.
//
// Returns:
// - An error if the URL cannot be opened, otherwise nil.
func openURLPlatformSpecific(url string) error {
var cmd *exec.Cmd
switch runtime.GOOS {
case "darwin": // macOS
cmd = exec.Command("open", url)
case "windows":
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
case "linux":
// Try common Linux browsers in order of preference
browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"}
for _, browser := range browsers {
if _, err := exec.LookPath(browser); err == nil {
cmd = exec.Command(browser, url)
break
}
}
if cmd == nil {
return fmt.Errorf("no suitable browser found on Linux system")
}
default:
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
}
log.Debugf("Running command: %s %v", cmd.Path, cmd.Args[1:])
err := cmd.Start()
if err != nil {
return fmt.Errorf("failed to start browser command: %w", err)
}
log.Debug("Successfully opened URL using platform-specific command")
return nil
}
// IsAvailable checks if the system has a command available to open a web browser.
// It verifies the presence of necessary commands for the current operating system.
//
// Returns:
// - true if a browser can be opened, false otherwise.
func IsAvailable() bool {
// First check if open-golang can work
testErr := open.Run("about:blank")
if testErr == nil {
return true
}
// Check platform-specific commands
switch runtime.GOOS {
case "darwin":
_, err := exec.LookPath("open")
return err == nil
case "windows":
_, err := exec.LookPath("rundll32")
return err == nil
case "linux":
browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"}
for _, browser := range browsers {
if _, err := exec.LookPath(browser); err == nil {
return true
}
}
return false
default:
return false
}
}
// GetPlatformInfo returns a map containing details about the current platform's
// browser opening capabilities, including the OS, architecture, and available commands.
//
// Returns:
// - A map with platform-specific browser support information.
func GetPlatformInfo() map[string]interface{} {
info := map[string]interface{}{
"os": runtime.GOOS,
"arch": runtime.GOARCH,
"available": IsAvailable(),
}
switch runtime.GOOS {
case "darwin":
info["default_command"] = "open"
case "windows":
info["default_command"] = "rundll32"
case "linux":
browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"}
var availableBrowsers []string
for _, browser := range browsers {
if _, err := exec.LookPath(browser); err == nil {
availableBrowsers = append(availableBrowsers, browser)
}
}
info["available_browsers"] = availableBrowsers
if len(availableBrowsers) > 0 {
info["default_command"] = availableBrowsers[0]
}
}
return info
}
================================================
FILE: internal/buildinfo/buildinfo.go
================================================
// Package buildinfo exposes compile-time metadata shared across the server.
package buildinfo
// The following variables are overridden via ldflags during release builds.
// Defaults cover local development builds.
var (
// Version is the semantic version or git describe output of the binary.
Version = "dev"
// Commit is the git commit SHA baked into the binary.
Commit = "none"
// BuildDate records when the binary was built in UTC.
BuildDate = "unknown"
)
================================================
FILE: internal/cache/signature_cache.go
================================================
package cache
import (
"crypto/sha256"
"encoding/hex"
"strings"
"sync"
"time"
)
// SignatureEntry holds a cached thinking signature with timestamp
type SignatureEntry struct {
Signature string
Timestamp time.Time
}
const (
// SignatureCacheTTL is how long signatures are valid
SignatureCacheTTL = 3 * time.Hour
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
SignatureTextHashLen = 16
// MinValidSignatureLen is the minimum length for a signature to be considered valid
MinValidSignatureLen = 50
// CacheCleanupInterval controls how often stale entries are purged
CacheCleanupInterval = 10 * time.Minute
)
// signatureCache stores signatures by model group -> textHash -> SignatureEntry
var signatureCache sync.Map
// cacheCleanupOnce ensures the background cleanup goroutine starts only once
var cacheCleanupOnce sync.Once
// groupCache is the inner map type
type groupCache struct {
mu sync.RWMutex
entries map[string]SignatureEntry
}
// hashText creates a stable, Unicode-safe key from text content
func hashText(text string) string {
h := sha256.Sum256([]byte(text))
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
}
// getOrCreateGroupCache gets or creates a cache bucket for a model group
func getOrCreateGroupCache(groupKey string) *groupCache {
// Start background cleanup on first access
cacheCleanupOnce.Do(startCacheCleanup)
if val, ok := signatureCache.Load(groupKey); ok {
return val.(*groupCache)
}
sc := &groupCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(groupKey, sc)
return actual.(*groupCache)
}
// startCacheCleanup launches a background goroutine that periodically
// removes caches where all entries have expired.
func startCacheCleanup() {
go func() {
ticker := time.NewTicker(CacheCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredCaches()
}
}()
}
// purgeExpiredCaches removes caches with no valid (non-expired) entries.
func purgeExpiredCaches() {
now := time.Now()
signatureCache.Range(func(key, value any) bool {
sc := value.(*groupCache)
sc.mu.Lock()
// Remove expired entries
for k, entry := range sc.entries {
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, k)
}
}
isEmpty := len(sc.entries) == 0
sc.mu.Unlock()
// Remove cache bucket if empty
if isEmpty {
signatureCache.Delete(key)
}
return true
})
}
// CacheSignature stores a thinking signature for a given model group and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(modelName, text, signature string) {
if text == "" || signature == "" {
return
}
if len(signature) < MinValidSignatureLen {
return
}
groupKey := GetModelGroup(modelName)
textHash := hashText(text)
sc := getOrCreateGroupCache(groupKey)
sc.mu.Lock()
defer sc.mu.Unlock()
sc.entries[textHash] = SignatureEntry{
Signature: signature,
Timestamp: time.Now(),
}
}
// GetCachedSignature retrieves a cached signature for a given model group and text.
// Returns empty string if not found or expired.
func GetCachedSignature(modelName, text string) string {
groupKey := GetModelGroup(modelName)
if text == "" {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
val, ok := signatureCache.Load(groupKey)
if !ok {
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
sc := val.(*groupCache)
textHash := hashText(text)
now := time.Now()
sc.mu.Lock()
entry, exists := sc.entries[textHash]
if !exists {
sc.mu.Unlock()
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, textHash)
sc.mu.Unlock()
if groupKey == "gemini" {
return "skip_thought_signature_validator"
}
return ""
}
// Refresh TTL on access (sliding expiration).
entry.Timestamp = now
sc.entries[textHash] = entry
sc.mu.Unlock()
return entry.Signature
}
// ClearSignatureCache clears signature cache for a specific model group or all groups.
func ClearSignatureCache(modelName string) {
if modelName == "" {
signatureCache.Range(func(key, _ any) bool {
signatureCache.Delete(key)
return true
})
return
}
groupKey := GetModelGroup(modelName)
signatureCache.Delete(groupKey)
}
// HasValidSignature checks if a signature is valid (non-empty and long enough)
func HasValidSignature(modelName, signature string) bool {
return (signature != "" && len(signature) >= MinValidSignatureLen) || (signature == "skip_thought_signature_validator" && GetModelGroup(modelName) == "gemini")
}
func GetModelGroup(modelName string) string {
if strings.Contains(modelName, "gpt") {
return "gpt"
} else if strings.Contains(modelName, "claude") {
return "claude"
} else if strings.Contains(modelName, "gemini") {
return "gemini"
}
return modelName
}
================================================
FILE: internal/cache/signature_cache_test.go
================================================
package cache
import (
"testing"
"time"
)
const testModelName = "claude-sonnet-4-5"
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
ClearSignatureCache("")
text := "This is some thinking text content"
signature := "abc123validSignature1234567890123456789012345678901234567890"
// Store signature
CacheSignature(testModelName, text, signature)
// Retrieve signature
retrieved := GetCachedSignature(testModelName, text)
if retrieved != signature {
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
}
}
func TestCacheSignature_DifferentModelGroups(t *testing.T) {
ClearSignatureCache("")
text := "Same text across models"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
geminiModel := "gemini-3-pro-preview"
CacheSignature(testModelName, text, sig1)
CacheSignature(geminiModel, text, sig2)
if GetCachedSignature(testModelName, text) != sig1 {
t.Error("Claude signature mismatch")
}
if GetCachedSignature(geminiModel, text) != sig2 {
t.Error("Gemini signature mismatch")
}
}
func TestCacheSignature_NotFound(t *testing.T) {
ClearSignatureCache("")
// Non-existent session
if got := GetCachedSignature(testModelName, "some text"); got != "" {
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
}
// Existing session but different text
CacheSignature(testModelName, "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature(testModelName, "text-b"); got != "" {
t.Errorf("Expected empty string for different text, got '%s'", got)
}
}
func TestCacheSignature_EmptyInputs(t *testing.T) {
ClearSignatureCache("")
// All empty/invalid inputs should be no-ops
CacheSignature(testModelName, "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature(testModelName, "text", "")
CacheSignature(testModelName, "text", "short") // Too short
if got := GetCachedSignature(testModelName, "text"); got != "" {
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
}
}
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
ClearSignatureCache("")
text := "Some text"
shortSig := "abc123" // Less than 50 chars
CacheSignature(testModelName, text, shortSig)
if got := GetCachedSignature(testModelName, text); got != "" {
t.Errorf("Short signature should be rejected, got '%s'", got)
}
}
func TestClearSignatureCache_ModelGroup(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature(testModelName, "text", sig)
CacheSignature(testModelName, "text-2", sig)
ClearSignatureCache("session-1")
if got := GetCachedSignature(testModelName, "text"); got != sig {
t.Error("signature should remain when clearing unknown session")
}
}
func TestClearSignatureCache_AllSessions(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature(testModelName, "text", sig)
CacheSignature(testModelName, "text-2", sig)
ClearSignatureCache("")
if got := GetCachedSignature(testModelName, "text"); got != "" {
t.Error("text should be cleared")
}
if got := GetCachedSignature(testModelName, "text-2"); got != "" {
t.Error("text-2 should be cleared")
}
}
func TestHasValidSignature(t *testing.T) {
tests := []struct {
name string
modelName string
signature string
expected bool
}{
{"valid long signature", testModelName, "abc123validSignature1234567890123456789012345678901234567890", true},
{"exactly 50 chars", testModelName, "12345678901234567890123456789012345678901234567890", true},
{"49 chars - invalid", testModelName, "1234567890123456789012345678901234567890123456789", false},
{"empty string", testModelName, "", false},
{"short signature", testModelName, "abc", false},
{"gemini sentinel", "gemini-3-pro-preview", "skip_thought_signature_validator", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := HasValidSignature(tt.modelName, tt.signature)
if result != tt.expected {
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
}
})
}
}
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
ClearSignatureCache("")
// Different texts should produce different hashes
text1 := "First thinking text"
text2 := "Second thinking text"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature(testModelName, text1, sig1)
CacheSignature(testModelName, text2, sig2)
if GetCachedSignature(testModelName, text1) != sig1 {
t.Error("text1 signature mismatch")
}
if GetCachedSignature(testModelName, text2) != sig2 {
t.Error("text2 signature mismatch")
}
}
func TestCacheSignature_UnicodeText(t *testing.T) {
ClearSignatureCache("")
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
sig := "unicodeSig123456789012345678901234567890123456789012345"
CacheSignature(testModelName, text, sig)
if got := GetCachedSignature(testModelName, text); got != sig {
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
}
}
func TestCacheSignature_Overwrite(t *testing.T) {
ClearSignatureCache("")
text := "Same text"
sig1 := "firstSignature12345678901234567890123456789012345678901"
sig2 := "secondSignature1234567890123456789012345678901234567890"
CacheSignature(testModelName, text, sig1)
CacheSignature(testModelName, text, sig2) // Overwrite
if got := GetCachedSignature(testModelName, text); got != sig2 {
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
}
}
// Note: TTL expiration test is tricky to test without mocking time
// We test the logic path exists but actual expiration would require time manipulation
func TestCacheSignature_ExpirationLogic(t *testing.T) {
ClearSignatureCache("")
// This test verifies the expiration check exists
// In a real scenario, we'd mock time.Now()
text := "text"
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature(testModelName, text, sig)
// Fresh entry should be retrievable
if got := GetCachedSignature(testModelName, text); got != sig {
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
}
// We can't easily test actual expiration without time mocking
// but the logic is verified by the implementation
_ = time.Now() // Acknowledge we're not testing time passage
}
================================================
FILE: internal/cmd/anthropic_login.go
================================================
package cmd
import (
"context"
"errors"
"fmt"
"os"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoClaudeLogin triggers the Claude OAuth flow through the shared authentication manager.
// It initiates the OAuth authentication process for Anthropic Claude services and saves
// the authentication tokens to the configured auth directory.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including browser behavior and prompts
func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
if err != nil {
if authErr, ok := errors.AsType[*claude.AuthenticationError](err); ok {
log.Error(claude.GetUserFriendlyMessage(authErr))
if authErr.Type == claude.ErrPortInUse.Type {
os.Exit(claude.ErrPortInUse.Code)
}
return
}
fmt.Printf("Claude authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Claude authentication successful!")
}
================================================
FILE: internal/cmd/antigravity_login.go
================================================
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoAntigravityLogin triggers the OAuth flow for the antigravity provider and saves tokens.
func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
if err != nil {
log.Errorf("Antigravity authentication failed: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Antigravity authentication successful!")
}
================================================
FILE: internal/cmd/auth_manager.go
================================================
package cmd
import (
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
)
// newAuthManager creates a new authentication manager instance with all supported
// authenticators and a file-based token store. It initializes authenticators for
// Gemini, Codex, Claude, and Qwen providers.
//
// Returns:
// - *sdkAuth.Manager: A configured authentication manager instance
func newAuthManager() *sdkAuth.Manager {
store := sdkAuth.GetTokenStore()
manager := sdkAuth.NewManager(store,
sdkAuth.NewGeminiAuthenticator(),
sdkAuth.NewCodexAuthenticator(),
sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewQwenAuthenticator(),
sdkAuth.NewIFlowAuthenticator(),
sdkAuth.NewAntigravityAuthenticator(),
sdkAuth.NewKimiAuthenticator(),
)
return manager
}
================================================
FILE: internal/cmd/iflow_cookie.go
================================================
package cmd
import (
"bufio"
"context"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// DoIFlowCookieAuth performs the iFlow cookie-based authentication.
func DoIFlowCookieAuth(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
reader := bufio.NewReader(os.Stdin)
promptFn = func(prompt string) (string, error) {
fmt.Print(prompt)
value, err := reader.ReadString('\n')
if err != nil {
return "", err
}
return strings.TrimSpace(value), nil
}
}
// Prompt user for cookie
cookie, err := promptForCookie(promptFn)
if err != nil {
fmt.Printf("Failed to get cookie: %v\n", err)
return
}
// Check for duplicate BXAuth before authentication
bxAuth := iflow.ExtractBXAuth(cookie)
if existingFile, err := iflow.CheckDuplicateBXAuth(cfg.AuthDir, bxAuth); err != nil {
fmt.Printf("Failed to check duplicate: %v\n", err)
return
} else if existingFile != "" {
fmt.Printf("Duplicate BXAuth found, authentication already exists: %s\n", filepath.Base(existingFile))
return
}
// Authenticate with cookie
auth := iflow.NewIFlowAuth(cfg)
ctx := context.Background()
tokenData, err := auth.AuthenticateWithCookie(ctx, cookie)
if err != nil {
fmt.Printf("iFlow cookie authentication failed: %v\n", err)
return
}
// Create token storage
tokenStorage := auth.CreateCookieTokenStorage(tokenData)
// Get auth file path using email in filename
authFilePath := getAuthFilePath(cfg, "iflow", tokenData.Email)
// Save token to file
if err := tokenStorage.SaveTokenToFile(authFilePath); err != nil {
fmt.Printf("Failed to save authentication: %v\n", err)
return
}
fmt.Printf("Authentication successful! API key: %s\n", tokenData.APIKey)
fmt.Printf("Expires at: %s\n", tokenData.Expire)
fmt.Printf("Authentication saved to: %s\n", authFilePath)
}
// promptForCookie prompts the user to enter their iFlow cookie
func promptForCookie(promptFn func(string) (string, error)) (string, error) {
line, err := promptFn("Enter iFlow Cookie (from browser cookies): ")
if err != nil {
return "", fmt.Errorf("failed to read cookie: %w", err)
}
cookie, err := iflow.NormalizeCookie(line)
if err != nil {
return "", err
}
return cookie, nil
}
// getAuthFilePath returns the auth file path for the given provider and email
func getAuthFilePath(cfg *config.Config, provider, email string) string {
fileName := iflow.SanitizeIFlowFileName(email)
return fmt.Sprintf("%s/%s-%s-%d.json", cfg.AuthDir, provider, fileName, time.Now().Unix())
}
================================================
FILE: internal/cmd/iflow_login.go
================================================
package cmd
import (
"context"
"errors"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoIFlowLogin performs the iFlow OAuth login via the shared authentication manager.
func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "iflow", cfg, authOpts)
if err != nil {
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
log.Error(emailErr.Error())
return
}
fmt.Printf("iFlow authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("iFlow authentication successful!")
}
================================================
FILE: internal/cmd/kimi_login.go
================================================
package cmd
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoKimiLogin triggers the OAuth device flow for Kimi (Moonshot AI) and saves tokens.
// It initiates the device flow authentication, displays the verification URL for the user,
// and waits for authorization before saving the tokens.
//
// Parameters:
// - cfg: The application configuration containing proxy and auth directory settings
// - options: Login options including browser behavior settings
func DoKimiLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
Metadata: map[string]string{},
Prompt: options.Prompt,
}
record, savedPath, err := manager.Login(context.Background(), "kimi", cfg, authOpts)
if err != nil {
log.Errorf("Kimi authentication failed: %v", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
if record != nil && record.Label != "" {
fmt.Printf("Authenticated as %s\n", record.Label)
}
fmt.Println("Kimi authentication successful!")
}
================================================
FILE: internal/cmd/login.go
================================================
// Package cmd provides command-line interface functionality for the CLI Proxy API server.
// It includes authentication flows for various AI service providers, service startup,
// and other command-line operations.
package cmd
import (
"bufio"
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"os"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
const (
geminiCLIEndpoint = "https://cloudcode-pa.googleapis.com"
geminiCLIVersion = "v1internal"
)
type projectSelectionRequiredError struct{}
func (e *projectSelectionRequiredError) Error() string {
return "gemini cli: project selection required"
}
// DoLogin handles Google Gemini authentication using the shared authentication manager.
// It initiates the OAuth flow for Google Gemini services, performs the legacy CLI user setup,
// and saves the authentication tokens to the configured auth directory.
//
// Parameters:
// - cfg: The application configuration
// - projectID: Optional Google Cloud project ID for Gemini services
// - options: Login options including browser behavior and prompts
func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
ctx := context.Background()
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
trimmedProjectID := strings.TrimSpace(projectID)
callbackPrompt := promptFn
if trimmedProjectID == "" {
callbackPrompt = nil
}
loginOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
ProjectID: trimmedProjectID,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: callbackPrompt,
}
authenticator := sdkAuth.NewGeminiAuthenticator()
record, errLogin := authenticator.Login(ctx, cfg, loginOpts)
if errLogin != nil {
log.Errorf("Gemini authentication failed: %v", errLogin)
return
}
storage, okStorage := record.Storage.(*gemini.GeminiTokenStorage)
if !okStorage || storage == nil {
log.Error("Gemini authentication failed: unsupported token storage")
return
}
geminiAuth := gemini.NewGeminiAuth()
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Prompt: callbackPrompt,
})
if errClient != nil {
log.Errorf("Gemini authentication failed: %v", errClient)
return
}
log.Info("Authentication successful.")
var activatedProjects []string
useGoogleOne := false
if trimmedProjectID == "" && promptFn != nil {
fmt.Println("\nSelect login mode:")
fmt.Println(" 1. Code Assist (GCP project, manual selection)")
fmt.Println(" 2. Google One (personal account, auto-discover project)")
choice, errPrompt := promptFn("Enter choice [1/2] (default: 1): ")
if errPrompt == nil && strings.TrimSpace(choice) == "2" {
useGoogleOne = true
}
}
if useGoogleOne {
log.Info("Google One mode: auto-discovering project...")
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, ""); errSetup != nil {
log.Errorf("Google One auto-discovery failed: %v", errSetup)
return
}
autoProject := strings.TrimSpace(storage.ProjectID)
if autoProject == "" {
log.Error("Google One auto-discovery returned empty project ID")
return
}
log.Infof("Auto-discovered project: %s", autoProject)
activatedProjects = []string{autoProject}
} else {
projects, errProjects := fetchGCPProjects(ctx, httpClient)
if errProjects != nil {
log.Errorf("Failed to get project list: %v", errProjects)
return
}
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil {
log.Errorf("Invalid project selection: %v", errSelection)
return
}
if len(projectSelections) == 0 {
log.Error("No project selected; aborting login.")
return
}
seenProjects := make(map[string]bool)
for _, candidateID := range projectSelections {
log.Infof("Activating project %s", candidateID)
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
if _, ok := errors.AsType[*projectSelectionRequiredError](errSetup); ok {
log.Error("Failed to start user onboarding: A project ID is required.")
showProjectSelectionHelp(storage.Email, projects)
return
}
log.Errorf("Failed to complete user setup: %v", errSetup)
return
}
finalID := strings.TrimSpace(storage.ProjectID)
if finalID == "" {
finalID = candidateID
}
if seenProjects[finalID] {
log.Infof("Project %s already activated, skipping", finalID)
continue
}
seenProjects[finalID] = true
activatedProjects = append(activatedProjects, finalID)
}
}
storage.Auto = false
storage.ProjectID = strings.Join(activatedProjects, ",")
if !storage.Auto && !storage.Checked {
for _, pid := range activatedProjects {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, pid)
if errCheck != nil {
log.Errorf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck)
return
}
if !isChecked {
log.Errorf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid)
return
}
}
storage.Checked = true
}
updateAuthRecord(record, storage)
store := sdkAuth.GetTokenStore()
if setter, okSetter := store.(interface{ SetBaseDir(string) }); okSetter && cfg != nil {
setter.SetBaseDir(cfg.AuthDir)
}
savedPath, errSave := store.Save(ctx, record)
if errSave != nil {
log.Errorf("Failed to save token to file: %v", errSave)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Gemini authentication successful!")
}
func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage *gemini.GeminiTokenStorage, requestedProject string) error {
metadata := map[string]string{
"ideType": "IDE_UNSPECIFIED",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
}
trimmedRequest := strings.TrimSpace(requestedProject)
explicitProject := trimmedRequest != ""
loadReqBody := map[string]any{
"metadata": metadata,
}
if explicitProject {
loadReqBody["cloudaicompanionProject"] = trimmedRequest
}
var loadResp map[string]any
if errLoad := callGeminiCLI(ctx, httpClient, "loadCodeAssist", loadReqBody, &loadResp); errLoad != nil {
return fmt.Errorf("load code assist: %w", errLoad)
}
tierID := "legacy-tier"
if tiers, okTiers := loadResp["allowedTiers"].([]any); okTiers {
for _, rawTier := range tiers {
tier, okTier := rawTier.(map[string]any)
if !okTier {
continue
}
if isDefault, okDefault := tier["isDefault"].(bool); okDefault && isDefault {
if id, okID := tier["id"].(string); okID && strings.TrimSpace(id) != "" {
tierID = strings.TrimSpace(id)
break
}
}
}
}
projectID := trimmedRequest
if projectID == "" {
if id, okProject := loadResp["cloudaicompanionProject"].(string); okProject {
projectID = strings.TrimSpace(id)
}
if projectID == "" {
if projectMap, okProject := loadResp["cloudaicompanionProject"].(map[string]any); okProject {
if id, okID := projectMap["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
}
if projectID == "" {
// Auto-discovery: try onboardUser without specifying a project
// to let Google auto-provision one (matches Gemini CLI headless behavior
// and Antigravity's FetchProjectID pattern).
autoOnboardReq := map[string]any{
"tierId": tierID,
"metadata": metadata,
}
autoCtx, autoCancel := context.WithTimeout(ctx, 30*time.Second)
defer autoCancel()
for attempt := 1; ; attempt++ {
var onboardResp map[string]any
if errOnboard := callGeminiCLI(autoCtx, httpClient, "onboardUser", autoOnboardReq, &onboardResp); errOnboard != nil {
return fmt.Errorf("auto-discovery onboardUser: %w", errOnboard)
}
if done, okDone := onboardResp["done"].(bool); okDone && done {
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
switch v := resp["cloudaicompanionProject"].(type) {
case string:
projectID = strings.TrimSpace(v)
case map[string]any:
if id, okID := v["id"].(string); okID {
projectID = strings.TrimSpace(id)
}
}
}
break
}
log.Debugf("Auto-discovery: onboarding in progress, attempt %d...", attempt)
select {
case <-autoCtx.Done():
return &projectSelectionRequiredError{}
case <-time.After(2 * time.Second):
}
}
if projectID == "" {
return &projectSelectionRequiredError{}
}
log.Infof("Auto-discovered project ID via onboarding: %s", projectID)
}
onboardReqBody := map[string]any{
"tierId": tierID,
"metadata": metadata,
"cloudaicompanionProject": projectID,
}
// Store the requested project as a fallback in case the response omits it.
storage.ProjectID = projectID
for {
var onboardResp map[string]any
if errOnboard := callGeminiCLI(ctx, httpClient, "onboardUser", onboardReqBody, &onboardResp); errOnboard != nil {
return fmt.Errorf("onboard user: %w", errOnboard)
}
if done, okDone := onboardResp["done"].(bool); okDone && done {
responseProjectID := ""
if resp, okResp := onboardResp["response"].(map[string]any); okResp {
switch projectValue := resp["cloudaicompanionProject"].(type) {
case map[string]any:
if id, okID := projectValue["id"].(string); okID {
responseProjectID = strings.TrimSpace(id)
}
case string:
responseProjectID = strings.TrimSpace(projectValue)
}
}
finalProjectID := projectID
if responseProjectID != "" {
if explicitProject && !strings.EqualFold(responseProjectID, projectID) {
// Check if this is a free user (gen-lang-client projects or free/legacy tier)
isFreeUser := strings.HasPrefix(projectID, "gen-lang-client-") ||
strings.EqualFold(tierID, "FREE") ||
strings.EqualFold(tierID, "LEGACY")
if isFreeUser {
// Interactive prompt for free users
fmt.Printf("\nGoogle returned a different project ID:\n")
fmt.Printf(" Requested (frontend): %s\n", projectID)
fmt.Printf(" Returned (backend): %s\n\n", responseProjectID)
fmt.Printf(" Backend project IDs have access to preview models (gemini-3-*).\n")
fmt.Printf(" This is normal for free tier users.\n\n")
fmt.Printf("Which project ID would you like to use?\n")
fmt.Printf(" [1] Backend (recommended): %s\n", responseProjectID)
fmt.Printf(" [2] Frontend: %s\n\n", projectID)
fmt.Printf("Enter choice [1]: ")
reader := bufio.NewReader(os.Stdin)
choice, _ := reader.ReadString('\n')
choice = strings.TrimSpace(choice)
if choice == "2" {
log.Infof("Using frontend project ID: %s", projectID)
fmt.Println(". Warning: Frontend project IDs may not have access to preview models.")
finalProjectID = projectID
} else {
log.Infof("Using backend project ID: %s (recommended)", responseProjectID)
finalProjectID = responseProjectID
}
} else {
// Pro users: keep requested project ID (original behavior)
log.Warnf("Gemini onboarding returned project %s instead of requested %s; keeping requested project ID.", responseProjectID, projectID)
}
} else {
finalProjectID = responseProjectID
}
}
storage.ProjectID = strings.TrimSpace(finalProjectID)
if storage.ProjectID == "" {
storage.ProjectID = strings.TrimSpace(projectID)
}
if storage.ProjectID == "" {
return fmt.Errorf("onboard user completed without project id")
}
log.Infof("Onboarding complete. Using Project ID: %s", storage.ProjectID)
return nil
}
log.Println("Onboarding in progress, waiting 5 seconds...")
time.Sleep(5 * time.Second)
}
}
func callGeminiCLI(ctx context.Context, httpClient *http.Client, endpoint string, body any, result any) error {
url := fmt.Sprintf("%s/%s:%s", geminiCLIEndpoint, geminiCLIVersion, endpoint)
if strings.HasPrefix(endpoint, "operations/") {
url = fmt.Sprintf("%s/%s", geminiCLIEndpoint, endpoint)
}
var reader io.Reader
if body != nil {
rawBody, errMarshal := json.Marshal(body)
if errMarshal != nil {
return fmt.Errorf("marshal request body: %w", errMarshal)
}
reader = bytes.NewReader(rawBody)
}
req, errRequest := http.NewRequestWithContext(ctx, http.MethodPost, url, reader)
if errRequest != nil {
return fmt.Errorf("create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo := httpClient.Do(req)
if errDo != nil {
return fmt.Errorf("execute request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(resp.Body)
return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
if result == nil {
_, _ = io.Copy(io.Discard, resp.Body)
return nil
}
if errDecode := json.NewDecoder(resp.Body).Decode(result); errDecode != nil {
return fmt.Errorf("decode response body: %w", errDecode)
}
return nil
}
func fetchGCPProjects(ctx context.Context, httpClient *http.Client) ([]interfaces.GCPProjectProjects, error) {
req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
if errRequest != nil {
return nil, fmt.Errorf("could not create project list request: %w", errRequest)
}
resp, errDo := httpClient.Do(req)
if errDo != nil {
return nil, fmt.Errorf("failed to execute project list request: %w", errDo)
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, strings.TrimSpace(string(bodyBytes)))
}
var projects interfaces.GCPProject
if errDecode := json.NewDecoder(resp.Body).Decode(&projects); errDecode != nil {
return nil, fmt.Errorf("failed to unmarshal project list: %w", errDecode)
}
return projects.Projects, nil
}
// promptForProjectSelection prints available projects and returns the chosen project ID.
func promptForProjectSelection(projects []interfaces.GCPProjectProjects, presetID string, promptFn func(string) (string, error)) string {
trimmedPreset := strings.TrimSpace(presetID)
if len(projects) == 0 {
if trimmedPreset != "" {
return trimmedPreset
}
fmt.Println("No Google Cloud projects are available for selection.")
return ""
}
fmt.Println("Available Google Cloud projects:")
defaultIndex := 0
for idx, project := range projects {
fmt.Printf("[%d] %s (%s)\n", idx+1, project.ProjectID, project.Name)
if trimmedPreset != "" && project.ProjectID == trimmedPreset {
defaultIndex = idx
}
}
fmt.Println("Type 'ALL' to onboard every listed project.")
defaultID := projects[defaultIndex].ProjectID
if trimmedPreset != "" {
if strings.EqualFold(trimmedPreset, "ALL") {
return "ALL"
}
for _, project := range projects {
if project.ProjectID == trimmedPreset {
return trimmedPreset
}
}
log.Warnf("Provided project ID %s not found in available projects; please choose from the list.", trimmedPreset)
}
for {
promptMsg := fmt.Sprintf("Enter project ID [%s] or ALL: ", defaultID)
answer, errPrompt := promptFn(promptMsg)
if errPrompt != nil {
log.Errorf("Project selection prompt failed: %v", errPrompt)
return defaultID
}
answer = strings.TrimSpace(answer)
if strings.EqualFold(answer, "ALL") {
return "ALL"
}
if answer == "" {
return defaultID
}
for _, project := range projects {
if project.ProjectID == answer {
return project.ProjectID
}
}
if idx, errAtoi := strconv.Atoi(answer); errAtoi == nil {
if idx >= 1 && idx <= len(projects) {
return projects[idx-1].ProjectID
}
}
fmt.Println("Invalid selection, enter a project ID or a number from the list.")
}
}
func resolveProjectSelections(selection string, projects []interfaces.GCPProjectProjects) ([]string, error) {
trimmed := strings.TrimSpace(selection)
if trimmed == "" {
return nil, nil
}
available := make(map[string]struct{}, len(projects))
ordered := make([]string, 0, len(projects))
for _, project := range projects {
id := strings.TrimSpace(project.ProjectID)
if id == "" {
continue
}
if _, exists := available[id]; exists {
continue
}
available[id] = struct{}{}
ordered = append(ordered, id)
}
if strings.EqualFold(trimmed, "ALL") {
if len(ordered) == 0 {
return nil, fmt.Errorf("no projects available for ALL selection")
}
return append([]string(nil), ordered...), nil
}
parts := strings.Split(trimmed, ",")
selections := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, part := range parts {
id := strings.TrimSpace(part)
if id == "" {
continue
}
if _, dup := seen[id]; dup {
continue
}
if len(available) > 0 {
if _, ok := available[id]; !ok {
return nil, fmt.Errorf("project %s not found in available projects", id)
}
}
seen[id] = struct{}{}
selections = append(selections, id)
}
return selections, nil
}
func defaultProjectPrompt() func(string) (string, error) {
reader := bufio.NewReader(os.Stdin)
return func(prompt string) (string, error) {
fmt.Print(prompt)
line, errRead := reader.ReadString('\n')
if errRead != nil {
if errors.Is(errRead, io.EOF) {
return strings.TrimSpace(line), nil
}
return "", errRead
}
return strings.TrimSpace(line), nil
}
}
func showProjectSelectionHelp(email string, projects []interfaces.GCPProjectProjects) {
if email != "" {
log.Infof("Your account %s needs to specify a project ID.", email)
} else {
log.Info("You need to specify a project ID.")
}
if len(projects) > 0 {
fmt.Println("========================================================================")
for _, p := range projects {
fmt.Printf("Project ID: %s\n", p.ProjectID)
fmt.Printf("Project Name: %s\n", p.Name)
fmt.Println("------------------------------------------------------------------------")
}
} else {
fmt.Println("No active projects were returned for this account.")
}
fmt.Printf("Please run this command to login again with a specific project:\n\n%s --login --project_id \n", os.Args[0])
}
func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projectID string) (bool, error) {
serviceUsageURL := "https://serviceusage.googleapis.com"
requiredServices := []string{
// "geminicloudassist.googleapis.com", // Gemini Cloud Assist API
"cloudaicompanion.googleapis.com", // Gemini for Google Cloud API
}
for _, service := range requiredServices {
checkUrl := fmt.Sprintf("%s/v1/projects/%s/services/%s", serviceUsageURL, projectID, service)
req, errRequest := http.NewRequestWithContext(ctx, http.MethodGet, checkUrl, nil)
if errRequest != nil {
return false, fmt.Errorf("failed to create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo := httpClient.Do(req)
if errDo != nil {
return false, fmt.Errorf("failed to execute request: %w", errDo)
}
if resp.StatusCode == http.StatusOK {
bodyBytes, _ := io.ReadAll(resp.Body)
if gjson.GetBytes(bodyBytes, "state").String() == "ENABLED" {
_ = resp.Body.Close()
continue
}
}
_ = resp.Body.Close()
enableUrl := fmt.Sprintf("%s/v1/projects/%s/services/%s:enable", serviceUsageURL, projectID, service)
req, errRequest = http.NewRequestWithContext(ctx, http.MethodPost, enableUrl, strings.NewReader("{}"))
if errRequest != nil {
return false, fmt.Errorf("failed to create request: %w", errRequest)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", misc.GeminiCLIUserAgent(""))
resp, errDo = httpClient.Do(req)
if errDo != nil {
return false, fmt.Errorf("failed to execute request: %w", errDo)
}
bodyBytes, _ := io.ReadAll(resp.Body)
errMessage := string(bodyBytes)
errMessageResult := gjson.GetBytes(bodyBytes, "error.message")
if errMessageResult.Exists() {
errMessage = errMessageResult.String()
}
if resp.StatusCode == http.StatusOK || resp.StatusCode == http.StatusCreated {
_ = resp.Body.Close()
continue
} else if resp.StatusCode == http.StatusBadRequest {
_ = resp.Body.Close()
if strings.Contains(strings.ToLower(errMessage), "already enabled") {
continue
}
}
_ = resp.Body.Close()
return false, fmt.Errorf("project activation required: %s", errMessage)
}
return true, nil
}
func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStorage) {
if record == nil || storage == nil {
return
}
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, true)
if record.Metadata == nil {
record.Metadata = make(map[string]any)
}
record.Metadata["email"] = storage.Email
record.Metadata["project_id"] = storage.ProjectID
record.Metadata["auto"] = storage.Auto
record.Metadata["checked"] = storage.Checked
record.ID = finalName
record.FileName = finalName
record.Storage = storage
}
================================================
FILE: internal/cmd/openai_device_login.go
================================================
package cmd
import (
"context"
"errors"
"fmt"
"os"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
)
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
// existing codex-login OAuth callback flow intact.
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{
codexLoginModeMetadataKey: codexLoginModeDevice,
},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)
}
return
}
fmt.Printf("Codex device authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Codex device authentication successful!")
}
================================================
FILE: internal/cmd/openai_login.go
================================================
package cmd
import (
"context"
"errors"
"fmt"
"os"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// LoginOptions contains options for the login processes.
// It provides configuration for authentication flows including browser behavior
// and interactive prompting capabilities.
type LoginOptions struct {
// NoBrowser indicates whether to skip opening the browser automatically.
NoBrowser bool
// CallbackPort overrides the local OAuth callback port when set (>0).
CallbackPort int
// Prompt allows the caller to provide interactive input when needed.
Prompt func(prompt string) (string, error)
}
// DoCodexLogin triggers the Codex OAuth flow through the shared authentication manager.
// It initiates the OAuth authentication process for OpenAI Codex services and saves
// the authentication tokens to the configured auth directory.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including browser behavior and prompts
func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
if err != nil {
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
log.Error(codex.GetUserFriendlyMessage(authErr))
if authErr.Type == codex.ErrPortInUse.Type {
os.Exit(codex.ErrPortInUse.Code)
}
return
}
fmt.Printf("Codex authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Codex authentication successful!")
}
================================================
FILE: internal/cmd/qwen_login.go
================================================
package cmd
import (
"context"
"errors"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
log "github.com/sirupsen/logrus"
)
// DoQwenLogin handles the Qwen device flow using the shared authentication manager.
// It initiates the device-based authentication process for Qwen services and saves
// the authentication tokens to the configured auth directory.
//
// Parameters:
// - cfg: The application configuration
// - options: Login options including browser behavior and prompts
func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
if options == nil {
options = &LoginOptions{}
}
manager := newAuthManager()
promptFn := options.Prompt
if promptFn == nil {
promptFn = func(prompt string) (string, error) {
fmt.Println()
fmt.Println(prompt)
var value string
_, err := fmt.Scanln(&value)
return value, err
}
}
authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser,
CallbackPort: options.CallbackPort,
Metadata: map[string]string{},
Prompt: promptFn,
}
_, savedPath, err := manager.Login(context.Background(), "qwen", cfg, authOpts)
if err != nil {
if emailErr, ok := errors.AsType[*sdkAuth.EmailRequiredError](err); ok {
log.Error(emailErr.Error())
return
}
fmt.Printf("Qwen authentication failed: %v\n", err)
return
}
if savedPath != "" {
fmt.Printf("Authentication saved to %s\n", savedPath)
}
fmt.Println("Qwen authentication successful!")
}
================================================
FILE: internal/cmd/run.go
================================================
// Package cmd provides command-line interface functionality for the CLI Proxy API server.
// It includes authentication flows for various AI service providers, service startup,
// and other command-line operations.
package cmd
import (
"context"
"errors"
"os/signal"
"syscall"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
log "github.com/sirupsen/logrus"
)
// StartService builds and runs the proxy service using the exported SDK.
// It creates a new proxy service instance, sets up signal handling for graceful shutdown,
// and starts the service with the provided configuration.
//
// Parameters:
// - cfg: The application configuration
// - configPath: The path to the configuration file
// - localPassword: Optional password accepted for local management requests
func StartService(cfg *config.Config, configPath string, localPassword string) {
builder := cliproxy.NewBuilder().
WithConfig(cfg).
WithConfigPath(configPath).
WithLocalManagementPassword(localPassword)
ctxSignal, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
runCtx := ctxSignal
if localPassword != "" {
var keepAliveCancel context.CancelFunc
runCtx, keepAliveCancel = context.WithCancel(ctxSignal)
builder = builder.WithServerOptions(api.WithKeepAliveEndpoint(10*time.Second, func() {
log.Warn("keep-alive endpoint idle for 10s, shutting down")
keepAliveCancel()
}))
}
service, err := builder.Build()
if err != nil {
log.Errorf("failed to build proxy service: %v", err)
return
}
err = service.Run(runCtx)
if err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("proxy service exited with error: %v", err)
}
}
// StartServiceBackground starts the proxy service in a background goroutine
// and returns a cancel function for shutdown and a done channel.
func StartServiceBackground(cfg *config.Config, configPath string, localPassword string) (cancel func(), done <-chan struct{}) {
builder := cliproxy.NewBuilder().
WithConfig(cfg).
WithConfigPath(configPath).
WithLocalManagementPassword(localPassword)
ctx, cancelFn := context.WithCancel(context.Background())
doneCh := make(chan struct{})
service, err := builder.Build()
if err != nil {
log.Errorf("failed to build proxy service: %v", err)
close(doneCh)
return cancelFn, doneCh
}
go func() {
defer close(doneCh)
if err := service.Run(ctx); err != nil && !errors.Is(err, context.Canceled) {
log.Errorf("proxy service exited with error: %v", err)
}
}()
return cancelFn, doneCh
}
// WaitForCloudDeploy waits indefinitely for shutdown signals in cloud deploy mode
// when no configuration file is available.
func WaitForCloudDeploy() {
// Clarify that we are intentionally idle for configuration and not running the API server.
log.Info("Cloud deploy mode: No config found; standing by for configuration. API server is not started. Press Ctrl+C to exit.")
ctxSignal, cancel := signal.NotifyContext(context.Background(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
// Block until shutdown signal is received
<-ctxSignal.Done()
log.Info("Cloud deploy mode: Shutdown signal received; exiting")
}
================================================
FILE: internal/cmd/vertex_import.go
================================================
// Package cmd contains CLI helpers. This file implements importing a Vertex AI
// service account JSON into the auth store as a dedicated "vertex" credential.
package cmd
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// DoVertexImport imports a Google Cloud service account key JSON and persists
// it as a "vertex" provider credential. The file content is embedded in the auth
// file to allow portable deployment across stores.
func DoVertexImport(cfg *config.Config, keyPath string) {
if cfg == nil {
cfg = &config.Config{}
}
if resolved, errResolve := util.ResolveAuthDir(cfg.AuthDir); errResolve == nil {
cfg.AuthDir = resolved
}
rawPath := strings.TrimSpace(keyPath)
if rawPath == "" {
log.Errorf("vertex-import: missing service account key path")
return
}
data, errRead := os.ReadFile(rawPath)
if errRead != nil {
log.Errorf("vertex-import: read file failed: %v", errRead)
return
}
var sa map[string]any
if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil {
log.Errorf("vertex-import: invalid service account json: %v", errUnmarshal)
return
}
// Validate and normalize private_key before saving
normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa)
if errFix != nil {
log.Errorf("vertex-import: %v", errFix)
return
}
sa = normalizedSA
email, _ := sa["client_email"].(string)
projectID, _ := sa["project_id"].(string)
if strings.TrimSpace(projectID) == "" {
log.Errorf("vertex-import: project_id missing in service account json")
return
}
if strings.TrimSpace(email) == "" {
// Keep empty email but warn
log.Warn("vertex-import: client_email missing in service account json")
}
// Default location if not provided by user. Can be edited in the saved file later.
location := "us-central1"
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID))
// Build auth record
storage := &vertex.VertexCredentialStorage{
ServiceAccount: sa,
ProjectID: projectID,
Email: email,
Location: location,
}
metadata := map[string]any{
"service_account": sa,
"project_id": projectID,
"email": email,
"location": location,
"type": "vertex",
"label": labelForVertex(projectID, email),
}
record := &coreauth.Auth{
ID: fileName,
Provider: "vertex",
FileName: fileName,
Storage: storage,
Metadata: metadata,
}
store := sdkAuth.GetTokenStore()
if setter, ok := store.(interface{ SetBaseDir(string) }); ok {
setter.SetBaseDir(cfg.AuthDir)
}
path, errSave := store.Save(context.Background(), record)
if errSave != nil {
log.Errorf("vertex-import: save credential failed: %v", errSave)
return
}
fmt.Printf("Vertex credentials imported: %s\n", path)
}
func sanitizeFilePart(s string) string {
out := strings.TrimSpace(s)
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
for i := 0; i < len(replacers); i += 2 {
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
}
return out
}
func labelForVertex(projectID, email string) string {
p := strings.TrimSpace(projectID)
e := strings.TrimSpace(email)
if p != "" && e != "" {
return fmt.Sprintf("%s (%s)", p, e)
}
if p != "" {
return p
}
if e != "" {
return e
}
return "vertex"
}
================================================
FILE: internal/config/codex_websocket_header_defaults_test.go
================================================
package config
import (
"os"
"path/filepath"
"testing"
)
func TestLoadConfigOptional_CodexHeaderDefaults(t *testing.T) {
dir := t.TempDir()
configPath := filepath.Join(dir, "config.yaml")
configYAML := []byte(`
codex-header-defaults:
user-agent: " my-codex-client/1.0 "
beta-features: " feature-a,feature-b "
`)
if err := os.WriteFile(configPath, configYAML, 0o600); err != nil {
t.Fatalf("failed to write config: %v", err)
}
cfg, err := LoadConfigOptional(configPath, false)
if err != nil {
t.Fatalf("LoadConfigOptional() error = %v", err)
}
if got := cfg.CodexHeaderDefaults.UserAgent; got != "my-codex-client/1.0" {
t.Fatalf("UserAgent = %q, want %q", got, "my-codex-client/1.0")
}
if got := cfg.CodexHeaderDefaults.BetaFeatures; got != "feature-a,feature-b" {
t.Fatalf("BetaFeatures = %q, want %q", got, "feature-a,feature-b")
}
}
================================================
FILE: internal/config/config.go
================================================
// Package config provides configuration management for the CLI Proxy API server.
// It handles loading and parsing YAML configuration files, and provides structured
// access to application settings including server port, authentication directory,
// debug settings, proxy configuration, and API keys.
package config
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"os"
"strings"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3"
)
const (
DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
DefaultPprofAddr = "127.0.0.1:8316"
)
// Config represents the application's configuration, loaded from a YAML file.
type Config struct {
SDKConfig `yaml:",inline"`
// Host is the network host/interface on which the API server will bind.
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
Host string `yaml:"host" json:"-"`
// Port is the network port on which the API server will listen.
Port int `yaml:"port" json:"-"`
// TLS config controls HTTPS server settings.
TLS TLSConfig `yaml:"tls" json:"tls"`
// RemoteManagement nests management-related options under 'remote-management'.
RemoteManagement RemoteManagement `yaml:"remote-management" json:"-"`
// AuthDir is the directory where authentication token files are stored.
AuthDir string `yaml:"auth-dir" json:"-"`
// Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug" json:"debug"`
// Pprof config controls the optional pprof HTTP debug server.
Pprof PprofConfig `yaml:"pprof" json:"pprof"`
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
// LoggingToFile controls whether application logs are written to rotating files or stdout.
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
// LogsMaxTotalSizeMB limits the total size (in MB) of log files under the logs directory.
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
// ErrorLogsMaxFiles limits the number of error log files retained when request logging is disabled.
// When exceeded, the oldest error log files are deleted. Default is 10. Set to 0 to disable cleanup.
ErrorLogsMaxFiles int `yaml:"error-logs-max-files" json:"error-logs-max-files"`
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
// DisableCooling disables quota cooldown scheduling when true.
DisableCooling bool `yaml:"disable-cooling" json:"disable-cooling"`
// RequestRetry defines the retry times when the request failed.
RequestRetry int `yaml:"request-retry" json:"request-retry"`
// MaxRetryCredentials defines the maximum number of credentials to try for a failed request.
// Set to 0 or a negative value to keep trying all available credentials (legacy behavior).
MaxRetryCredentials int `yaml:"max-retry-credentials" json:"max-retry-credentials"`
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
// QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
// Routing controls credential selection behavior.
Routing RoutingConfig `yaml:"routing" json:"routing"`
// WebsocketAuth enables or disables authentication for the WebSocket API.
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
// GeminiKey defines Gemini API key configurations with optional routing overrides.
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
// Codex defines a list of Codex API key configurations as specified in the YAML configuration file.
CodexKey []CodexKey `yaml:"codex-api-key" json:"codex-api-key"`
// CodexHeaderDefaults configures fallback headers for Codex OAuth model requests.
// These are used only when the client does not send its own headers.
CodexHeaderDefaults CodexHeaderDefaults `yaml:"codex-header-defaults" json:"codex-header-defaults"`
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
// ClaudeHeaderDefaults configures default header values for Claude API requests.
// These are used as fallbacks when the client does not send its own headers.
ClaudeHeaderDefaults ClaudeHeaderDefaults `yaml:"claude-header-defaults" json:"claude-header-defaults"`
// OpenAICompatibility defines OpenAI API compatibility configurations for external providers.
OpenAICompatibility []OpenAICompatibility `yaml:"openai-compatibility" json:"openai-compatibility"`
// VertexCompatAPIKey defines Vertex AI-compatible API key configurations for third-party providers.
// Used for services that use Vertex AI-style paths but with simple API key authentication.
VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"`
// AmpCode contains Amp CLI upstream configuration, management restrictions, and model mappings.
AmpCode AmpCode `yaml:"ampcode" json:"ampcode"`
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
// OAuthModelAlias defines global model name aliases for OAuth/file-backed auth channels.
// These aliases affect both model listing and model routing for supported channels:
// gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
//
// NOTE: This does not apply to existing per-credential model alias features under:
// gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, and ampcode.
OAuthModelAlias map[string][]OAuthModelAlias `yaml:"oauth-model-alias,omitempty" json:"oauth-model-alias,omitempty"`
// Payload defines default and override rules for provider payload parameters.
Payload PayloadConfig `yaml:"payload" json:"payload"`
legacyMigrationPending bool `yaml:"-" json:"-"`
}
// ClaudeHeaderDefaults configures default header values injected into Claude API requests
// when the client does not send them. Update these when Claude Code releases a new version.
type ClaudeHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
PackageVersion string `yaml:"package-version" json:"package-version"`
RuntimeVersion string `yaml:"runtime-version" json:"runtime-version"`
Timeout string `yaml:"timeout" json:"timeout"`
}
// CodexHeaderDefaults configures fallback header values injected into Codex
// model requests for OAuth/file-backed auth when the client omits them.
// UserAgent applies to HTTP and websocket requests; BetaFeatures only applies to websockets.
type CodexHeaderDefaults struct {
UserAgent string `yaml:"user-agent" json:"user-agent"`
BetaFeatures string `yaml:"beta-features" json:"beta-features"`
}
// TLSConfig holds HTTPS server settings.
type TLSConfig struct {
// Enable toggles HTTPS server mode.
Enable bool `yaml:"enable" json:"enable"`
// Cert is the path to the TLS certificate file.
Cert string `yaml:"cert" json:"cert"`
// Key is the path to the TLS private key file.
Key string `yaml:"key" json:"key"`
}
// PprofConfig holds pprof HTTP server settings.
type PprofConfig struct {
// Enable toggles the pprof HTTP debug server.
Enable bool `yaml:"enable" json:"enable"`
// Addr is the host:port address for the pprof HTTP server.
Addr string `yaml:"addr" json:"addr"`
}
// RemoteManagement holds management API configuration under 'remote-management'.
type RemoteManagement struct {
// AllowRemote toggles remote (non-localhost) access to management API.
AllowRemote bool `yaml:"allow-remote"`
// SecretKey is the management key (plaintext or bcrypt hashed). YAML key intentionally 'secret-key'.
SecretKey string `yaml:"secret-key"`
// DisableControlPanel skips serving and syncing the bundled management UI when true.
DisableControlPanel bool `yaml:"disable-control-panel"`
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
PanelGitHubRepository string `yaml:"panel-github-repository"`
}
// QuotaExceeded defines the behavior when API quota limits are exceeded.
// It provides configuration options for automatic failover mechanisms.
type QuotaExceeded struct {
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
SwitchProject bool `yaml:"switch-project" json:"switch-project"`
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
}
// RoutingConfig configures how credentials are selected for requests.
type RoutingConfig struct {
// Strategy selects the credential selection strategy.
// Supported values: "round-robin" (default), "fill-first".
Strategy string `yaml:"strategy,omitempty" json:"strategy,omitempty"`
}
// OAuthModelAlias defines a model ID alias for a specific channel.
// It maps the upstream model name (Name) to the client-visible alias (Alias).
// When Fork is true, the alias is added as an additional model in listings while
// keeping the original model ID available.
type OAuthModelAlias struct {
Name string `yaml:"name" json:"name"`
Alias string `yaml:"alias" json:"alias"`
Fork bool `yaml:"fork,omitempty" json:"fork,omitempty"`
}
// AmpModelMapping defines a model name mapping for Amp CLI requests.
// When Amp requests a model that isn't available locally, this mapping
// allows routing to an alternative model that IS available.
type AmpModelMapping struct {
// From is the model name that Amp CLI requests (e.g., "claude-opus-4.5").
From string `yaml:"from" json:"from"`
// To is the target model name to route to (e.g., "claude-sonnet-4").
// The target model must have available providers in the registry.
To string `yaml:"to" json:"to"`
// Regex indicates whether the 'from' field should be interpreted as a regular
// expression for matching model names. When true, this mapping is evaluated
// after exact matches and in the order provided. Defaults to false (exact match).
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
}
// AmpCode groups Amp CLI integration settings including upstream routing,
// optional overrides, management route restrictions, and model fallback mappings.
type AmpCode struct {
// UpstreamURL defines the upstream Amp control plane used for non-provider calls.
UpstreamURL string `yaml:"upstream-url" json:"upstream-url"`
// UpstreamAPIKey optionally overrides the Authorization header when proxying Amp upstream calls.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// UpstreamAPIKeys maps client API keys (from top-level api-keys) to upstream API keys.
// When a client authenticates with a key that matches an entry, that upstream key is used.
// If no match is found, falls back to UpstreamAPIKey (default behavior).
UpstreamAPIKeys []AmpUpstreamAPIKeyEntry `yaml:"upstream-api-keys,omitempty" json:"upstream-api-keys,omitempty"`
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
RestrictManagementToLocalhost bool `yaml:"restrict-management-to-localhost" json:"restrict-management-to-localhost"`
// ModelMappings defines model name mappings for Amp CLI requests.
// When Amp requests a model that isn't available locally, these mappings
// allow routing to an alternative model that IS available.
ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"`
// ForceModelMappings when true, model mappings take precedence over local API keys.
// When false (default), local API keys are used first if available.
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
}
// AmpUpstreamAPIKeyEntry maps a set of client API keys to a specific upstream API key.
// When a request is authenticated with one of the APIKeys, the corresponding UpstreamAPIKey
// is used for the upstream Amp request.
type AmpUpstreamAPIKeyEntry struct {
// UpstreamAPIKey is the API key to use when proxying to the Amp upstream.
UpstreamAPIKey string `yaml:"upstream-api-key" json:"upstream-api-key"`
// APIKeys are the client API keys (from top-level api-keys) that map to this upstream key.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads.
type PayloadConfig struct {
// Default defines rules that only set parameters when they are missing in the payload.
Default []PayloadRule `yaml:"default" json:"default"`
// DefaultRaw defines rules that set raw JSON values only when they are missing.
DefaultRaw []PayloadRule `yaml:"default-raw" json:"default-raw"`
// Override defines rules that always set parameters, overwriting any existing values.
Override []PayloadRule `yaml:"override" json:"override"`
// OverrideRaw defines rules that always set raw JSON values, overwriting any existing values.
OverrideRaw []PayloadRule `yaml:"override-raw" json:"override-raw"`
// Filter defines rules that remove parameters from the payload by JSON path.
Filter []PayloadFilterRule `yaml:"filter" json:"filter"`
}
// PayloadFilterRule describes a rule to remove specific JSON paths from matching model payloads.
type PayloadFilterRule struct {
// Models lists model entries with name pattern and protocol constraint.
Models []PayloadModelRule `yaml:"models" json:"models"`
// Params lists JSON paths (gjson/sjson syntax) to remove from the payload.
Params []string `yaml:"params" json:"params"`
}
// PayloadRule describes a single rule targeting a list of models with parameter updates.
type PayloadRule struct {
// Models lists model entries with name pattern and protocol constraint.
Models []PayloadModelRule `yaml:"models" json:"models"`
// Params maps JSON paths (gjson/sjson syntax) to values written into the payload.
// For *-raw rules, values are treated as raw JSON fragments (strings are used as-is).
Params map[string]any `yaml:"params" json:"params"`
}
// PayloadModelRule ties a model name pattern to a specific translator protocol.
type PayloadModelRule struct {
// Name is the model name or wildcard pattern (e.g., "gpt-*", "*-5", "gemini-*-pro").
Name string `yaml:"name" json:"name"`
// Protocol restricts the rule to a specific translator format (e.g., "gemini", "responses").
Protocol string `yaml:"protocol" json:"protocol"`
}
// CloakConfig configures request cloaking for non-Claude-Code clients.
// Cloaking disguises API requests to appear as originating from the official Claude Code CLI.
type CloakConfig struct {
// Mode controls cloaking behavior: "auto" (default), "always", or "never".
// - "auto": cloak only when client is not Claude Code (based on User-Agent)
// - "always": always apply cloaking regardless of client
// - "never": never apply cloaking
Mode string `yaml:"mode,omitempty" json:"mode,omitempty"`
// StrictMode controls how system prompts are handled when cloaking.
// - false (default): prepend Claude Code prompt to user system messages
// - true: strip all user system messages, keep only Claude Code prompt
StrictMode bool `yaml:"strict-mode,omitempty" json:"strict-mode,omitempty"`
// SensitiveWords is a list of words to obfuscate with zero-width characters.
// This can help bypass certain content filters.
SensitiveWords []string `yaml:"sensitive-words,omitempty" json:"sensitive-words,omitempty"`
// CacheUserID controls whether Claude user_id values are cached per API key.
// When false, a fresh random user_id is generated for every request.
CacheUserID *bool `yaml:"cache-user-id,omitempty" json:"cache-user-id,omitempty"`
}
// ClaudeKey represents the configuration for a Claude API key,
// including the API key itself and an optional base URL for the API endpoint.
type ClaudeKey struct {
// APIKey is the authentication key for accessing Claude API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
// BaseURL is the base URL for the Claude API endpoint.
// If empty, the default Claude API URL will be used.
BaseURL string `yaml:"base-url" json:"base-url"`
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// Models defines upstream model names and aliases for request routing.
Models []ClaudeModel `yaml:"models" json:"models"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
// ExcludedModels lists model IDs that should be excluded for this provider.
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
// Cloak configures request cloaking for non-Claude-Code clients.
Cloak *CloakConfig `yaml:"cloak,omitempty" json:"cloak,omitempty"`
}
func (k ClaudeKey) GetAPIKey() string { return k.APIKey }
func (k ClaudeKey) GetBaseURL() string { return k.BaseURL }
// ClaudeModel describes a mapping between an alias and the actual upstream model name.
type ClaudeModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m ClaudeModel) GetName() string { return m.Name }
func (m ClaudeModel) GetAlias() string { return m.Alias }
// CodexKey represents the configuration for a Codex API key,
// including the API key itself and an optional base URL for the API endpoint.
type CodexKey struct {
// APIKey is the authentication key for accessing Codex API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
// BaseURL is the base URL for the Codex API endpoint.
// If empty, the default Codex API URL will be used.
BaseURL string `yaml:"base-url" json:"base-url"`
// Websockets enables the Responses API websocket transport for this credential.
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// Models defines upstream model names and aliases for request routing.
Models []CodexModel `yaml:"models" json:"models"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
// ExcludedModels lists model IDs that should be excluded for this provider.
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
func (k CodexKey) GetAPIKey() string { return k.APIKey }
func (k CodexKey) GetBaseURL() string { return k.BaseURL }
// CodexModel describes a mapping between an alias and the actual upstream model name.
type CodexModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m CodexModel) GetName() string { return m.Name }
func (m CodexModel) GetAlias() string { return m.Alias }
// GeminiKey represents the configuration for a Gemini API key,
// including optional overrides for upstream base URL, proxy routing, and headers.
type GeminiKey struct {
// APIKey is the authentication key for accessing Gemini API services.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
// BaseURL optionally overrides the Gemini API endpoint.
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
// ProxyURL optionally overrides the global proxy for this API key.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// Models defines upstream model names and aliases for request routing.
Models []GeminiModel `yaml:"models,omitempty" json:"models,omitempty"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
// ExcludedModels lists model IDs that should be excluded for this provider.
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
func (k GeminiKey) GetAPIKey() string { return k.APIKey }
func (k GeminiKey) GetBaseURL() string { return k.BaseURL }
// GeminiModel describes a mapping between an alias and the actual upstream model name.
type GeminiModel struct {
// Name is the upstream model identifier used when issuing requests.
Name string `yaml:"name" json:"name"`
// Alias is the client-facing model name that maps to Name.
Alias string `yaml:"alias" json:"alias"`
}
func (m GeminiModel) GetName() string { return m.Name }
func (m GeminiModel) GetAlias() string { return m.Alias }
// OpenAICompatibility represents the configuration for OpenAI API compatibility
// with external providers, allowing model aliases to be routed through OpenAI API format.
type OpenAICompatibility struct {
// Name is the identifier for this OpenAI compatibility configuration.
Name string `yaml:"name" json:"name"`
// Priority controls selection preference when multiple providers or credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
// BaseURL is the base URL for the external OpenAI-compatible API endpoint.
BaseURL string `yaml:"base-url" json:"base-url"`
// APIKeyEntries defines API keys with optional per-key proxy configuration.
APIKeyEntries []OpenAICompatibilityAPIKey `yaml:"api-key-entries,omitempty" json:"api-key-entries,omitempty"`
// Models defines the model configurations including aliases for routing.
Models []OpenAICompatibilityModel `yaml:"models" json:"models"`
// Headers optionally adds extra HTTP headers for requests sent to this provider.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
}
// OpenAICompatibilityAPIKey represents an API key configuration with optional proxy setting.
type OpenAICompatibilityAPIKey struct {
// APIKey is the authentication key for accessing the external API services.
APIKey string `yaml:"api-key" json:"api-key"`
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
}
// OpenAICompatibilityModel represents a model configuration for OpenAI compatibility,
// including the actual model name and its alias for API routing.
type OpenAICompatibilityModel struct {
// Name is the actual model name used by the external provider.
Name string `yaml:"name" json:"name"`
// Alias is the model name alias that clients will use to reference this model.
Alias string `yaml:"alias" json:"alias"`
}
func (m OpenAICompatibilityModel) GetName() string { return m.Name }
func (m OpenAICompatibilityModel) GetAlias() string { return m.Alias }
// LoadConfig reads a YAML configuration file from the given path,
// unmarshals it into a Config struct, applies environment variable overrides,
// and returns it.
//
// Parameters:
// - configFile: The path to the YAML configuration file
//
// Returns:
// - *Config: The loaded configuration
// - error: An error if the configuration could not be loaded
func LoadConfig(configFile string) (*Config, error) {
return LoadConfigOptional(configFile, false)
}
// LoadConfigOptional reads YAML from configFile.
// If optional is true and the file is missing, it returns an empty Config.
// If optional is true and the file is empty or invalid, it returns an empty Config.
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Read the entire configuration file into memory.
data, err := os.ReadFile(configFile)
if err != nil {
if optional {
if os.IsNotExist(err) || errors.Is(err, syscall.EISDIR) {
// Missing and optional: return empty config (cloud deploy standby).
return &Config{}, nil
}
}
return nil, fmt.Errorf("failed to read config file: %w", err)
}
// In cloud deploy mode (optional=true), if file is empty or contains only whitespace, return empty config.
if optional && len(data) == 0 {
return &Config{}, nil
}
// Unmarshal the YAML data into the Config struct.
var cfg Config
// Set defaults before unmarshal so that absent keys keep defaults.
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
cfg.LoggingToFile = false
cfg.LogsMaxTotalSizeMB = 0
cfg.ErrorLogsMaxFiles = 10
cfg.UsageStatisticsEnabled = false
cfg.DisableCooling = false
cfg.Pprof.Enable = false
cfg.Pprof.Addr = DefaultPprofAddr
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
if err = yaml.Unmarshal(data, &cfg); err != nil {
if optional {
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
return &Config{}, nil
}
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// NOTE: Startup legacy key migration is intentionally disabled.
// Reason: avoid mutating config.yaml during server startup.
// Re-enable the block below if automatic startup migration is needed again.
// var legacy legacyConfigData
// if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
// if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
// cfg.legacyMigrationPending = true
// }
// if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
// cfg.legacyMigrationPending = true
// }
// if cfg.migrateLegacyAmpConfig(&legacy) {
// cfg.legacyMigrationPending = true
// }
// }
// Hash remote management key if plaintext is detected (nested)
// We consider a value to be already hashed if it looks like a bcrypt hash ($2a$, $2b$, or $2y$ prefix).
if cfg.RemoteManagement.SecretKey != "" && !looksLikeBcrypt(cfg.RemoteManagement.SecretKey) {
hashed, errHash := hashSecret(cfg.RemoteManagement.SecretKey)
if errHash != nil {
return nil, fmt.Errorf("failed to hash remote management key: %w", errHash)
}
cfg.RemoteManagement.SecretKey = hashed
// Persist the hashed value back to the config file to avoid re-hashing on next startup.
// Preserve YAML comments and ordering; update only the nested key.
_ = SaveConfigPreserveCommentsUpdateNestedScalar(configFile, []string{"remote-management", "secret-key"}, hashed)
}
cfg.RemoteManagement.PanelGitHubRepository = strings.TrimSpace(cfg.RemoteManagement.PanelGitHubRepository)
if cfg.RemoteManagement.PanelGitHubRepository == "" {
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
}
cfg.Pprof.Addr = strings.TrimSpace(cfg.Pprof.Addr)
if cfg.Pprof.Addr == "" {
cfg.Pprof.Addr = DefaultPprofAddr
}
if cfg.LogsMaxTotalSizeMB < 0 {
cfg.LogsMaxTotalSizeMB = 0
}
if cfg.ErrorLogsMaxFiles < 0 {
cfg.ErrorLogsMaxFiles = 10
}
if cfg.MaxRetryCredentials < 0 {
cfg.MaxRetryCredentials = 0
}
// Sanitize Gemini API key configuration and migrate legacy entries.
cfg.SanitizeGeminiKeys()
// Sanitize Vertex-compatible API keys.
cfg.SanitizeVertexCompatKeys()
// Sanitize Codex keys: drop entries without base-url
cfg.SanitizeCodexKeys()
// Sanitize Codex header defaults.
cfg.SanitizeCodexHeaderDefaults()
// Sanitize Claude key headers
cfg.SanitizeClaudeKeys()
// Sanitize OpenAI compatibility providers: drop entries without base-url
cfg.SanitizeOpenAICompatibility()
// Normalize OAuth provider model exclusion map.
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
// Normalize global OAuth model name aliases.
cfg.SanitizeOAuthModelAlias()
// Validate raw payload rules and drop invalid entries.
cfg.SanitizePayloadRules()
// NOTE: Legacy migration persistence is intentionally disabled together with
// startup legacy migration to keep startup read-only for config.yaml.
// Re-enable the block below if automatic startup migration is needed again.
// if cfg.legacyMigrationPending {
// fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
// if !optional && configFile != "" {
// if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
// return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
// }
// fmt.Println("Legacy configuration normalized and persisted.")
// } else {
// fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
// }
// }
// Return the populated configuration struct.
return &cfg, nil
}
// SanitizePayloadRules validates raw JSON payload rule params and drops invalid rules.
func (cfg *Config) SanitizePayloadRules() {
if cfg == nil {
return
}
cfg.Payload.DefaultRaw = sanitizePayloadRawRules(cfg.Payload.DefaultRaw, "default-raw")
cfg.Payload.OverrideRaw = sanitizePayloadRawRules(cfg.Payload.OverrideRaw, "override-raw")
}
func sanitizePayloadRawRules(rules []PayloadRule, section string) []PayloadRule {
if len(rules) == 0 {
return rules
}
out := make([]PayloadRule, 0, len(rules))
for i := range rules {
rule := rules[i]
if len(rule.Params) == 0 {
continue
}
invalid := false
for path, value := range rule.Params {
raw, ok := payloadRawString(value)
if !ok {
continue
}
trimmed := bytes.TrimSpace(raw)
if len(trimmed) == 0 || !json.Valid(trimmed) {
log.WithFields(log.Fields{
"section": section,
"rule_index": i + 1,
"param": path,
}).Warn("payload rule dropped: invalid raw JSON")
invalid = true
break
}
}
if invalid {
continue
}
out = append(out, rule)
}
return out
}
func payloadRawString(value any) ([]byte, bool) {
switch typed := value.(type) {
case string:
return []byte(typed), true
case []byte:
return typed, true
default:
return nil, false
}
}
// SanitizeCodexHeaderDefaults trims surrounding whitespace from the
// configured Codex header fallback values.
func (cfg *Config) SanitizeCodexHeaderDefaults() {
if cfg == nil {
return
}
cfg.CodexHeaderDefaults.UserAgent = strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent)
cfg.CodexHeaderDefaults.BetaFeatures = strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
}
// SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases.
// It trims whitespace, normalizes channel keys to lower-case, drops empty entries,
// allows multiple aliases per upstream name, and ensures aliases are unique within each channel.
func (cfg *Config) SanitizeOAuthModelAlias() {
if cfg == nil || len(cfg.OAuthModelAlias) == 0 {
return
}
out := make(map[string][]OAuthModelAlias, len(cfg.OAuthModelAlias))
for rawChannel, aliases := range cfg.OAuthModelAlias {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(aliases) == 0 {
continue
}
seenAlias := make(map[string]struct{}, len(aliases))
clean := make([]OAuthModelAlias, 0, len(aliases))
for _, entry := range aliases {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
aliasKey := strings.ToLower(alias)
if _, ok := seenAlias[aliasKey]; ok {
continue
}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, OAuthModelAlias{Name: name, Alias: alias, Fork: entry.Fork})
}
if len(clean) > 0 {
out[channel] = clean
}
}
cfg.OAuthModelAlias = out
}
// SanitizeOpenAICompatibility removes OpenAI-compatibility provider entries that are
// not actionable, specifically those missing a BaseURL. It trims whitespace before
// evaluation and preserves the relative order of remaining entries.
func (cfg *Config) SanitizeOpenAICompatibility() {
if cfg == nil || len(cfg.OpenAICompatibility) == 0 {
return
}
out := make([]OpenAICompatibility, 0, len(cfg.OpenAICompatibility))
for i := range cfg.OpenAICompatibility {
e := cfg.OpenAICompatibility[i]
e.Name = strings.TrimSpace(e.Name)
e.Prefix = normalizeModelPrefix(e.Prefix)
e.BaseURL = strings.TrimSpace(e.BaseURL)
e.Headers = NormalizeHeaders(e.Headers)
if e.BaseURL == "" {
// Skip providers with no base-url; treated as removed
continue
}
out = append(out, e)
}
cfg.OpenAICompatibility = out
}
// SanitizeCodexKeys removes Codex API key entries missing a BaseURL.
// It trims whitespace and preserves order for remaining entries.
func (cfg *Config) SanitizeCodexKeys() {
if cfg == nil || len(cfg.CodexKey) == 0 {
return
}
out := make([]CodexKey, 0, len(cfg.CodexKey))
for i := range cfg.CodexKey {
e := cfg.CodexKey[i]
e.Prefix = normalizeModelPrefix(e.Prefix)
e.BaseURL = strings.TrimSpace(e.BaseURL)
e.Headers = NormalizeHeaders(e.Headers)
e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels)
if e.BaseURL == "" {
continue
}
out = append(out, e)
}
cfg.CodexKey = out
}
// SanitizeClaudeKeys normalizes headers for Claude credentials.
func (cfg *Config) SanitizeClaudeKeys() {
if cfg == nil || len(cfg.ClaudeKey) == 0 {
return
}
for i := range cfg.ClaudeKey {
entry := &cfg.ClaudeKey[i]
entry.Prefix = normalizeModelPrefix(entry.Prefix)
entry.Headers = NormalizeHeaders(entry.Headers)
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
}
}
// SanitizeGeminiKeys deduplicates and normalizes Gemini credentials.
func (cfg *Config) SanitizeGeminiKeys() {
if cfg == nil {
return
}
seen := make(map[string]struct{}, len(cfg.GeminiKey))
out := cfg.GeminiKey[:0]
for i := range cfg.GeminiKey {
entry := cfg.GeminiKey[i]
entry.APIKey = strings.TrimSpace(entry.APIKey)
if entry.APIKey == "" {
continue
}
entry.Prefix = normalizeModelPrefix(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = NormalizeHeaders(entry.Headers)
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
if _, exists := seen[entry.APIKey]; exists {
continue
}
seen[entry.APIKey] = struct{}{}
out = append(out, entry)
}
cfg.GeminiKey = out
}
func normalizeModelPrefix(prefix string) string {
trimmed := strings.TrimSpace(prefix)
trimmed = strings.Trim(trimmed, "/")
if trimmed == "" {
return ""
}
if strings.Contains(trimmed, "/") {
return ""
}
return trimmed
}
// looksLikeBcrypt returns true if the provided string appears to be a bcrypt hash.
func looksLikeBcrypt(s string) bool {
return len(s) > 4 && (s[:4] == "$2a$" || s[:4] == "$2b$" || s[:4] == "$2y$")
}
// NormalizeHeaders trims header keys and values and removes empty pairs.
func NormalizeHeaders(headers map[string]string) map[string]string {
if len(headers) == 0 {
return nil
}
clean := make(map[string]string, len(headers))
for k, v := range headers {
key := strings.TrimSpace(k)
val := strings.TrimSpace(v)
if key == "" || val == "" {
continue
}
clean[key] = val
}
if len(clean) == 0 {
return nil
}
return clean
}
// NormalizeExcludedModels trims, lowercases, and deduplicates model exclusion patterns.
// It preserves the order of first occurrences and drops empty entries.
func NormalizeExcludedModels(models []string) []string {
if len(models) == 0 {
return nil
}
seen := make(map[string]struct{}, len(models))
out := make([]string, 0, len(models))
for _, raw := range models {
trimmed := strings.ToLower(strings.TrimSpace(raw))
if trimmed == "" {
continue
}
if _, exists := seen[trimmed]; exists {
continue
}
seen[trimmed] = struct{}{}
out = append(out, trimmed)
}
if len(out) == 0 {
return nil
}
return out
}
// NormalizeOAuthExcludedModels cleans provider -> excluded models mappings by normalizing provider keys
// and applying model exclusion normalization to each entry.
func NormalizeOAuthExcludedModels(entries map[string][]string) map[string][]string {
if len(entries) == 0 {
return nil
}
out := make(map[string][]string, len(entries))
for provider, models := range entries {
key := strings.ToLower(strings.TrimSpace(provider))
if key == "" {
continue
}
normalized := NormalizeExcludedModels(models)
if len(normalized) == 0 {
continue
}
out[key] = normalized
}
if len(out) == 0 {
return nil
}
return out
}
// hashSecret hashes the given secret using bcrypt.
func hashSecret(secret string) (string, error) {
// Use default cost for simplicity.
hashedBytes, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(hashedBytes), nil
}
// SaveConfigPreserveComments writes the config back to YAML while preserving existing comments
// and key ordering by loading the original file into a yaml.Node tree and updating values in-place.
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
persistCfg := cfg
// Load original YAML as a node tree to preserve comments and ordering.
data, err := os.ReadFile(configFile)
if err != nil {
return err
}
var original yaml.Node
if err = yaml.Unmarshal(data, &original); err != nil {
return err
}
if original.Kind != yaml.DocumentNode || len(original.Content) == 0 {
return fmt.Errorf("invalid yaml document structure")
}
if original.Content[0] == nil || original.Content[0].Kind != yaml.MappingNode {
return fmt.Errorf("expected root mapping node")
}
// Marshal the current cfg to YAML, then unmarshal to a yaml.Node we can merge from.
rendered, err := yaml.Marshal(persistCfg)
if err != nil {
return err
}
var generated yaml.Node
if err = yaml.Unmarshal(rendered, &generated); err != nil {
return err
}
if generated.Kind != yaml.DocumentNode || len(generated.Content) == 0 || generated.Content[0] == nil {
return fmt.Errorf("invalid generated yaml structure")
}
if generated.Content[0].Kind != yaml.MappingNode {
return fmt.Errorf("expected generated root mapping node")
}
// Remove deprecated sections before merging back the sanitized config.
removeLegacyAuthBlock(original.Content[0])
removeLegacyOpenAICompatAPIKeys(original.Content[0])
removeLegacyAmpKeys(original.Content[0])
removeLegacyGenerativeLanguageKeys(original.Content[0])
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-excluded-models")
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-model-alias")
// Merge generated into original in-place, preserving comments/order of existing nodes.
mergeMappingPreserve(original.Content[0], generated.Content[0])
normalizeCollectionNodeStyles(original.Content[0])
// Write back.
f, err := os.Create(configFile)
if err != nil {
return err
}
defer func() { _ = f.Close() }()
var buf bytes.Buffer
enc := yaml.NewEncoder(&buf)
enc.SetIndent(2)
if err = enc.Encode(&original); err != nil {
_ = enc.Close()
return err
}
if err = enc.Close(); err != nil {
return err
}
data = NormalizeCommentIndentation(buf.Bytes())
_, err = f.Write(data)
return err
}
// SaveConfigPreserveCommentsUpdateNestedScalar updates a nested scalar key path like ["a","b"]
// while preserving comments and positions.
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
data, err := os.ReadFile(configFile)
if err != nil {
return err
}
var root yaml.Node
if err = yaml.Unmarshal(data, &root); err != nil {
return err
}
if root.Kind != yaml.DocumentNode || len(root.Content) == 0 {
return fmt.Errorf("invalid yaml document structure")
}
node := root.Content[0]
// descend mapping nodes following path
for i, key := range path {
if i == len(path)-1 {
// set final scalar
v := getOrCreateMapValue(node, key)
v.Kind = yaml.ScalarNode
v.Tag = "!!str"
v.Value = value
} else {
next := getOrCreateMapValue(node, key)
if next.Kind != yaml.MappingNode {
next.Kind = yaml.MappingNode
next.Tag = "!!map"
}
node = next
}
}
f, err := os.Create(configFile)
if err != nil {
return err
}
defer func() { _ = f.Close() }()
var buf bytes.Buffer
enc := yaml.NewEncoder(&buf)
enc.SetIndent(2)
if err = enc.Encode(&root); err != nil {
_ = enc.Close()
return err
}
if err = enc.Close(); err != nil {
return err
}
data = NormalizeCommentIndentation(buf.Bytes())
_, err = f.Write(data)
return err
}
// NormalizeCommentIndentation removes indentation from standalone YAML comment lines to keep them left aligned.
func NormalizeCommentIndentation(data []byte) []byte {
lines := bytes.Split(data, []byte("\n"))
changed := false
for i, line := range lines {
trimmed := bytes.TrimLeft(line, " \t")
if len(trimmed) == 0 || trimmed[0] != '#' {
continue
}
if len(trimmed) == len(line) {
continue
}
lines[i] = append([]byte(nil), trimmed...)
changed = true
}
if !changed {
return data
}
return bytes.Join(lines, []byte("\n"))
}
// getOrCreateMapValue finds the value node for a given key in a mapping node.
// If not found, it appends a new key/value pair and returns the new value node.
func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
if mapNode.Kind != yaml.MappingNode {
mapNode.Kind = yaml.MappingNode
mapNode.Tag = "!!map"
mapNode.Content = nil
}
for i := 0; i+1 < len(mapNode.Content); i += 2 {
k := mapNode.Content[i]
if k.Value == key {
return mapNode.Content[i+1]
}
}
// append new key/value
mapNode.Content = append(mapNode.Content, &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: key})
val := &yaml.Node{Kind: yaml.ScalarNode, Tag: "!!str", Value: ""}
mapNode.Content = append(mapNode.Content, val)
return val
}
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
// key order and comments of existing keys in dst. New keys are only added if their
// value is non-zero and not a known default to avoid polluting the config with defaults.
func mergeMappingPreserve(dst, src *yaml.Node, path ...[]string) {
var currentPath []string
if len(path) > 0 {
currentPath = path[0]
}
if dst == nil || src == nil {
return
}
if dst.Kind != yaml.MappingNode || src.Kind != yaml.MappingNode {
// If kinds do not match, prefer replacing dst with src semantics in-place
// but keep dst node object to preserve any attached comments at the parent level.
copyNodeShallow(dst, src)
return
}
for i := 0; i+1 < len(src.Content); i += 2 {
sk := src.Content[i]
sv := src.Content[i+1]
idx := findMapKeyIndex(dst, sk.Value)
childPath := appendPath(currentPath, sk.Value)
if idx >= 0 {
// Merge into existing value node (always update, even to zero values)
dv := dst.Content[idx+1]
mergeNodePreserve(dv, sv, childPath)
} else {
// New key: only add if value is non-zero and not a known default
candidate := deepCopyNode(sv)
pruneKnownDefaultsInNewNode(childPath, candidate)
if isKnownDefaultValue(childPath, candidate) {
continue
}
dst.Content = append(dst.Content, deepCopyNode(sk), candidate)
}
}
}
// mergeNodePreserve merges src into dst for scalars, mappings and sequences while
// reusing destination nodes to keep comments and anchors. For sequences, it updates
// in-place by index.
func mergeNodePreserve(dst, src *yaml.Node, path ...[]string) {
var currentPath []string
if len(path) > 0 {
currentPath = path[0]
}
if dst == nil || src == nil {
return
}
switch src.Kind {
case yaml.MappingNode:
if dst.Kind != yaml.MappingNode {
copyNodeShallow(dst, src)
}
mergeMappingPreserve(dst, src, currentPath)
case yaml.SequenceNode:
// Preserve explicit null style if dst was null and src is empty sequence
if dst.Kind == yaml.ScalarNode && dst.Tag == "!!null" && len(src.Content) == 0 {
// Keep as null to preserve original style
return
}
if dst.Kind != yaml.SequenceNode {
dst.Kind = yaml.SequenceNode
dst.Tag = "!!seq"
dst.Content = nil
}
reorderSequenceForMerge(dst, src)
// Update elements in place
minContent := len(dst.Content)
if len(src.Content) < minContent {
minContent = len(src.Content)
}
for i := 0; i < minContent; i++ {
if dst.Content[i] == nil {
dst.Content[i] = deepCopyNode(src.Content[i])
continue
}
mergeNodePreserve(dst.Content[i], src.Content[i], currentPath)
if dst.Content[i] != nil && src.Content[i] != nil &&
dst.Content[i].Kind == yaml.MappingNode && src.Content[i].Kind == yaml.MappingNode {
pruneMissingMapKeys(dst.Content[i], src.Content[i])
}
}
// Append any extra items from src
for i := len(dst.Content); i < len(src.Content); i++ {
dst.Content = append(dst.Content, deepCopyNode(src.Content[i]))
}
// Truncate if dst has extra items not in src
if len(src.Content) < len(dst.Content) {
dst.Content = dst.Content[:len(src.Content)]
}
case yaml.ScalarNode, yaml.AliasNode:
// For scalars, update Tag and Value but keep Style from dst to preserve quoting
dst.Kind = src.Kind
dst.Tag = src.Tag
dst.Value = src.Value
// Keep dst.Style as-is intentionally
case 0:
// Unknown/empty kind; do nothing
default:
// Fallback: replace shallowly
copyNodeShallow(dst, src)
}
}
// findMapKeyIndex returns the index of key node in dst mapping (index of key, not value).
// Returns -1 when not found.
func findMapKeyIndex(mapNode *yaml.Node, key string) int {
if mapNode == nil || mapNode.Kind != yaml.MappingNode {
return -1
}
for i := 0; i+1 < len(mapNode.Content); i += 2 {
if mapNode.Content[i] != nil && mapNode.Content[i].Value == key {
return i
}
}
return -1
}
// appendPath appends a key to the path, returning a new slice to avoid modifying the original.
func appendPath(path []string, key string) []string {
if len(path) == 0 {
return []string{key}
}
newPath := make([]string, len(path)+1)
copy(newPath, path)
newPath[len(path)] = key
return newPath
}
// isKnownDefaultValue returns true if the given node at the specified path
// represents a known default value that should not be written to the config file.
// This prevents non-zero defaults from polluting the config.
func isKnownDefaultValue(path []string, node *yaml.Node) bool {
// First check if it's a zero value
if isZeroValueNode(node) {
return true
}
// Match known non-zero defaults by exact dotted path.
if len(path) == 0 {
return false
}
fullPath := strings.Join(path, ".")
// Check string defaults
if node.Kind == yaml.ScalarNode && node.Tag == "!!str" {
switch fullPath {
case "pprof.addr":
return node.Value == DefaultPprofAddr
case "remote-management.panel-github-repository":
return node.Value == DefaultPanelGitHubRepository
case "routing.strategy":
return node.Value == "round-robin"
}
}
// Check integer defaults
if node.Kind == yaml.ScalarNode && node.Tag == "!!int" {
switch fullPath {
case "error-logs-max-files":
return node.Value == "10"
}
}
return false
}
// pruneKnownDefaultsInNewNode removes default-valued descendants from a new node
// before it is appended into the destination YAML tree.
func pruneKnownDefaultsInNewNode(path []string, node *yaml.Node) {
if node == nil {
return
}
switch node.Kind {
case yaml.MappingNode:
filtered := make([]*yaml.Node, 0, len(node.Content))
for i := 0; i+1 < len(node.Content); i += 2 {
keyNode := node.Content[i]
valueNode := node.Content[i+1]
if keyNode == nil || valueNode == nil {
continue
}
childPath := appendPath(path, keyNode.Value)
if isKnownDefaultValue(childPath, valueNode) {
continue
}
pruneKnownDefaultsInNewNode(childPath, valueNode)
if (valueNode.Kind == yaml.MappingNode || valueNode.Kind == yaml.SequenceNode) &&
len(valueNode.Content) == 0 {
continue
}
filtered = append(filtered, keyNode, valueNode)
}
node.Content = filtered
case yaml.SequenceNode:
for _, child := range node.Content {
pruneKnownDefaultsInNewNode(path, child)
}
}
}
// isZeroValueNode returns true if the YAML node represents a zero/default value
// that should not be written as a new key to preserve config cleanliness.
// For mappings and sequences, recursively checks if all children are zero values.
func isZeroValueNode(node *yaml.Node) bool {
if node == nil {
return true
}
switch node.Kind {
case yaml.ScalarNode:
switch node.Tag {
case "!!bool":
return node.Value == "false"
case "!!int", "!!float":
return node.Value == "0" || node.Value == "0.0"
case "!!str":
return node.Value == ""
case "!!null":
return true
}
case yaml.SequenceNode:
if len(node.Content) == 0 {
return true
}
// Check if all elements are zero values
for _, child := range node.Content {
if !isZeroValueNode(child) {
return false
}
}
return true
case yaml.MappingNode:
if len(node.Content) == 0 {
return true
}
// Check if all values are zero values (values are at odd indices)
for i := 1; i < len(node.Content); i += 2 {
if !isZeroValueNode(node.Content[i]) {
return false
}
}
return true
}
return false
}
// deepCopyNode creates a deep copy of a yaml.Node graph.
func deepCopyNode(n *yaml.Node) *yaml.Node {
if n == nil {
return nil
}
cp := *n
if len(n.Content) > 0 {
cp.Content = make([]*yaml.Node, len(n.Content))
for i := range n.Content {
cp.Content[i] = deepCopyNode(n.Content[i])
}
}
return &cp
}
// copyNodeShallow copies type/tag/value and resets content to match src, but
// keeps the same destination node pointer to preserve parent relations/comments.
func copyNodeShallow(dst, src *yaml.Node) {
if dst == nil || src == nil {
return
}
dst.Kind = src.Kind
dst.Tag = src.Tag
dst.Value = src.Value
// Replace content with deep copy from src
if len(src.Content) > 0 {
dst.Content = make([]*yaml.Node, len(src.Content))
for i := range src.Content {
dst.Content[i] = deepCopyNode(src.Content[i])
}
} else {
dst.Content = nil
}
}
func reorderSequenceForMerge(dst, src *yaml.Node) {
if dst == nil || src == nil {
return
}
if len(dst.Content) == 0 {
return
}
if len(src.Content) == 0 {
return
}
original := append([]*yaml.Node(nil), dst.Content...)
used := make([]bool, len(original))
ordered := make([]*yaml.Node, len(src.Content))
for i := range src.Content {
if idx := matchSequenceElement(original, used, src.Content[i]); idx >= 0 {
ordered[i] = original[idx]
used[idx] = true
}
}
dst.Content = ordered
}
func matchSequenceElement(original []*yaml.Node, used []bool, target *yaml.Node) int {
if target == nil {
return -1
}
switch target.Kind {
case yaml.MappingNode:
id := sequenceElementIdentity(target)
if id != "" {
for i := range original {
if used[i] || original[i] == nil || original[i].Kind != yaml.MappingNode {
continue
}
if sequenceElementIdentity(original[i]) == id {
return i
}
}
}
case yaml.ScalarNode:
val := strings.TrimSpace(target.Value)
if val != "" {
for i := range original {
if used[i] || original[i] == nil || original[i].Kind != yaml.ScalarNode {
continue
}
if strings.TrimSpace(original[i].Value) == val {
return i
}
}
}
default:
}
// Fallback to structural equality to preserve nodes lacking explicit identifiers.
for i := range original {
if used[i] || original[i] == nil {
continue
}
if nodesStructurallyEqual(original[i], target) {
return i
}
}
return -1
}
func sequenceElementIdentity(node *yaml.Node) string {
if node == nil || node.Kind != yaml.MappingNode {
return ""
}
identityKeys := []string{"id", "name", "alias", "api-key", "api_key", "apikey", "key", "provider", "model"}
for _, k := range identityKeys {
if v := mappingScalarValue(node, k); v != "" {
return k + "=" + v
}
}
for i := 0; i+1 < len(node.Content); i += 2 {
keyNode := node.Content[i]
valNode := node.Content[i+1]
if keyNode == nil || valNode == nil || valNode.Kind != yaml.ScalarNode {
continue
}
val := strings.TrimSpace(valNode.Value)
if val != "" {
return strings.ToLower(strings.TrimSpace(keyNode.Value)) + "=" + val
}
}
return ""
}
func mappingScalarValue(node *yaml.Node, key string) string {
if node == nil || node.Kind != yaml.MappingNode {
return ""
}
lowerKey := strings.ToLower(key)
for i := 0; i+1 < len(node.Content); i += 2 {
keyNode := node.Content[i]
valNode := node.Content[i+1]
if keyNode == nil || valNode == nil || valNode.Kind != yaml.ScalarNode {
continue
}
if strings.ToLower(strings.TrimSpace(keyNode.Value)) == lowerKey {
return strings.TrimSpace(valNode.Value)
}
}
return ""
}
func nodesStructurallyEqual(a, b *yaml.Node) bool {
if a == nil || b == nil {
return a == b
}
if a.Kind != b.Kind {
return false
}
switch a.Kind {
case yaml.MappingNode:
if len(a.Content) != len(b.Content) {
return false
}
for i := 0; i+1 < len(a.Content); i += 2 {
if !nodesStructurallyEqual(a.Content[i], b.Content[i]) {
return false
}
if !nodesStructurallyEqual(a.Content[i+1], b.Content[i+1]) {
return false
}
}
return true
case yaml.SequenceNode:
if len(a.Content) != len(b.Content) {
return false
}
for i := range a.Content {
if !nodesStructurallyEqual(a.Content[i], b.Content[i]) {
return false
}
}
return true
case yaml.ScalarNode:
return strings.TrimSpace(a.Value) == strings.TrimSpace(b.Value)
case yaml.AliasNode:
return nodesStructurallyEqual(a.Alias, b.Alias)
default:
return strings.TrimSpace(a.Value) == strings.TrimSpace(b.Value)
}
}
func removeMapKey(mapNode *yaml.Node, key string) {
if mapNode == nil || mapNode.Kind != yaml.MappingNode || key == "" {
return
}
for i := 0; i+1 < len(mapNode.Content); i += 2 {
if mapNode.Content[i] != nil && mapNode.Content[i].Value == key {
mapNode.Content = append(mapNode.Content[:i], mapNode.Content[i+2:]...)
return
}
}
}
func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
if key == "" || dstRoot == nil || srcRoot == nil {
return
}
if dstRoot.Kind != yaml.MappingNode || srcRoot.Kind != yaml.MappingNode {
return
}
dstIdx := findMapKeyIndex(dstRoot, key)
if dstIdx < 0 || dstIdx+1 >= len(dstRoot.Content) {
return
}
srcIdx := findMapKeyIndex(srcRoot, key)
if srcIdx < 0 {
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
// When users delete the last channel from oauth-model-alias via the management API,
// we want that deletion to persist across hot reloads and restarts.
if key == "oauth-model-alias" {
dstRoot.Content[dstIdx+1] = &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
return
}
removeMapKey(dstRoot, key)
return
}
if srcIdx+1 >= len(srcRoot.Content) {
return
}
srcVal := srcRoot.Content[srcIdx+1]
dstVal := dstRoot.Content[dstIdx+1]
if srcVal == nil {
dstRoot.Content[dstIdx+1] = nil
return
}
if srcVal.Kind != yaml.MappingNode {
dstRoot.Content[dstIdx+1] = deepCopyNode(srcVal)
return
}
if dstVal == nil || dstVal.Kind != yaml.MappingNode {
dstRoot.Content[dstIdx+1] = deepCopyNode(srcVal)
return
}
pruneMissingMapKeys(dstVal, srcVal)
}
func pruneMissingMapKeys(dstMap, srcMap *yaml.Node) {
if dstMap == nil || srcMap == nil || dstMap.Kind != yaml.MappingNode || srcMap.Kind != yaml.MappingNode {
return
}
keep := make(map[string]struct{}, len(srcMap.Content)/2)
for i := 0; i+1 < len(srcMap.Content); i += 2 {
keyNode := srcMap.Content[i]
if keyNode == nil {
continue
}
key := strings.TrimSpace(keyNode.Value)
if key == "" {
continue
}
keep[key] = struct{}{}
}
for i := 0; i+1 < len(dstMap.Content); {
keyNode := dstMap.Content[i]
if keyNode == nil {
i += 2
continue
}
key := strings.TrimSpace(keyNode.Value)
if _, ok := keep[key]; !ok {
dstMap.Content = append(dstMap.Content[:i], dstMap.Content[i+2:]...)
continue
}
i += 2
}
}
// normalizeCollectionNodeStyles forces YAML collections to use block notation, keeping
// lists and maps readable. Empty sequences retain flow style ([]) so empty list markers
// remain compact.
func normalizeCollectionNodeStyles(node *yaml.Node) {
if node == nil {
return
}
switch node.Kind {
case yaml.MappingNode:
node.Style = 0
for i := range node.Content {
normalizeCollectionNodeStyles(node.Content[i])
}
case yaml.SequenceNode:
if len(node.Content) == 0 {
node.Style = yaml.FlowStyle
} else {
node.Style = 0
}
for i := range node.Content {
normalizeCollectionNodeStyles(node.Content[i])
}
default:
// Scalars keep their existing style to preserve quoting
}
}
// Legacy migration helpers (move deprecated config keys into structured fields).
type legacyConfigData struct {
LegacyGeminiKeys []string `yaml:"generative-language-api-key"`
OpenAICompat []legacyOpenAICompatibility `yaml:"openai-compatibility"`
AmpUpstreamURL string `yaml:"amp-upstream-url"`
AmpUpstreamAPIKey string `yaml:"amp-upstream-api-key"`
AmpRestrictManagement *bool `yaml:"amp-restrict-management-to-localhost"`
AmpModelMappings []AmpModelMapping `yaml:"amp-model-mappings"`
}
type legacyOpenAICompatibility struct {
Name string `yaml:"name"`
BaseURL string `yaml:"base-url"`
APIKeys []string `yaml:"api-keys"`
}
func (cfg *Config) migrateLegacyGeminiKeys(legacy []string) bool {
if cfg == nil || len(legacy) == 0 {
return false
}
changed := false
seen := make(map[string]struct{}, len(cfg.GeminiKey))
for i := range cfg.GeminiKey {
key := strings.TrimSpace(cfg.GeminiKey[i].APIKey)
if key == "" {
continue
}
seen[key] = struct{}{}
}
for _, raw := range legacy {
key := strings.TrimSpace(raw)
if key == "" {
continue
}
if _, exists := seen[key]; exists {
continue
}
cfg.GeminiKey = append(cfg.GeminiKey, GeminiKey{APIKey: key})
seen[key] = struct{}{}
changed = true
}
return changed
}
func (cfg *Config) migrateLegacyOpenAICompatibilityKeys(legacy []legacyOpenAICompatibility) bool {
if cfg == nil || len(cfg.OpenAICompatibility) == 0 || len(legacy) == 0 {
return false
}
changed := false
for _, legacyEntry := range legacy {
if len(legacyEntry.APIKeys) == 0 {
continue
}
target := findOpenAICompatTarget(cfg.OpenAICompatibility, legacyEntry.Name, legacyEntry.BaseURL)
if target == nil {
continue
}
if mergeLegacyOpenAICompatAPIKeys(target, legacyEntry.APIKeys) {
changed = true
}
}
return changed
}
func mergeLegacyOpenAICompatAPIKeys(entry *OpenAICompatibility, keys []string) bool {
if entry == nil || len(keys) == 0 {
return false
}
changed := false
existing := make(map[string]struct{}, len(entry.APIKeyEntries))
for i := range entry.APIKeyEntries {
key := strings.TrimSpace(entry.APIKeyEntries[i].APIKey)
if key == "" {
continue
}
existing[key] = struct{}{}
}
for _, raw := range keys {
key := strings.TrimSpace(raw)
if key == "" {
continue
}
if _, ok := existing[key]; ok {
continue
}
entry.APIKeyEntries = append(entry.APIKeyEntries, OpenAICompatibilityAPIKey{APIKey: key})
existing[key] = struct{}{}
changed = true
}
return changed
}
func findOpenAICompatTarget(entries []OpenAICompatibility, legacyName, legacyBase string) *OpenAICompatibility {
nameKey := strings.ToLower(strings.TrimSpace(legacyName))
baseKey := strings.ToLower(strings.TrimSpace(legacyBase))
if nameKey != "" && baseKey != "" {
for i := range entries {
if strings.ToLower(strings.TrimSpace(entries[i].Name)) == nameKey &&
strings.ToLower(strings.TrimSpace(entries[i].BaseURL)) == baseKey {
return &entries[i]
}
}
}
if baseKey != "" {
for i := range entries {
if strings.ToLower(strings.TrimSpace(entries[i].BaseURL)) == baseKey {
return &entries[i]
}
}
}
if nameKey != "" {
for i := range entries {
if strings.ToLower(strings.TrimSpace(entries[i].Name)) == nameKey {
return &entries[i]
}
}
}
return nil
}
func (cfg *Config) migrateLegacyAmpConfig(legacy *legacyConfigData) bool {
if cfg == nil || legacy == nil {
return false
}
changed := false
if cfg.AmpCode.UpstreamURL == "" {
if val := strings.TrimSpace(legacy.AmpUpstreamURL); val != "" {
cfg.AmpCode.UpstreamURL = val
changed = true
}
}
if cfg.AmpCode.UpstreamAPIKey == "" {
if val := strings.TrimSpace(legacy.AmpUpstreamAPIKey); val != "" {
cfg.AmpCode.UpstreamAPIKey = val
changed = true
}
}
if legacy.AmpRestrictManagement != nil {
cfg.AmpCode.RestrictManagementToLocalhost = *legacy.AmpRestrictManagement
changed = true
}
if len(cfg.AmpCode.ModelMappings) == 0 && len(legacy.AmpModelMappings) > 0 {
cfg.AmpCode.ModelMappings = append([]AmpModelMapping(nil), legacy.AmpModelMappings...)
changed = true
}
return changed
}
func removeLegacyOpenAICompatAPIKeys(root *yaml.Node) {
if root == nil || root.Kind != yaml.MappingNode {
return
}
idx := findMapKeyIndex(root, "openai-compatibility")
if idx < 0 || idx+1 >= len(root.Content) {
return
}
seq := root.Content[idx+1]
if seq == nil || seq.Kind != yaml.SequenceNode {
return
}
for i := range seq.Content {
if seq.Content[i] != nil && seq.Content[i].Kind == yaml.MappingNode {
removeMapKey(seq.Content[i], "api-keys")
}
}
}
func removeLegacyAmpKeys(root *yaml.Node) {
if root == nil || root.Kind != yaml.MappingNode {
return
}
removeMapKey(root, "amp-upstream-url")
removeMapKey(root, "amp-upstream-api-key")
removeMapKey(root, "amp-restrict-management-to-localhost")
removeMapKey(root, "amp-model-mappings")
}
func removeLegacyGenerativeLanguageKeys(root *yaml.Node) {
if root == nil || root.Kind != yaml.MappingNode {
return
}
removeMapKey(root, "generative-language-api-key")
}
func removeLegacyAuthBlock(root *yaml.Node) {
if root == nil || root.Kind != yaml.MappingNode {
return
}
removeMapKey(root, "auth")
}
================================================
FILE: internal/config/oauth_model_alias_test.go
================================================
package config
import "testing"
func TestSanitizeOAuthModelAlias_PreservesForkFlag(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
" CoDeX ": {
{Name: " gpt-5 ", Alias: " g5 ", Fork: true},
{Name: "gpt-6", Alias: "g6"},
},
},
}
cfg.SanitizeOAuthModelAlias()
aliases := cfg.OAuthModelAlias["codex"]
if len(aliases) != 2 {
t.Fatalf("expected 2 sanitized aliases, got %d", len(aliases))
}
if aliases[0].Name != "gpt-5" || aliases[0].Alias != "g5" || !aliases[0].Fork {
t.Fatalf("expected first alias to be gpt-5->g5 fork=true, got name=%q alias=%q fork=%v", aliases[0].Name, aliases[0].Alias, aliases[0].Fork)
}
if aliases[1].Name != "gpt-6" || aliases[1].Alias != "g6" || aliases[1].Fork {
t.Fatalf("expected second alias to be gpt-6->g6 fork=false, got name=%q alias=%q fork=%v", aliases[1].Name, aliases[1].Alias, aliases[1].Fork)
}
}
func TestSanitizeOAuthModelAlias_AllowsMultipleAliasesForSameName(t *testing.T) {
cfg := &Config{
OAuthModelAlias: map[string][]OAuthModelAlias{
"antigravity": {
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
},
},
}
cfg.SanitizeOAuthModelAlias()
aliases := cfg.OAuthModelAlias["antigravity"]
expected := []OAuthModelAlias{
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5-20251101-thinking", Fork: true},
{Name: "gemini-claude-opus-4-5-thinking", Alias: "claude-opus-4-5", Fork: true},
}
if len(aliases) != len(expected) {
t.Fatalf("expected %d sanitized aliases, got %d", len(expected), len(aliases))
}
for i, exp := range expected {
if aliases[i].Name != exp.Name || aliases[i].Alias != exp.Alias || aliases[i].Fork != exp.Fork {
t.Fatalf("expected alias %d to be name=%q alias=%q fork=%v, got name=%q alias=%q fork=%v", i, exp.Name, exp.Alias, exp.Fork, aliases[i].Name, aliases[i].Alias, aliases[i].Fork)
}
}
}
================================================
FILE: internal/config/sdk_config.go
================================================
// Package config provides configuration management for the CLI Proxy API server.
// It handles loading and parsing YAML configuration files, and provides structured
// access to application settings including server port, authentication directory,
// debug settings, proxy configuration, and API keys.
package config
// SDKConfig represents the application's configuration, loaded from a YAML file.
type SDKConfig struct {
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
// credentials as well.
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
// RequestLog enables or disables detailed request logging functionality.
RequestLog bool `yaml:"request-log" json:"request-log"`
// APIKeys is a list of keys for authenticating clients to this proxy server.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
// PassthroughHeaders controls whether upstream response headers are forwarded to downstream clients.
// Default is false (disabled).
PassthroughHeaders bool `yaml:"passthrough-headers" json:"passthrough-headers"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
// NonStreamKeepAliveInterval controls how often blank lines are emitted for non-streaming responses.
// <= 0 disables keep-alives. Value is in seconds.
NonStreamKeepAliveInterval int `yaml:"nonstream-keepalive-interval,omitempty" json:"nonstream-keepalive-interval,omitempty"`
}
// StreamingConfig holds server streaming behavior configuration.
type StreamingConfig struct {
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
// <= 0 disables keep-alives. Default is 0.
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
// to allow auth rotation / transient recovery.
// <= 0 disables bootstrap retries. Default is 0.
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
}
================================================
FILE: internal/config/vertex_compat.go
================================================
package config
import "strings"
// VertexCompatKey represents the configuration for Vertex AI-compatible API keys.
// This supports third-party services that use Vertex AI-style endpoint paths
// (/publishers/google/models/{model}:streamGenerateContent) but authenticate
// with simple API keys instead of Google Cloud service account credentials.
//
// Example services: zenmux.ai and similar Vertex-compatible providers.
type VertexCompatKey struct {
// APIKey is the authentication key for accessing the Vertex-compatible API.
// Maps to the x-goog-api-key header.
APIKey string `yaml:"api-key" json:"api-key"`
// Priority controls selection preference when multiple credentials match.
// Higher values are preferred; defaults to 0.
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
// BaseURL optionally overrides the Vertex-compatible API endpoint.
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
// When empty, requests fall back to the default Vertex API base URL.
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
// ProxyURL optionally overrides the global proxy for this API key.
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
// Headers optionally adds extra HTTP headers for requests sent with this key.
// Commonly used for cookies, user-agent, and other authentication headers.
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
// Models defines the model configurations including aliases for routing.
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
// ExcludedModels lists model IDs that should be excluded for this provider.
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
}
func (k VertexCompatKey) GetAPIKey() string { return k.APIKey }
func (k VertexCompatKey) GetBaseURL() string { return k.BaseURL }
// VertexCompatModel represents a model configuration for Vertex compatibility,
// including the actual model name and its alias for API routing.
type VertexCompatModel struct {
// Name is the actual model name used by the external provider.
Name string `yaml:"name" json:"name"`
// Alias is the model name alias that clients will use to reference this model.
Alias string `yaml:"alias" json:"alias"`
}
func (m VertexCompatModel) GetName() string { return m.Name }
func (m VertexCompatModel) GetAlias() string { return m.Alias }
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
func (cfg *Config) SanitizeVertexCompatKeys() {
if cfg == nil {
return
}
seen := make(map[string]struct{}, len(cfg.VertexCompatAPIKey))
out := cfg.VertexCompatAPIKey[:0]
for i := range cfg.VertexCompatAPIKey {
entry := cfg.VertexCompatAPIKey[i]
entry.APIKey = strings.TrimSpace(entry.APIKey)
if entry.APIKey == "" {
continue
}
entry.Prefix = normalizeModelPrefix(entry.Prefix)
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
entry.Headers = NormalizeHeaders(entry.Headers)
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
// Sanitize models: remove entries without valid alias
sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models))
for _, model := range entry.Models {
model.Alias = strings.TrimSpace(model.Alias)
model.Name = strings.TrimSpace(model.Name)
if model.Alias != "" && model.Name != "" {
sanitizedModels = append(sanitizedModels, model)
}
}
entry.Models = sanitizedModels
// Use API key + base URL as uniqueness key
uniqueKey := entry.APIKey + "|" + entry.BaseURL
if _, exists := seen[uniqueKey]; exists {
continue
}
seen[uniqueKey] = struct{}{}
out = append(out, entry)
}
cfg.VertexCompatAPIKey = out
}
================================================
FILE: internal/constant/constant.go
================================================
// Package constant defines provider name constants used throughout the CLI Proxy API.
// These constants identify different AI service providers and their variants,
// ensuring consistent naming across the application.
package constant
const (
// Gemini represents the Google Gemini provider identifier.
Gemini = "gemini"
// GeminiCLI represents the Google Gemini CLI provider identifier.
GeminiCLI = "gemini-cli"
// Codex represents the OpenAI Codex provider identifier.
Codex = "codex"
// Claude represents the Anthropic Claude provider identifier.
Claude = "claude"
// OpenAI represents the OpenAI provider identifier.
OpenAI = "openai"
// OpenaiResponse represents the OpenAI response format identifier.
OpenaiResponse = "openai-response"
// Antigravity represents the Antigravity response format identifier.
Antigravity = "antigravity"
)
================================================
FILE: internal/interfaces/api_handler.go
================================================
// Package interfaces defines the core interfaces and shared structures for the CLI Proxy API server.
// These interfaces provide a common contract for different components of the application,
// such as AI service clients, API handlers, and data models.
package interfaces
// APIHandler defines the interface that all API handlers must implement.
// This interface provides methods for identifying handler types and retrieving
// supported models for different AI service endpoints.
type APIHandler interface {
// HandlerType returns the type identifier for this API handler.
// This is used to determine which request/response translators to use.
HandlerType() string
// Models returns a list of supported models for this API handler.
// Each model is represented as a map containing model metadata.
Models() []map[string]any
}
================================================
FILE: internal/interfaces/client_models.go
================================================
// Package interfaces defines the core interfaces and shared structures for the CLI Proxy API server.
// These interfaces provide a common contract for different components of the application,
// such as AI service clients, API handlers, and data models.
package interfaces
import (
"time"
)
// GCPProject represents the response structure for a Google Cloud project list request.
// This structure is used when fetching available projects for a Google Cloud account.
type GCPProject struct {
// Projects is a list of Google Cloud projects accessible by the user.
Projects []GCPProjectProjects `json:"projects"`
}
// GCPProjectLabels defines the labels associated with a GCP project.
// These labels can contain metadata about the project's purpose or configuration.
type GCPProjectLabels struct {
// GenerativeLanguage indicates if the project has generative language APIs enabled.
GenerativeLanguage string `json:"generative-language"`
}
// GCPProjectProjects contains details about a single Google Cloud project.
// This includes identifying information, metadata, and configuration details.
type GCPProjectProjects struct {
// ProjectNumber is the unique numeric identifier for the project.
ProjectNumber string `json:"projectNumber"`
// ProjectID is the unique string identifier for the project.
ProjectID string `json:"projectId"`
// LifecycleState indicates the current state of the project (e.g., "ACTIVE").
LifecycleState string `json:"lifecycleState"`
// Name is the human-readable name of the project.
Name string `json:"name"`
// Labels contains metadata labels associated with the project.
Labels GCPProjectLabels `json:"labels"`
// CreateTime is the timestamp when the project was created.
CreateTime time.Time `json:"createTime"`
}
// Content represents a single message in a conversation, with a role and parts.
// This structure models a message exchange between a user and an AI model.
type Content struct {
// Role indicates who sent the message ("user", "model", or "tool").
Role string `json:"role"`
// Parts is a collection of content parts that make up the message.
Parts []Part `json:"parts"`
}
// Part represents a distinct piece of content within a message.
// A part can be text, inline data (like an image), a function call, or a function response.
type Part struct {
Thought bool `json:"thought,omitempty"`
// Text contains plain text content.
Text string `json:"text,omitempty"`
// InlineData contains base64-encoded data with its MIME type (e.g., images).
InlineData *InlineData `json:"inlineData,omitempty"`
// ThoughtSignature is a provider-required signature that accompanies certain parts.
ThoughtSignature string `json:"thoughtSignature,omitempty"`
// FunctionCall represents a tool call requested by the model.
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
// FunctionResponse represents the result of a tool execution.
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
}
// InlineData represents base64-encoded data with its MIME type.
// This is typically used for embedding images or other binary data in requests.
type InlineData struct {
// MimeType specifies the media type of the embedded data (e.g., "image/png").
MimeType string `json:"mime_type,omitempty"`
// Data contains the base64-encoded binary data.
Data string `json:"data,omitempty"`
}
// FunctionCall represents a tool call requested by the model.
// It includes the function name and its arguments that the model wants to execute.
type FunctionCall struct {
// ID is the identifier of the function to be called.
ID string `json:"id,omitempty"`
// Name is the identifier of the function to be called.
Name string `json:"name"`
// Args contains the arguments to pass to the function.
Args map[string]interface{} `json:"args"`
}
// FunctionResponse represents the result of a tool execution.
// This is sent back to the model after a tool call has been processed.
type FunctionResponse struct {
// ID is the identifier of the function to be called.
ID string `json:"id,omitempty"`
// Name is the identifier of the function that was called.
Name string `json:"name"`
// Response contains the result data from the function execution.
Response map[string]interface{} `json:"response"`
}
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
// This structure defines all the parameters needed for generating content from an AI model.
type GenerateContentRequest struct {
// SystemInstruction provides system-level instructions that guide the model's behavior.
SystemInstruction *Content `json:"systemInstruction,omitempty"`
// Contents is the conversation history between the user and the model.
Contents []Content `json:"contents"`
// Tools defines the available tools/functions that the model can call.
Tools []ToolDeclaration `json:"tools,omitempty"`
// GenerationConfig contains parameters that control the model's generation behavior.
GenerationConfig `json:"generationConfig"`
}
// GenerationConfig defines parameters that control the model's generation behavior.
// These parameters affect the creativity, randomness, and reasoning of the model's responses.
type GenerationConfig struct {
// ThinkingConfig specifies configuration for the model's "thinking" process.
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
// Temperature controls the randomness of the model's responses.
// Values closer to 0 make responses more deterministic, while values closer to 1 increase randomness.
Temperature float64 `json:"temperature,omitempty"`
// TopP controls nucleus sampling, which affects the diversity of responses.
// It limits the model to consider only the top P% of probability mass.
TopP float64 `json:"topP,omitempty"`
// TopK limits the model to consider only the top K most likely tokens.
// This can help control the quality and diversity of generated text.
TopK float64 `json:"topK,omitempty"`
}
// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process.
// This controls whether the model should output its reasoning process along with the final answer.
type GenerationConfigThinkingConfig struct {
// IncludeThoughts determines whether the model should output its reasoning process.
// When enabled, the model will include its step-by-step thinking in the response.
IncludeThoughts bool `json:"include_thoughts,omitempty"`
}
// ToolDeclaration defines the structure for declaring tools (like functions)
// that the model can call during content generation.
type ToolDeclaration struct {
// FunctionDeclarations is a list of available functions that the model can call.
FunctionDeclarations []interface{} `json:"functionDeclarations"`
}
================================================
FILE: internal/interfaces/error_message.go
================================================
// Package interfaces defines the core interfaces and shared structures for the CLI Proxy API server.
// These interfaces provide a common contract for different components of the application,
// such as AI service clients, API handlers, and data models.
package interfaces
import "net/http"
// ErrorMessage encapsulates an error with an associated HTTP status code.
// This structure is used to provide detailed error information including
// both the HTTP status and the underlying error.
type ErrorMessage struct {
// StatusCode is the HTTP status code returned by the API.
StatusCode int
// Error is the underlying error that occurred.
Error error
// Addon contains additional headers to be added to the response.
Addon http.Header
}
================================================
FILE: internal/interfaces/types.go
================================================
// Package interfaces provides type aliases for backwards compatibility with translator functions.
// It defines common interface types used throughout the CLI Proxy API for request and response
// transformation operations, maintaining compatibility with the SDK translator package.
package interfaces
import sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
// Backwards compatible aliases for translator function types.
type TranslateRequestFunc = sdktranslator.RequestTransform
type TranslateResponseFunc = sdktranslator.ResponseStreamTransform
type TranslateResponseNonStreamFunc = sdktranslator.ResponseNonStreamTransform
type TranslateResponse = sdktranslator.ResponseTransform
================================================
FILE: internal/logging/gin_logger.go
================================================
// Package logging provides Gin middleware for HTTP request logging and panic recovery.
// It integrates Gin web framework with logrus for structured logging of HTTP requests,
// responses, and error handling with panic recovery capabilities.
package logging
import (
"errors"
"fmt"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
)
// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
var aiAPIPrefixes = []string{
"/v1/chat/completions",
"/v1/completions",
"/v1/messages",
"/v1/responses",
"/v1beta/models/",
"/api/provider/",
}
const skipGinLogKey = "__gin_skip_request_logging__"
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
// using logrus. It captures request details including method, path, status code, latency,
// client IP, and any error messages. Request ID is only added for AI API requests.
//
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
//
// Returns:
// - gin.HandlerFunc: A middleware handler for request logging
func GinLogrusLogger() gin.HandlerFunc {
return func(c *gin.Context) {
start := time.Now()
path := c.Request.URL.Path
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
// Only generate request ID for AI API paths
var requestID string
if isAIAPIPath(path) {
requestID = GenerateRequestID()
SetGinRequestID(c, requestID)
ctx := WithRequestID(c.Request.Context(), requestID)
c.Request = c.Request.WithContext(ctx)
}
c.Next()
if shouldSkipGinRequestLogging(c) {
return
}
if raw != "" {
path = path + "?" + raw
}
latency := time.Since(start)
if latency > time.Minute {
latency = latency.Truncate(time.Second)
} else {
latency = latency.Truncate(time.Millisecond)
}
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
method := c.Request.Method
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
if requestID == "" {
requestID = "--------"
}
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
if errorMessage != "" {
logLine = logLine + " | " + errorMessage
}
entry := log.WithField("request_id", requestID)
switch {
case statusCode >= http.StatusInternalServerError:
entry.Error(logLine)
case statusCode >= http.StatusBadRequest:
entry.Warn(logLine)
default:
entry.Info(logLine)
}
}
}
// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
func isAIAPIPath(path string) bool {
for _, prefix := range aiAPIPrefixes {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
// them using logrus. When a panic occurs, it captures the panic value, stack trace,
// and request path, then returns a 500 Internal Server Error response to the client.
//
// Returns:
// - gin.HandlerFunc: A middleware handler for panic recovery
func GinLogrusRecovery() gin.HandlerFunc {
return gin.CustomRecovery(func(c *gin.Context, recovered interface{}) {
if err, ok := recovered.(error); ok && errors.Is(err, http.ErrAbortHandler) {
// Let net/http handle ErrAbortHandler so the connection is aborted without noisy stack logs.
panic(http.ErrAbortHandler)
}
log.WithFields(log.Fields{
"panic": recovered,
"stack": string(debug.Stack()),
"path": c.Request.URL.Path,
}).Error("recovered from panic")
c.AbortWithStatus(http.StatusInternalServerError)
})
}
// SkipGinRequestLogging marks the provided Gin context so that GinLogrusLogger
// will skip emitting a log line for the associated request.
func SkipGinRequestLogging(c *gin.Context) {
if c == nil {
return
}
c.Set(skipGinLogKey, true)
}
func shouldSkipGinRequestLogging(c *gin.Context) bool {
if c == nil {
return false
}
val, exists := c.Get(skipGinLogKey)
if !exists {
return false
}
flag, ok := val.(bool)
return ok && flag
}
================================================
FILE: internal/logging/gin_logger_test.go
================================================
package logging
import (
"errors"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestGinLogrusRecoveryRepanicsErrAbortHandler(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
engine.Use(GinLogrusRecovery())
engine.GET("/abort", func(c *gin.Context) {
panic(http.ErrAbortHandler)
})
req := httptest.NewRequest(http.MethodGet, "/abort", nil)
recorder := httptest.NewRecorder()
defer func() {
recovered := recover()
if recovered == nil {
t.Fatalf("expected panic, got nil")
}
err, ok := recovered.(error)
if !ok {
t.Fatalf("expected error panic, got %T", recovered)
}
if !errors.Is(err, http.ErrAbortHandler) {
t.Fatalf("expected ErrAbortHandler, got %v", err)
}
if err != http.ErrAbortHandler {
t.Fatalf("expected exact ErrAbortHandler sentinel, got %v", err)
}
}()
engine.ServeHTTP(recorder, req)
}
func TestGinLogrusRecoveryHandlesRegularPanic(t *testing.T) {
gin.SetMode(gin.TestMode)
engine := gin.New()
engine.Use(GinLogrusRecovery())
engine.GET("/panic", func(c *gin.Context) {
panic("boom")
})
req := httptest.NewRequest(http.MethodGet, "/panic", nil)
recorder := httptest.NewRecorder()
engine.ServeHTTP(recorder, req)
if recorder.Code != http.StatusInternalServerError {
t.Fatalf("expected 500, got %d", recorder.Code)
}
}
================================================
FILE: internal/logging/global_logger.go
================================================
package logging
import (
"bytes"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"sync"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"gopkg.in/natefinch/lumberjack.v2"
)
var (
setupOnce sync.Once
writerMu sync.Mutex
logWriter *lumberjack.Logger
ginInfoWriter *io.PipeWriter
ginErrorWriter *io.PipeWriter
)
// LogFormatter defines a custom log format for logrus.
// This formatter adds timestamp, level, request ID, and source location to each log entry.
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
type LogFormatter struct{}
// logFieldOrder defines the display order for common log fields.
var logFieldOrder = []string{"provider", "model", "mode", "budget", "level", "original_mode", "original_value", "min", "max", "clamped_to", "error"}
// Format renders a single log entry with custom formatting.
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
var buffer *bytes.Buffer
if entry.Buffer != nil {
buffer = entry.Buffer
} else {
buffer = &bytes.Buffer{}
}
timestamp := entry.Time.Format("2006-01-02 15:04:05")
message := strings.TrimRight(entry.Message, "\r\n")
reqID := "--------"
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
reqID = id
}
level := entry.Level.String()
if level == "warning" {
level = "warn"
}
levelStr := fmt.Sprintf("%-5s", level)
// Build fields string (only print fields in logFieldOrder)
var fieldsStr string
if len(entry.Data) > 0 {
var fields []string
for _, k := range logFieldOrder {
if v, ok := entry.Data[k]; ok {
fields = append(fields, fmt.Sprintf("%s=%v", k, v))
}
}
if len(fields) > 0 {
fieldsStr = " " + strings.Join(fields, " ")
}
}
var formatted string
if entry.Caller != nil {
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s%s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message, fieldsStr)
} else {
formatted = fmt.Sprintf("[%s] [%s] [%s] %s%s\n", timestamp, reqID, levelStr, message, fieldsStr)
}
buffer.WriteString(formatted)
return buffer.Bytes(), nil
}
// SetupBaseLogger configures the shared logrus instance and Gin writers.
// It is safe to call multiple times; initialization happens only once.
func SetupBaseLogger() {
setupOnce.Do(func() {
log.SetOutput(os.Stdout)
log.SetReportCaller(true)
log.SetFormatter(&LogFormatter{})
ginInfoWriter = log.StandardLogger().Writer()
gin.DefaultWriter = ginInfoWriter
ginErrorWriter = log.StandardLogger().WriterLevel(log.ErrorLevel)
gin.DefaultErrorWriter = ginErrorWriter
gin.DebugPrintFunc = func(format string, values ...interface{}) {
format = strings.TrimRight(format, "\r\n")
log.StandardLogger().Infof(format, values...)
}
log.RegisterExitHandler(closeLogOutputs)
})
}
// isDirWritable checks if the specified directory exists and is writable by attempting to create and remove a test file.
func isDirWritable(dir string) bool {
info, err := os.Stat(dir)
if err != nil || !info.IsDir() {
return false
}
testFile := filepath.Join(dir, ".perm_test")
f, err := os.Create(testFile)
if err != nil {
return false
}
defer func() {
_ = f.Close()
_ = os.Remove(testFile)
}()
return true
}
// ResolveLogDirectory determines the directory used for application logs.
func ResolveLogDirectory(cfg *config.Config) string {
logDir := "logs"
if base := util.WritablePath(); base != "" {
return filepath.Join(base, "logs")
}
if cfg == nil {
return logDir
}
if !isDirWritable(logDir) {
authDir, err := util.ResolveAuthDir(cfg.AuthDir)
if err != nil {
log.Warnf("Failed to resolve auth-dir %q for log directory: %v", cfg.AuthDir, err)
}
if authDir != "" {
logDir = filepath.Join(authDir, "logs")
}
}
return logDir
}
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
// until the total size is within the limit.
func ConfigureLogOutput(cfg *config.Config) error {
SetupBaseLogger()
writerMu.Lock()
defer writerMu.Unlock()
logDir := ResolveLogDirectory(cfg)
protectedPath := ""
if cfg.LoggingToFile {
if err := os.MkdirAll(logDir, 0o755); err != nil {
return fmt.Errorf("logging: failed to create log directory: %w", err)
}
if logWriter != nil {
_ = logWriter.Close()
}
protectedPath = filepath.Join(logDir, "main.log")
logWriter = &lumberjack.Logger{
Filename: protectedPath,
MaxSize: 10,
MaxBackups: 0,
MaxAge: 0,
Compress: false,
}
log.SetOutput(logWriter)
} else {
if logWriter != nil {
_ = logWriter.Close()
logWriter = nil
}
log.SetOutput(os.Stdout)
}
configureLogDirCleanerLocked(logDir, cfg.LogsMaxTotalSizeMB, protectedPath)
return nil
}
func closeLogOutputs() {
writerMu.Lock()
defer writerMu.Unlock()
stopLogDirCleanerLocked()
if logWriter != nil {
_ = logWriter.Close()
logWriter = nil
}
if ginInfoWriter != nil {
_ = ginInfoWriter.Close()
ginInfoWriter = nil
}
if ginErrorWriter != nil {
_ = ginErrorWriter.Close()
ginErrorWriter = nil
}
}
================================================
FILE: internal/logging/log_dir_cleaner.go
================================================
package logging
import (
"context"
"os"
"path/filepath"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const logDirCleanerInterval = time.Minute
var logDirCleanerCancel context.CancelFunc
func configureLogDirCleanerLocked(logDir string, maxTotalSizeMB int, protectedPath string) {
stopLogDirCleanerLocked()
if maxTotalSizeMB <= 0 {
return
}
maxBytes := int64(maxTotalSizeMB) * 1024 * 1024
if maxBytes <= 0 {
return
}
dir := strings.TrimSpace(logDir)
if dir == "" {
return
}
ctx, cancel := context.WithCancel(context.Background())
logDirCleanerCancel = cancel
go runLogDirCleaner(ctx, filepath.Clean(dir), maxBytes, strings.TrimSpace(protectedPath))
}
func stopLogDirCleanerLocked() {
if logDirCleanerCancel == nil {
return
}
logDirCleanerCancel()
logDirCleanerCancel = nil
}
func runLogDirCleaner(ctx context.Context, logDir string, maxBytes int64, protectedPath string) {
ticker := time.NewTicker(logDirCleanerInterval)
defer ticker.Stop()
cleanOnce := func() {
deleted, errClean := enforceLogDirSizeLimit(logDir, maxBytes, protectedPath)
if errClean != nil {
log.WithError(errClean).Warn("logging: failed to enforce log directory size limit")
return
}
if deleted > 0 {
log.Debugf("logging: removed %d old log file(s) to enforce log directory size limit", deleted)
}
}
cleanOnce()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
cleanOnce()
}
}
}
func enforceLogDirSizeLimit(logDir string, maxBytes int64, protectedPath string) (int, error) {
if maxBytes <= 0 {
return 0, nil
}
dir := strings.TrimSpace(logDir)
if dir == "" {
return 0, nil
}
dir = filepath.Clean(dir)
entries, errRead := os.ReadDir(dir)
if errRead != nil {
if os.IsNotExist(errRead) {
return 0, nil
}
return 0, errRead
}
protected := strings.TrimSpace(protectedPath)
if protected != "" {
protected = filepath.Clean(protected)
}
type logFile struct {
path string
size int64
modTime time.Time
}
var (
files []logFile
total int64
)
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !isLogFileName(name) {
continue
}
info, errInfo := entry.Info()
if errInfo != nil {
continue
}
if !info.Mode().IsRegular() {
continue
}
path := filepath.Join(dir, name)
files = append(files, logFile{
path: path,
size: info.Size(),
modTime: info.ModTime(),
})
total += info.Size()
}
if total <= maxBytes {
return 0, nil
}
sort.Slice(files, func(i, j int) bool {
return files[i].modTime.Before(files[j].modTime)
})
deleted := 0
for _, file := range files {
if total <= maxBytes {
break
}
if protected != "" && filepath.Clean(file.path) == protected {
continue
}
if errRemove := os.Remove(file.path); errRemove != nil {
log.WithError(errRemove).Warnf("logging: failed to remove old log file: %s", filepath.Base(file.path))
continue
}
total -= file.size
deleted++
}
return deleted, nil
}
func isLogFileName(name string) bool {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return false
}
lower := strings.ToLower(trimmed)
return strings.HasSuffix(lower, ".log") || strings.HasSuffix(lower, ".log.gz")
}
================================================
FILE: internal/logging/log_dir_cleaner_test.go
================================================
package logging
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestEnforceLogDirSizeLimitDeletesOldest(t *testing.T) {
dir := t.TempDir()
writeLogFile(t, filepath.Join(dir, "old.log"), 60, time.Unix(1, 0))
writeLogFile(t, filepath.Join(dir, "mid.log"), 60, time.Unix(2, 0))
protected := filepath.Join(dir, "main.log")
writeLogFile(t, protected, 60, time.Unix(3, 0))
deleted, err := enforceLogDirSizeLimit(dir, 120, protected)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if deleted != 1 {
t.Fatalf("expected 1 deleted file, got %d", deleted)
}
if _, err := os.Stat(filepath.Join(dir, "old.log")); !os.IsNotExist(err) {
t.Fatalf("expected old.log to be removed, stat error: %v", err)
}
if _, err := os.Stat(filepath.Join(dir, "mid.log")); err != nil {
t.Fatalf("expected mid.log to remain, stat error: %v", err)
}
if _, err := os.Stat(protected); err != nil {
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
}
}
func TestEnforceLogDirSizeLimitSkipsProtected(t *testing.T) {
dir := t.TempDir()
protected := filepath.Join(dir, "main.log")
writeLogFile(t, protected, 200, time.Unix(1, 0))
writeLogFile(t, filepath.Join(dir, "other.log"), 50, time.Unix(2, 0))
deleted, err := enforceLogDirSizeLimit(dir, 100, protected)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if deleted != 1 {
t.Fatalf("expected 1 deleted file, got %d", deleted)
}
if _, err := os.Stat(protected); err != nil {
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
}
if _, err := os.Stat(filepath.Join(dir, "other.log")); !os.IsNotExist(err) {
t.Fatalf("expected other.log to be removed, stat error: %v", err)
}
}
func writeLogFile(t *testing.T, path string, size int, modTime time.Time) {
t.Helper()
data := make([]byte, size)
if err := os.WriteFile(path, data, 0o644); err != nil {
t.Fatalf("write file: %v", err)
}
if err := os.Chtimes(path, modTime, modTime); err != nil {
t.Fatalf("set times: %v", err)
}
}
================================================
FILE: internal/logging/request_logger.go
================================================
// Package logging provides request logging functionality for the CLI Proxy API server.
// It handles capturing and storing detailed HTTP request and response data when enabled
// through configuration, supporting both regular and streaming responses.
package logging
import (
"bytes"
"compress/flate"
"compress/gzip"
"fmt"
"io"
"os"
"path/filepath"
"regexp"
"sort"
"strings"
"sync/atomic"
"time"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
log "github.com/sirupsen/logrus"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
var requestLogID atomic.Uint64
// RequestLogger defines the interface for logging HTTP requests and responses.
// It provides methods for logging both regular and streaming HTTP request/response cycles.
type RequestLogger interface {
// LogRequest logs a complete non-streaming request/response cycle.
//
// Parameters:
// - url: The request URL
// - method: The HTTP method
// - requestHeaders: The request headers
// - body: The request body
// - statusCode: The response status code
// - responseHeaders: The response headers
// - response: The raw response data
// - apiRequest: The API request data
// - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
// - requestTimestamp: When the request was received
// - apiResponseTimestamp: When the API response was received
//
// Returns:
// - error: An error if logging fails, nil otherwise
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
//
// Parameters:
// - url: The request URL
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - requestID: Optional request ID for log file naming
//
// Returns:
// - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise
LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error)
// IsEnabled returns whether request logging is currently enabled.
//
// Returns:
// - bool: True if logging is enabled, false otherwise
IsEnabled() bool
}
// StreamingLogWriter handles real-time logging of streaming response chunks.
// It provides methods for writing streaming response data asynchronously.
type StreamingLogWriter interface {
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
//
// Parameters:
// - chunk: The response chunk to write
WriteChunkAsync(chunk []byte)
// WriteStatus writes the response status and headers to the log.
//
// Parameters:
// - status: The response status code
// - headers: The response headers
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteStatus(status int, headers map[string][]string) error
// WriteAPIRequest writes the upstream API request details to the log.
// This should be called before WriteStatus to maintain proper log ordering.
//
// Parameters:
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIRequest(apiRequest []byte) error
// WriteAPIResponse writes the upstream API response details to the log.
// This should be called after the streaming response is complete.
//
// Parameters:
// - apiResponse: The API response data
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIResponse(apiResponse []byte) error
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
//
// Parameters:
// - timestamp: The time when first response chunk was received
SetFirstChunkTimestamp(timestamp time.Time)
// Close finalizes the log file and cleans up resources.
//
// Returns:
// - error: An error if closing fails, nil otherwise
Close() error
}
// FileRequestLogger implements RequestLogger using file-based storage.
// It provides file-based logging functionality for HTTP requests and responses.
type FileRequestLogger struct {
// enabled indicates whether request logging is currently enabled.
enabled bool
// logsDir is the directory where log files are stored.
logsDir string
// errorLogsMaxFiles limits the number of error log files retained.
errorLogsMaxFiles int
}
// NewFileRequestLogger creates a new file-based request logger.
//
// Parameters:
// - enabled: Whether request logging should be enabled
// - logsDir: The directory where log files should be stored (can be relative)
// - configDir: The directory of the configuration file; when logsDir is
// relative, it will be resolved relative to this directory
// - errorLogsMaxFiles: Maximum number of error log files to retain (0 = no cleanup)
//
// Returns:
// - *FileRequestLogger: A new file-based request logger instance
func NewFileRequestLogger(enabled bool, logsDir string, configDir string, errorLogsMaxFiles int) *FileRequestLogger {
// Resolve logsDir relative to the configuration file directory when it's not absolute.
if !filepath.IsAbs(logsDir) {
// If configDir is provided, resolve logsDir relative to it.
if configDir != "" {
logsDir = filepath.Join(configDir, logsDir)
}
}
return &FileRequestLogger{
enabled: enabled,
logsDir: logsDir,
errorLogsMaxFiles: errorLogsMaxFiles,
}
}
// IsEnabled returns whether request logging is currently enabled.
//
// Returns:
// - bool: True if logging is enabled, false otherwise
func (l *FileRequestLogger) IsEnabled() bool {
return l.enabled
}
// SetEnabled updates the request logging enabled state.
// This method allows dynamic enabling/disabling of request logging.
//
// Parameters:
// - enabled: Whether request logging should be enabled
func (l *FileRequestLogger) SetEnabled(enabled bool) {
l.enabled = enabled
}
// SetErrorLogsMaxFiles updates the maximum number of error log files to retain.
func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
l.errorLogsMaxFiles = maxFiles
}
// LogRequest logs a complete non-streaming request/response cycle to a file.
//
// Parameters:
// - url: The request URL
// - method: The HTTP method
// - requestHeaders: The request headers
// - body: The request body
// - statusCode: The response status code
// - responseHeaders: The response headers
// - response: The raw response data
// - apiRequest: The API request data
// - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
// - requestTimestamp: When the request was received
// - apiResponseTimestamp: When the API response was received
//
// Returns:
// - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
}
// LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
}
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
if !l.enabled && !force {
return nil
}
// Ensure logs directory exists
if errEnsure := l.ensureLogsDir(); errEnsure != nil {
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
}
// Generate filename with request ID
filename := l.generateFilename(url, requestID)
if force && !l.enabled {
filename = l.generateErrorFilename(url, requestID)
}
filePath := filepath.Join(l.logsDir, filename)
requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
if errTemp != nil {
log.WithError(errTemp).Warn("failed to create request body temp file, falling back to direct write")
}
if requestBodyPath != "" {
defer func() {
if errRemove := os.Remove(requestBodyPath); errRemove != nil {
log.WithError(errRemove).Warn("failed to remove request body temp file")
}
}()
}
responseToWrite, decompressErr := l.decompressResponse(responseHeaders, response)
if decompressErr != nil {
// If decompression fails, continue with original response and annotate the log output.
responseToWrite = response
}
logFile, errOpen := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if errOpen != nil {
return fmt.Errorf("failed to create log file: %w", errOpen)
}
writeErr := l.writeNonStreamingLog(
logFile,
url,
method,
requestHeaders,
body,
requestBodyPath,
apiRequest,
apiResponse,
apiResponseErrors,
statusCode,
responseHeaders,
responseToWrite,
decompressErr,
requestTimestamp,
apiResponseTimestamp,
)
if errClose := logFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request log file")
if writeErr == nil {
return errClose
}
}
if writeErr != nil {
return fmt.Errorf("failed to write log file: %w", writeErr)
}
if force && !l.enabled {
if errCleanup := l.cleanupOldErrorLogs(); errCleanup != nil {
log.WithError(errCleanup).Warn("failed to clean up old error logs")
}
}
return nil
}
// LogStreamingRequest initiates logging for a streaming request.
//
// Parameters:
// - url: The request URL
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - requestID: Optional request ID for log file naming
//
// Returns:
// - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) {
if !l.enabled {
return &NoOpStreamingLogWriter{}, nil
}
// Ensure logs directory exists
if err := l.ensureLogsDir(); err != nil {
return nil, fmt.Errorf("failed to create logs directory: %w", err)
}
// Generate filename with request ID
filename := l.generateFilename(url, requestID)
filePath := filepath.Join(l.logsDir, filename)
requestHeaders := make(map[string][]string, len(headers))
for key, values := range headers {
headerValues := make([]string, len(values))
copy(headerValues, values)
requestHeaders[key] = headerValues
}
requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
if errTemp != nil {
return nil, fmt.Errorf("failed to create request body temp file: %w", errTemp)
}
responseBodyFile, errCreate := os.CreateTemp(l.logsDir, "response-body-*.tmp")
if errCreate != nil {
_ = os.Remove(requestBodyPath)
return nil, fmt.Errorf("failed to create response body temp file: %w", errCreate)
}
responseBodyPath := responseBodyFile.Name()
// Create streaming writer
writer := &FileStreamingLogWriter{
logFilePath: filePath,
url: url,
method: method,
timestamp: time.Now(),
requestHeaders: requestHeaders,
requestBodyPath: requestBodyPath,
responseBodyPath: responseBodyPath,
responseBodyFile: responseBodyFile,
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
closeChan: make(chan struct{}),
errorChan: make(chan error, 1),
}
// Start async writer goroutine
go writer.asyncWriter()
return writer, nil
}
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string {
return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...))
}
// ensureLogsDir creates the logs directory if it doesn't exist.
//
// Returns:
// - error: An error if directory creation fails, nil otherwise
func (l *FileRequestLogger) ensureLogsDir() error {
if _, err := os.Stat(l.logsDir); os.IsNotExist(err) {
return os.MkdirAll(l.logsDir, 0755)
}
return nil
}
// generateFilename creates a sanitized filename from the URL path and current timestamp.
// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log
//
// Parameters:
// - url: The request URL
// - requestID: Optional request ID to include in filename
//
// Returns:
// - string: A sanitized filename for the log file
func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string {
// Extract path from URL
path := url
if strings.Contains(url, "?") {
path = strings.Split(url, "?")[0]
}
// Remove leading slash
if strings.HasPrefix(path, "/") {
path = path[1:]
}
// Sanitize path for filename
sanitized := l.sanitizeForFilename(path)
// Add timestamp
timestamp := time.Now().Format("2006-01-02T150405")
// Use request ID if provided, otherwise use sequential ID
var idPart string
if len(requestID) > 0 && requestID[0] != "" {
idPart = requestID[0]
} else {
id := requestLogID.Add(1)
idPart = fmt.Sprintf("%d", id)
}
return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart)
}
// sanitizeForFilename replaces characters that are not safe for filenames.
//
// Parameters:
// - path: The path to sanitize
//
// Returns:
// - string: A sanitized filename
func (l *FileRequestLogger) sanitizeForFilename(path string) string {
// Replace slashes with hyphens
sanitized := strings.ReplaceAll(path, "/", "-")
// Replace colons with hyphens
sanitized = strings.ReplaceAll(sanitized, ":", "-")
// Replace other problematic characters with hyphens
reg := regexp.MustCompile(`[<>:"|?*\s]`)
sanitized = reg.ReplaceAllString(sanitized, "-")
// Remove multiple consecutive hyphens
reg = regexp.MustCompile(`-+`)
sanitized = reg.ReplaceAllString(sanitized, "-")
// Remove leading/trailing hyphens
sanitized = strings.Trim(sanitized, "-")
// Handle empty result
if sanitized == "" {
sanitized = "root"
}
return sanitized
}
// cleanupOldErrorLogs keeps only the newest errorLogsMaxFiles forced error log files.
func (l *FileRequestLogger) cleanupOldErrorLogs() error {
if l.errorLogsMaxFiles <= 0 {
return nil
}
entries, errRead := os.ReadDir(l.logsDir)
if errRead != nil {
return errRead
}
type logFile struct {
name string
modTime time.Time
}
var files []logFile
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
continue
}
info, errInfo := entry.Info()
if errInfo != nil {
log.WithError(errInfo).Warn("failed to read error log info")
continue
}
files = append(files, logFile{name: name, modTime: info.ModTime()})
}
if len(files) <= l.errorLogsMaxFiles {
return nil
}
sort.Slice(files, func(i, j int) bool {
return files[i].modTime.After(files[j].modTime)
})
for _, file := range files[l.errorLogsMaxFiles:] {
if errRemove := os.Remove(filepath.Join(l.logsDir, file.name)); errRemove != nil {
log.WithError(errRemove).Warnf("failed to remove old error log: %s", file.name)
}
}
return nil
}
func (l *FileRequestLogger) writeRequestBodyTempFile(body []byte) (string, error) {
tmpFile, errCreate := os.CreateTemp(l.logsDir, "request-body-*.tmp")
if errCreate != nil {
return "", errCreate
}
tmpPath := tmpFile.Name()
if _, errCopy := io.Copy(tmpFile, bytes.NewReader(body)); errCopy != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", errCopy
}
if errClose := tmpFile.Close(); errClose != nil {
_ = os.Remove(tmpPath)
return "", errClose
}
return tmpPath, nil
}
func (l *FileRequestLogger) writeNonStreamingLog(
w io.Writer,
url, method string,
requestHeaders map[string][]string,
requestBody []byte,
requestBodyPath string,
apiRequest []byte,
apiResponse []byte,
apiResponseErrors []*interfaces.ErrorMessage,
statusCode int,
responseHeaders map[string][]string,
response []byte,
decompressErr error,
requestTimestamp time.Time,
apiResponseTimestamp time.Time,
) error {
if requestTimestamp.IsZero() {
requestTimestamp = time.Now()
}
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
return errWrite
}
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
}
func writeRequestInfoWithBody(
w io.Writer,
url, method string,
headers map[string][]string,
body []byte,
bodyPath string,
timestamp time.Time,
) error {
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Version: %s\n", buildinfo.Version)); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("URL: %s\n", url)); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "=== HEADERS ===\n"); errWrite != nil {
return errWrite
}
for key, values := range headers {
for _, value := range values {
masked := util.MaskSensitiveHeaderValue(key, value)
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, masked)); errWrite != nil {
return errWrite
}
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
return errWrite
}
if bodyPath != "" {
bodyFile, errOpen := os.Open(bodyPath)
if errOpen != nil {
return errOpen
}
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
_ = bodyFile.Close()
return errCopy
}
if errClose := bodyFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request body temp file")
}
} else if _, errWrite := w.Write(body); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
return errWrite
}
return nil
}
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
if len(payload) == 0 {
return nil
}
if bytes.HasPrefix(payload, []byte(sectionPrefix)) {
if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite
}
if !bytes.HasSuffix(payload, []byte("\n")) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
} else {
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
return errWrite
}
if !timestamp.IsZero() {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
return errWrite
}
}
if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
return nil
}
func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMessage) error {
for i := 0; i < len(apiResponseErrors); i++ {
if apiResponseErrors[i] == nil {
continue
}
if _, errWrite := io.WriteString(w, "=== API ERROR RESPONSE ===\n"); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
return errWrite
}
if apiResponseErrors[i].Error != nil {
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
return errWrite
}
}
return nil
}
func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, responseHeaders map[string][]string, responseReader io.Reader, decompressErr error, trailingNewline bool) error {
if _, errWrite := io.WriteString(w, "=== RESPONSE ===\n"); errWrite != nil {
return errWrite
}
if statusWritten {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Status: %d\n", statusCode)); errWrite != nil {
return errWrite
}
}
if responseHeaders != nil {
for key, values := range responseHeaders {
for _, value := range values {
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, value)); errWrite != nil {
return errWrite
}
}
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
if responseReader != nil {
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
return errCopy
}
}
if decompressErr != nil {
if _, errWrite := io.WriteString(w, fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", decompressErr)); errWrite != nil {
return errWrite
}
}
if trailingNewline {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
return nil
}
// formatLogContent creates the complete log content for non-streaming requests.
//
// Parameters:
// - url: The request URL
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - apiRequest: The API request data
// - apiResponse: The API response data
// - response: The raw response data
// - status: The response status code
// - responseHeaders: The response headers
//
// Returns:
// - string: The formatted log content
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string {
var content strings.Builder
// Request info
content.WriteString(l.formatRequestInfo(url, method, headers, body))
if len(apiRequest) > 0 {
if bytes.HasPrefix(apiRequest, []byte("=== API REQUEST")) {
content.Write(apiRequest)
if !bytes.HasSuffix(apiRequest, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API REQUEST ===\n")
content.Write(apiRequest)
content.WriteString("\n")
}
content.WriteString("\n")
}
for i := 0; i < len(apiResponseErrors); i++ {
content.WriteString("=== API ERROR RESPONSE ===\n")
content.WriteString(fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode))
content.WriteString(apiResponseErrors[i].Error.Error())
content.WriteString("\n\n")
}
if len(apiResponse) > 0 {
if bytes.HasPrefix(apiResponse, []byte("=== API RESPONSE")) {
content.Write(apiResponse)
if !bytes.HasSuffix(apiResponse, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API RESPONSE ===\n")
content.Write(apiResponse)
content.WriteString("\n")
}
content.WriteString("\n")
}
// Response section
content.WriteString("=== RESPONSE ===\n")
content.WriteString(fmt.Sprintf("Status: %d\n", status))
if responseHeaders != nil {
for key, values := range responseHeaders {
for _, value := range values {
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
}
}
}
content.WriteString("\n")
content.Write(response)
content.WriteString("\n")
return content.String()
}
// decompressResponse decompresses response data based on Content-Encoding header.
//
// Parameters:
// - responseHeaders: The response headers
// - response: The response data to decompress
//
// Returns:
// - []byte: The decompressed response data
// - error: An error if decompression fails, nil otherwise
func (l *FileRequestLogger) decompressResponse(responseHeaders map[string][]string, response []byte) ([]byte, error) {
if responseHeaders == nil || len(response) == 0 {
return response, nil
}
// Check Content-Encoding header
var contentEncoding string
for key, values := range responseHeaders {
if strings.ToLower(key) == "content-encoding" && len(values) > 0 {
contentEncoding = strings.ToLower(values[0])
break
}
}
switch contentEncoding {
case "gzip":
return l.decompressGzip(response)
case "deflate":
return l.decompressDeflate(response)
case "br":
return l.decompressBrotli(response)
case "zstd":
return l.decompressZstd(response)
default:
// No compression or unsupported compression
return response, nil
}
}
// decompressGzip decompresses gzip-encoded data.
//
// Parameters:
// - data: The gzip-encoded data to decompress
//
// Returns:
// - []byte: The decompressed data
// - error: An error if decompression fails, nil otherwise
func (l *FileRequestLogger) decompressGzip(data []byte) ([]byte, error) {
reader, err := gzip.NewReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
defer func() {
if errClose := reader.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close gzip reader in request logger")
}
}()
decompressed, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("failed to decompress gzip data: %w", err)
}
return decompressed, nil
}
// decompressDeflate decompresses deflate-encoded data.
//
// Parameters:
// - data: The deflate-encoded data to decompress
//
// Returns:
// - []byte: The decompressed data
// - error: An error if decompression fails, nil otherwise
func (l *FileRequestLogger) decompressDeflate(data []byte) ([]byte, error) {
reader := flate.NewReader(bytes.NewReader(data))
defer func() {
if errClose := reader.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close deflate reader in request logger")
}
}()
decompressed, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("failed to decompress deflate data: %w", err)
}
return decompressed, nil
}
// decompressBrotli decompresses brotli-encoded data.
//
// Parameters:
// - data: The brotli-encoded data to decompress
//
// Returns:
// - []byte: The decompressed data
// - error: An error if decompression fails, nil otherwise
func (l *FileRequestLogger) decompressBrotli(data []byte) ([]byte, error) {
reader := brotli.NewReader(bytes.NewReader(data))
decompressed, err := io.ReadAll(reader)
if err != nil {
return nil, fmt.Errorf("failed to decompress brotli data: %w", err)
}
return decompressed, nil
}
// decompressZstd decompresses zstd-encoded data.
//
// Parameters:
// - data: The zstd-encoded data to decompress
//
// Returns:
// - []byte: The decompressed data
// - error: An error if decompression fails, nil otherwise
func (l *FileRequestLogger) decompressZstd(data []byte) ([]byte, error) {
decoder, err := zstd.NewReader(bytes.NewReader(data))
if err != nil {
return nil, fmt.Errorf("failed to create zstd reader: %w", err)
}
defer decoder.Close()
decompressed, err := io.ReadAll(decoder)
if err != nil {
return nil, fmt.Errorf("failed to decompress zstd data: %w", err)
}
return decompressed, nil
}
// formatRequestInfo creates the request information section of the log.
//
// Parameters:
// - url: The request URL
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
//
// Returns:
// - string: The formatted request information
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string {
var content strings.Builder
content.WriteString("=== REQUEST INFO ===\n")
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
content.WriteString(fmt.Sprintf("URL: %s\n", url))
content.WriteString(fmt.Sprintf("Method: %s\n", method))
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
content.WriteString("\n")
content.WriteString("=== HEADERS ===\n")
for key, values := range headers {
for _, value := range values {
masked := util.MaskSensitiveHeaderValue(key, value)
content.WriteString(fmt.Sprintf("%s: %s\n", key, masked))
}
}
content.WriteString("\n")
content.WriteString("=== REQUEST BODY ===\n")
content.Write(body)
content.WriteString("\n\n")
return content.String()
}
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
// It spools streaming response chunks to a temporary file to avoid retaining large responses in memory.
// The final log file is assembled when Close is called.
type FileStreamingLogWriter struct {
// logFilePath is the final log file path.
logFilePath string
// url is the request URL (masked upstream in middleware).
url string
// method is the HTTP method.
method string
// timestamp is captured when the streaming log is initialized.
timestamp time.Time
// requestHeaders stores the request headers.
requestHeaders map[string][]string
// requestBodyPath is a temporary file path holding the request body.
requestBodyPath string
// responseBodyPath is a temporary file path holding the streaming response body.
responseBodyPath string
// responseBodyFile is the temp file where chunks are appended by the async writer.
responseBodyFile *os.File
// chunkChan is a channel for receiving response chunks to spool.
chunkChan chan []byte
// closeChan is a channel for signaling when the writer is closed.
closeChan chan struct{}
// errorChan is a channel for reporting errors during writing.
errorChan chan error
// responseStatus stores the HTTP status code.
responseStatus int
// statusWritten indicates whether a non-zero status was recorded.
statusWritten bool
// responseHeaders stores the response headers.
responseHeaders map[string][]string
// apiRequest stores the upstream API request data.
apiRequest []byte
// apiResponse stores the upstream API response data.
apiResponse []byte
// apiResponseTimestamp captures when the API response was received.
apiResponseTimestamp time.Time
}
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
//
// Parameters:
// - chunk: The response chunk to write
func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) {
if w.chunkChan == nil {
return
}
// Make a copy of the chunk to avoid data races
chunkCopy := make([]byte, len(chunk))
copy(chunkCopy, chunk)
// Non-blocking send
select {
case w.chunkChan <- chunkCopy:
default:
// Channel is full, skip this chunk to avoid blocking
}
}
// WriteStatus buffers the response status and headers for later writing.
//
// Parameters:
// - status: The response status code
// - headers: The response headers
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error {
if status == 0 {
return nil
}
w.responseStatus = status
if headers != nil {
w.responseHeaders = make(map[string][]string, len(headers))
for key, values := range headers {
headerValues := make([]string, len(values))
copy(headerValues, values)
w.responseHeaders[key] = headerValues
}
}
w.statusWritten = true
return nil
}
// WriteAPIRequest buffers the upstream API request details for later writing.
//
// Parameters:
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error {
if len(apiRequest) == 0 {
return nil
}
w.apiRequest = bytes.Clone(apiRequest)
return nil
}
// WriteAPIResponse buffers the upstream API response details for later writing.
//
// Parameters:
// - apiResponse: The API response data
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
if len(apiResponse) == 0 {
return nil
}
w.apiResponse = bytes.Clone(apiResponse)
return nil
}
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
if !timestamp.IsZero() {
w.apiResponseTimestamp = timestamp
}
}
// Close finalizes the log file and cleans up resources.
// It writes all buffered data to the file in the correct order:
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
//
// Returns:
// - error: An error if closing fails, nil otherwise
func (w *FileStreamingLogWriter) Close() error {
if w.chunkChan != nil {
close(w.chunkChan)
}
// Wait for async writer to finish spooling chunks
if w.closeChan != nil {
<-w.closeChan
w.chunkChan = nil
}
select {
case errWrite := <-w.errorChan:
w.cleanupTempFiles()
return errWrite
default:
}
if w.logFilePath == "" {
w.cleanupTempFiles()
return nil
}
logFile, errOpen := os.OpenFile(w.logFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if errOpen != nil {
w.cleanupTempFiles()
return fmt.Errorf("failed to create log file: %w", errOpen)
}
writeErr := w.writeFinalLog(logFile)
if errClose := logFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request log file")
if writeErr == nil {
writeErr = errClose
}
}
w.cleanupTempFiles()
return writeErr
}
// asyncWriter runs in a goroutine to buffer chunks from the channel.
// It continuously reads chunks from the channel and appends them to a temp file for later assembly.
func (w *FileStreamingLogWriter) asyncWriter() {
defer close(w.closeChan)
for chunk := range w.chunkChan {
if w.responseBodyFile == nil {
continue
}
if _, errWrite := w.responseBodyFile.Write(chunk); errWrite != nil {
select {
case w.errorChan <- errWrite:
default:
}
if errClose := w.responseBodyFile.Close(); errClose != nil {
select {
case w.errorChan <- errClose:
default:
}
}
w.responseBodyFile = nil
}
}
if w.responseBodyFile == nil {
return
}
if errClose := w.responseBodyFile.Close(); errClose != nil {
select {
case w.errorChan <- errClose:
default:
}
}
w.responseBodyFile = nil
}
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse, w.apiResponseTimestamp); errWrite != nil {
return errWrite
}
responseBodyFile, errOpen := os.Open(w.responseBodyPath)
if errOpen != nil {
return errOpen
}
defer func() {
if errClose := responseBodyFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close response body temp file")
}
}()
return writeResponseSection(logFile, w.responseStatus, w.statusWritten, w.responseHeaders, responseBodyFile, nil, false)
}
func (w *FileStreamingLogWriter) cleanupTempFiles() {
if w.requestBodyPath != "" {
if errRemove := os.Remove(w.requestBodyPath); errRemove != nil {
log.WithError(errRemove).Warn("failed to remove request body temp file")
}
w.requestBodyPath = ""
}
if w.responseBodyPath != "" {
if errRemove := os.Remove(w.responseBodyPath); errRemove != nil {
log.WithError(errRemove).Warn("failed to remove response body temp file")
}
w.responseBodyPath = ""
}
}
// NoOpStreamingLogWriter is a no-operation implementation for when logging is disabled.
// It implements the StreamingLogWriter interface but performs no actual logging operations.
type NoOpStreamingLogWriter struct{}
// WriteChunkAsync is a no-op implementation that does nothing.
//
// Parameters:
// - chunk: The response chunk (ignored)
func (w *NoOpStreamingLogWriter) WriteChunkAsync(_ []byte) {}
// WriteStatus is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - status: The response status code (ignored)
// - headers: The response headers (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error {
return nil
}
// WriteAPIRequest is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiRequest: The API request data (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error {
return nil
}
// WriteAPIResponse is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiResponse: The API response data (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
return nil
}
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
// Close is a no-op implementation that does nothing and always returns nil.
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) Close() error { return nil }
================================================
FILE: internal/logging/requestid.go
================================================
package logging
import (
"context"
"crypto/rand"
"encoding/hex"
"github.com/gin-gonic/gin"
)
// requestIDKey is the context key for storing/retrieving request IDs.
type requestIDKey struct{}
// ginRequestIDKey is the Gin context key for request IDs.
const ginRequestIDKey = "__request_id__"
// GenerateRequestID creates a new 8-character hex request ID.
func GenerateRequestID() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
return "00000000"
}
return hex.EncodeToString(b)
}
// WithRequestID returns a new context with the request ID attached.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID)
}
// GetRequestID retrieves the request ID from the context.
// Returns empty string if not found.
func GetRequestID(ctx context.Context) string {
if ctx == nil {
return ""
}
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
return id
}
return ""
}
// SetGinRequestID stores the request ID in the Gin context.
func SetGinRequestID(c *gin.Context, requestID string) {
if c != nil {
c.Set(ginRequestIDKey, requestID)
}
}
// GetGinRequestID retrieves the request ID from the Gin context.
func GetGinRequestID(c *gin.Context) string {
if c == nil {
return ""
}
if id, exists := c.Get(ginRequestIDKey); exists {
if s, ok := id.(string); ok {
return s
}
}
return ""
}
================================================
FILE: internal/managementasset/updater.go
================================================
package managementasset
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
)
const (
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
defaultManagementFallbackURL = "https://cpamc.router-for.me/"
managementAssetName = "management.html"
httpUserAgent = "CLIProxyAPI-management-updater"
managementSyncMinInterval = 30 * time.Second
updateCheckInterval = 3 * time.Hour
)
// ManagementFileName exposes the control panel asset filename.
const ManagementFileName = managementAssetName
var (
lastUpdateCheckMu sync.Mutex
lastUpdateCheckTime time.Time
currentConfigPtr atomic.Pointer[config.Config]
schedulerOnce sync.Once
schedulerConfigPath atomic.Value
sfGroup singleflight.Group
)
// SetCurrentConfig stores the latest configuration snapshot for management asset decisions.
func SetCurrentConfig(cfg *config.Config) {
if cfg == nil {
currentConfigPtr.Store(nil)
return
}
currentConfigPtr.Store(cfg)
}
// StartAutoUpdater launches a background goroutine that periodically ensures the management asset is up to date.
// It respects the disable-control-panel flag on every iteration and supports hot-reloaded configurations.
func StartAutoUpdater(ctx context.Context, configFilePath string) {
configFilePath = strings.TrimSpace(configFilePath)
if configFilePath == "" {
log.Debug("management asset auto-updater skipped: empty config path")
return
}
schedulerConfigPath.Store(configFilePath)
schedulerOnce.Do(func() {
go runAutoUpdater(ctx)
})
}
func runAutoUpdater(ctx context.Context) {
if ctx == nil {
ctx = context.Background()
}
ticker := time.NewTicker(updateCheckInterval)
defer ticker.Stop()
runOnce := func() {
cfg := currentConfigPtr.Load()
if cfg == nil {
log.Debug("management asset auto-updater skipped: config not yet available")
return
}
if cfg.RemoteManagement.DisableControlPanel {
log.Debug("management asset auto-updater skipped: control panel disabled")
return
}
configPath, _ := schedulerConfigPath.Load().(string)
staticDir := StaticDir(configPath)
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
}
runOnce()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
runOnce()
}
}
}
func newHTTPClient(proxyURL string) *http.Client {
client := &http.Client{Timeout: 15 * time.Second}
sdkCfg := &sdkconfig.SDKConfig{ProxyURL: strings.TrimSpace(proxyURL)}
util.SetProxy(sdkCfg, client)
return client
}
type releaseAsset struct {
Name string `json:"name"`
BrowserDownloadURL string `json:"browser_download_url"`
Digest string `json:"digest"`
}
type releaseResponse struct {
Assets []releaseAsset `json:"assets"`
}
// StaticDir resolves the directory that stores the management control panel asset.
func StaticDir(configFilePath string) string {
if override := strings.TrimSpace(os.Getenv("MANAGEMENT_STATIC_PATH")); override != "" {
cleaned := filepath.Clean(override)
if strings.EqualFold(filepath.Base(cleaned), managementAssetName) {
return filepath.Dir(cleaned)
}
return cleaned
}
if writable := util.WritablePath(); writable != "" {
return filepath.Join(writable, "static")
}
configFilePath = strings.TrimSpace(configFilePath)
if configFilePath == "" {
return ""
}
base := filepath.Dir(configFilePath)
fileInfo, err := os.Stat(configFilePath)
if err == nil {
if fileInfo.IsDir() {
base = configFilePath
}
}
return filepath.Join(base, "static")
}
// FilePath resolves the absolute path to the management control panel asset.
func FilePath(configFilePath string) string {
if override := strings.TrimSpace(os.Getenv("MANAGEMENT_STATIC_PATH")); override != "" {
cleaned := filepath.Clean(override)
if strings.EqualFold(filepath.Base(cleaned), managementAssetName) {
return cleaned
}
return filepath.Join(cleaned, ManagementFileName)
}
dir := StaticDir(configFilePath)
if dir == "" {
return ""
}
return filepath.Join(dir, ManagementFileName)
}
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
// It coalesces concurrent sync attempts and returns whether the asset exists after the sync attempt.
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) bool {
if ctx == nil {
ctx = context.Background()
}
staticDir = strings.TrimSpace(staticDir)
if staticDir == "" {
log.Debug("management asset sync skipped: empty static directory")
return false
}
localPath := filepath.Join(staticDir, managementAssetName)
_, _, _ = sfGroup.Do(localPath, func() (interface{}, error) {
lastUpdateCheckMu.Lock()
now := time.Now()
timeSinceLastAttempt := now.Sub(lastUpdateCheckTime)
if !lastUpdateCheckTime.IsZero() && timeSinceLastAttempt < managementSyncMinInterval {
lastUpdateCheckMu.Unlock()
log.Debugf(
"management asset sync skipped by throttle: last attempt %v ago (interval %v)",
timeSinceLastAttempt.Round(time.Second),
managementSyncMinInterval,
)
return nil, nil
}
lastUpdateCheckTime = now
lastUpdateCheckMu.Unlock()
localFileMissing := false
if _, errStat := os.Stat(localPath); errStat != nil {
if errors.Is(errStat, os.ErrNotExist) {
localFileMissing = true
} else {
log.WithError(errStat).Debug("failed to stat local management asset")
}
}
if errMkdirAll := os.MkdirAll(staticDir, 0o755); errMkdirAll != nil {
log.WithError(errMkdirAll).Warn("failed to prepare static directory for management asset")
return nil, nil
}
releaseURL := resolveReleaseURL(panelRepository)
client := newHTTPClient(proxyURL)
localHash, err := fileSHA256(localPath)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
log.WithError(err).Debug("failed to read local management asset hash")
}
localHash = ""
}
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to fetch latest management release information, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return nil, nil
}
return nil, nil
}
log.WithError(err).Warn("failed to fetch latest management release information")
return nil, nil
}
if remoteHash != "" && localHash != "" && strings.EqualFold(remoteHash, localHash) {
log.Debug("management asset is already up to date")
return nil, nil
}
data, downloadedHash, err := downloadAsset(ctx, client, asset.BrowserDownloadURL)
if err != nil {
if localFileMissing {
log.WithError(err).Warn("failed to download management asset, trying fallback page")
if ensureFallbackManagementHTML(ctx, client, localPath) {
return nil, nil
}
return nil, nil
}
log.WithError(err).Warn("failed to download management asset")
return nil, nil
}
if remoteHash != "" && !strings.EqualFold(remoteHash, downloadedHash) {
log.Warnf("remote digest mismatch for management asset: expected %s got %s", remoteHash, downloadedHash)
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to update management asset on disk")
return nil, nil
}
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
return nil, nil
})
_, err := os.Stat(localPath)
return err == nil
}
func ensureFallbackManagementHTML(ctx context.Context, client *http.Client, localPath string) bool {
data, downloadedHash, err := downloadAsset(ctx, client, defaultManagementFallbackURL)
if err != nil {
log.WithError(err).Warn("failed to download fallback management control panel page")
return false
}
if err = atomicWriteFile(localPath, data); err != nil {
log.WithError(err).Warn("failed to persist fallback management control panel page")
return false
}
log.Infof("management asset updated from fallback page successfully (hash=%s)", downloadedHash)
return true
}
func resolveReleaseURL(repo string) string {
repo = strings.TrimSpace(repo)
if repo == "" {
return defaultManagementReleaseURL
}
parsed, err := url.Parse(repo)
if err != nil || parsed.Host == "" {
return defaultManagementReleaseURL
}
host := strings.ToLower(parsed.Host)
parsed.Path = strings.TrimSuffix(parsed.Path, "/")
if host == "api.github.com" {
if !strings.HasSuffix(strings.ToLower(parsed.Path), "/releases/latest") {
parsed.Path = parsed.Path + "/releases/latest"
}
return parsed.String()
}
if host == "github.com" {
parts := strings.Split(strings.Trim(parsed.Path, "/"), "/")
if len(parts) >= 2 && parts[0] != "" && parts[1] != "" {
repoName := strings.TrimSuffix(parts[1], ".git")
return fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", parts[0], repoName)
}
}
return defaultManagementReleaseURL
}
func fetchLatestAsset(ctx context.Context, client *http.Client, releaseURL string) (*releaseAsset, string, error) {
if strings.TrimSpace(releaseURL) == "" {
releaseURL = defaultManagementReleaseURL
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil)
if err != nil {
return nil, "", fmt.Errorf("create release request: %w", err)
}
req.Header.Set("Accept", "application/vnd.github+json")
req.Header.Set("User-Agent", httpUserAgent)
gitURL := strings.ToLower(strings.TrimSpace(os.Getenv("GITSTORE_GIT_URL")))
if tok := strings.TrimSpace(os.Getenv("GITSTORE_GIT_TOKEN")); tok != "" && strings.Contains(gitURL, "github.com") {
req.Header.Set("Authorization", "Bearer "+tok)
}
resp, err := client.Do(req)
if err != nil {
return nil, "", fmt.Errorf("execute release request: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, "", fmt.Errorf("unexpected release status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
var release releaseResponse
if err = json.NewDecoder(resp.Body).Decode(&release); err != nil {
return nil, "", fmt.Errorf("decode release response: %w", err)
}
for i := range release.Assets {
asset := &release.Assets[i]
if strings.EqualFold(asset.Name, managementAssetName) {
remoteHash := parseDigest(asset.Digest)
return asset, remoteHash, nil
}
}
return nil, "", fmt.Errorf("management asset %s not found in latest release", managementAssetName)
}
func downloadAsset(ctx context.Context, client *http.Client, downloadURL string) ([]byte, string, error) {
if strings.TrimSpace(downloadURL) == "" {
return nil, "", fmt.Errorf("empty download url")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, downloadURL, nil)
if err != nil {
return nil, "", fmt.Errorf("create download request: %w", err)
}
req.Header.Set("User-Agent", httpUserAgent)
resp, err := client.Do(req)
if err != nil {
return nil, "", fmt.Errorf("execute download request: %w", err)
}
defer func() {
_ = resp.Body.Close()
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(io.LimitReader(resp.Body, 1024))
return nil, "", fmt.Errorf("unexpected download status %d: %s", resp.StatusCode, strings.TrimSpace(string(body)))
}
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, "", fmt.Errorf("read download body: %w", err)
}
sum := sha256.Sum256(data)
return data, hex.EncodeToString(sum[:]), nil
}
func fileSHA256(path string) (string, error) {
file, err := os.Open(path)
if err != nil {
return "", err
}
defer func() {
_ = file.Close()
}()
h := sha256.New()
if _, err = io.Copy(h, file); err != nil {
return "", err
}
return hex.EncodeToString(h.Sum(nil)), nil
}
func atomicWriteFile(path string, data []byte) error {
tmpFile, err := os.CreateTemp(filepath.Dir(path), "management-*.html")
if err != nil {
return err
}
tmpName := tmpFile.Name()
defer func() {
_ = tmpFile.Close()
_ = os.Remove(tmpName)
}()
if _, err = tmpFile.Write(data); err != nil {
return err
}
if err = tmpFile.Chmod(0o644); err != nil {
return err
}
if err = tmpFile.Close(); err != nil {
return err
}
if err = os.Rename(tmpName, path); err != nil {
return err
}
return nil
}
func parseDigest(digest string) string {
digest = strings.TrimSpace(digest)
if digest == "" {
return ""
}
if idx := strings.Index(digest, ":"); idx >= 0 {
digest = digest[idx+1:]
}
return strings.ToLower(strings.TrimSpace(digest))
}
================================================
FILE: internal/misc/claude_code_instructions.go
================================================
// Package misc provides miscellaneous utility functions and embedded data for the CLI Proxy API.
// This package contains general-purpose helpers and embedded resources that do not fit into
// more specific domain packages. It includes embedded instructional text for Claude Code-related operations.
package misc
import _ "embed"
// ClaudeCodeInstructions holds the content of the claude_code_instructions.txt file,
// which is embedded into the application binary at compile time. This variable
// contains specific instructions for Claude Code model interactions and code generation guidance.
//
//go:embed claude_code_instructions.txt
var ClaudeCodeInstructions string
================================================
FILE: internal/misc/claude_code_instructions.txt
================================================
[{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK.","cache_control":{"type":"ephemeral","ttl":"1h"}}]
================================================
FILE: internal/misc/copy-example-config.go
================================================
package misc
import (
"io"
"os"
"path/filepath"
log "github.com/sirupsen/logrus"
)
func CopyConfigTemplate(src, dst string) error {
in, err := os.Open(src)
if err != nil {
return err
}
defer func() {
if errClose := in.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close source config file")
}
}()
if err = os.MkdirAll(filepath.Dir(dst), 0o700); err != nil {
return err
}
out, err := os.OpenFile(dst, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600)
if err != nil {
return err
}
defer func() {
if errClose := out.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close destination config file")
}
}()
if _, err = io.Copy(out, in); err != nil {
return err
}
return out.Sync()
}
================================================
FILE: internal/misc/credentials.go
================================================
package misc
import (
"encoding/json"
"fmt"
"path/filepath"
"strings"
log "github.com/sirupsen/logrus"
)
// Separator used to visually group related log lines.
var credentialSeparator = strings.Repeat("-", 67)
// LogSavingCredentials emits a consistent log message when persisting auth material.
func LogSavingCredentials(path string) {
if path == "" {
return
}
// Use filepath.Clean so logs remain stable even if callers pass redundant separators.
fmt.Printf("Saving credentials to %s\n", filepath.Clean(path))
}
// LogCredentialSeparator adds a visual separator to group auth/key processing logs.
func LogCredentialSeparator() {
log.Debug(credentialSeparator)
}
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
var data map[string]any
// Fast path: if source is already a map, just copy it to avoid mutation of original
if srcMap, ok := source.(map[string]any); ok {
data = make(map[string]any, len(srcMap)+len(metadata))
for k, v := range srcMap {
data[k] = v
}
} else {
// Slow path: marshal to JSON and back to map to respect JSON tags
temp, err := json.Marshal(source)
if err != nil {
return nil, fmt.Errorf("failed to marshal source: %w", err)
}
if err := json.Unmarshal(temp, &data); err != nil {
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
}
}
// Merge extra metadata
if metadata != nil {
if data == nil {
data = make(map[string]any)
}
for k, v := range metadata {
data[k] = v
}
}
return data, nil
}
================================================
FILE: internal/misc/header_utils.go
================================================
// Package misc provides miscellaneous utility functions for the CLI Proxy API server.
// It includes helper functions for HTTP header manipulation and other common operations
// that don't fit into more specific packages.
package misc
import (
"fmt"
"net/http"
"runtime"
"strings"
)
const (
// GeminiCLIVersion is the version string reported in the User-Agent for upstream requests.
GeminiCLIVersion = "0.31.0"
// GeminiCLIApiClientHeader is the value for the X-Goog-Api-Client header sent to the Gemini CLI upstream.
GeminiCLIApiClientHeader = "google-genai-sdk/1.41.0 gl-node/v22.19.0"
)
// geminiCLIOS maps Go runtime OS names to the Node.js-style platform strings used by Gemini CLI.
func geminiCLIOS() string {
switch runtime.GOOS {
case "windows":
return "win32"
default:
return runtime.GOOS
}
}
// geminiCLIArch maps Go runtime architecture names to the Node.js-style arch strings used by Gemini CLI.
func geminiCLIArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "386":
return "x86"
default:
return runtime.GOARCH
}
}
// GeminiCLIUserAgent returns a User-Agent string that matches the Gemini CLI format.
// The model parameter is included in the UA; pass "" or "unknown" when the model is not applicable.
func GeminiCLIUserAgent(model string) string {
if model == "" {
model = "unknown"
}
return fmt.Sprintf("GeminiCLI/%s/%s (%s; %s)", GeminiCLIVersion, model, geminiCLIOS(), geminiCLIArch())
}
// ScrubProxyAndFingerprintHeaders removes all headers that could reveal
// proxy infrastructure, client identity, or browser fingerprints from an
// outgoing request. This ensures requests to upstream services look like they
// originate directly from a native client rather than a third-party client
// behind a reverse proxy.
func ScrubProxyAndFingerprintHeaders(req *http.Request) {
if req == nil {
return
}
// --- Proxy tracing headers ---
req.Header.Del("X-Forwarded-For")
req.Header.Del("X-Forwarded-Host")
req.Header.Del("X-Forwarded-Proto")
req.Header.Del("X-Forwarded-Port")
req.Header.Del("X-Real-IP")
req.Header.Del("Forwarded")
req.Header.Del("Via")
// --- Client identity headers ---
req.Header.Del("X-Title")
req.Header.Del("X-Stainless-Lang")
req.Header.Del("X-Stainless-Package-Version")
req.Header.Del("X-Stainless-Os")
req.Header.Del("X-Stainless-Arch")
req.Header.Del("X-Stainless-Runtime")
req.Header.Del("X-Stainless-Runtime-Version")
req.Header.Del("Http-Referer")
req.Header.Del("Referer")
// --- Browser / Chromium fingerprint headers ---
// These are sent by Electron-based clients (e.g. CherryStudio) using the
// Fetch API, but NOT by Node.js https module (which Antigravity uses).
req.Header.Del("Sec-Ch-Ua")
req.Header.Del("Sec-Ch-Ua-Mobile")
req.Header.Del("Sec-Ch-Ua-Platform")
req.Header.Del("Sec-Fetch-Mode")
req.Header.Del("Sec-Fetch-Site")
req.Header.Del("Sec-Fetch-Dest")
req.Header.Del("Priority")
// --- Encoding negotiation ---
// Antigravity (Node.js) sends "gzip, deflate, br" by default;
// Electron-based clients may add "zstd" which is a fingerprint mismatch.
req.Header.Del("Accept-Encoding")
}
// EnsureHeader ensures that a header exists in the target header map by checking
// multiple sources in order of priority: source headers, existing target headers,
// and finally the default value. It only sets the header if it's not already present
// and the value is not empty after trimming whitespace.
//
// Parameters:
// - target: The target header map to modify
// - source: The source header map to check first (can be nil)
// - key: The header key to ensure
// - defaultValue: The default value to use if no other source provides a value
func EnsureHeader(target http.Header, source http.Header, key, defaultValue string) {
if target == nil {
return
}
if source != nil {
if val := strings.TrimSpace(source.Get(key)); val != "" {
target.Set(key, val)
return
}
}
if strings.TrimSpace(target.Get(key)) != "" {
return
}
if val := strings.TrimSpace(defaultValue); val != "" {
target.Set(key, val)
}
}
================================================
FILE: internal/misc/mime-type.go
================================================
// Package misc provides miscellaneous utility functions and embedded data for the CLI Proxy API.
// This package contains general-purpose helpers and embedded resources that do not fit into
// more specific domain packages. It includes a comprehensive MIME type mapping for file operations.
package misc
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
// This map is used to determine the Content-Type header for file uploads and other
// operations where the MIME type needs to be identified from a file extension.
// The list is extensive to cover a wide range of common and uncommon file formats.
var MimeTypes = map[string]string{
"ez": "application/andrew-inset",
"aw": "application/applixware",
"atom": "application/atom+xml",
"atomcat": "application/atomcat+xml",
"atomsvc": "application/atomsvc+xml",
"ccxml": "application/ccxml+xml",
"cdmia": "application/cdmi-capability",
"cdmic": "application/cdmi-container",
"cdmid": "application/cdmi-domain",
"cdmio": "application/cdmi-object",
"cdmiq": "application/cdmi-queue",
"cu": "application/cu-seeme",
"davmount": "application/davmount+xml",
"dbk": "application/docbook+xml",
"dssc": "application/dssc+der",
"xdssc": "application/dssc+xml",
"ecma": "application/ecmascript",
"emma": "application/emma+xml",
"epub": "application/epub+zip",
"exi": "application/exi",
"pfr": "application/font-tdpfr",
"gml": "application/gml+xml",
"gpx": "application/gpx+xml",
"gxf": "application/gxf",
"stk": "application/hyperstudio",
"ink": "application/inkml+xml",
"ipfix": "application/ipfix",
"jar": "application/java-archive",
"ser": "application/java-serialized-object",
"class": "application/java-vm",
"js": "application/javascript",
"json": "application/json",
"jsonml": "application/jsonml+json",
"lostxml": "application/lost+xml",
"hqx": "application/mac-binhex40",
"cpt": "application/mac-compactpro",
"mads": "application/mads+xml",
"mrc": "application/marc",
"mrcx": "application/marcxml+xml",
"ma": "application/mathematica",
"mathml": "application/mathml+xml",
"mbox": "application/mbox",
"mscml": "application/mediaservercontrol+xml",
"metalink": "application/metalink+xml",
"meta4": "application/metalink4+xml",
"mets": "application/mets+xml",
"mods": "application/mods+xml",
"m21": "application/mp21",
"mp4s": "application/mp4",
"doc": "application/msword",
"mxf": "application/mxf",
"bin": "application/octet-stream",
"oda": "application/oda",
"opf": "application/oebps-package+xml",
"ogx": "application/ogg",
"omdoc": "application/omdoc+xml",
"onepkg": "application/onenote",
"oxps": "application/oxps",
"xer": "application/patch-ops-error+xml",
"pdf": "application/pdf",
"pgp": "application/pgp-encrypted",
"asc": "application/pgp-signature",
"prf": "application/pics-rules",
"p10": "application/pkcs10",
"p7c": "application/pkcs7-mime",
"p7s": "application/pkcs7-signature",
"p8": "application/pkcs8",
"ac": "application/pkix-attr-cert",
"cer": "application/pkix-cert",
"crl": "application/pkix-crl",
"pkipath": "application/pkix-pkipath",
"pki": "application/pkixcmp",
"pls": "application/pls+xml",
"ai": "application/postscript",
"cww": "application/prs.cww",
"pskcxml": "application/pskc+xml",
"rdf": "application/rdf+xml",
"rif": "application/reginfo+xml",
"rnc": "application/relax-ng-compact-syntax",
"rld": "application/resource-lists-diff+xml",
"rl": "application/resource-lists+xml",
"rs": "application/rls-services+xml",
"gbr": "application/rpki-ghostbusters",
"mft": "application/rpki-manifest",
"roa": "application/rpki-roa",
"rsd": "application/rsd+xml",
"rss": "application/rss+xml",
"rtf": "application/rtf",
"sbml": "application/sbml+xml",
"scq": "application/scvp-cv-request",
"scs": "application/scvp-cv-response",
"spq": "application/scvp-vp-request",
"spp": "application/scvp-vp-response",
"sdp": "application/sdp",
"setpay": "application/set-payment-initiation",
"setreg": "application/set-registration-initiation",
"shf": "application/shf+xml",
"smi": "application/smil+xml",
"rq": "application/sparql-query",
"srx": "application/sparql-results+xml",
"gram": "application/srgs",
"grxml": "application/srgs+xml",
"sru": "application/sru+xml",
"ssdl": "application/ssdl+xml",
"ssml": "application/ssml+xml",
"tei": "application/tei+xml",
"tfi": "application/thraud+xml",
"tsd": "application/timestamped-data",
"plb": "application/vnd.3gpp.pic-bw-large",
"psb": "application/vnd.3gpp.pic-bw-small",
"pvb": "application/vnd.3gpp.pic-bw-var",
"tcap": "application/vnd.3gpp2.tcap",
"pwn": "application/vnd.3m.post-it-notes",
"aso": "application/vnd.accpac.simply.aso",
"imp": "application/vnd.accpac.simply.imp",
"acu": "application/vnd.acucobol",
"acutc": "application/vnd.acucorp",
"air": "application/vnd.adobe.air-application-installer-package+zip",
"fcdt": "application/vnd.adobe.formscentral.fcdt",
"fxp": "application/vnd.adobe.fxp",
"xdp": "application/vnd.adobe.xdp+xml",
"xfdf": "application/vnd.adobe.xfdf",
"ahead": "application/vnd.ahead.space",
"azf": "application/vnd.airzip.filesecure.azf",
"azs": "application/vnd.airzip.filesecure.azs",
"azw": "application/vnd.amazon.ebook",
"acc": "application/vnd.americandynamics.acc",
"ami": "application/vnd.amiga.ami",
"apk": "application/vnd.android.package-archive",
"cii": "application/vnd.anser-web-certificate-issue-initiation",
"fti": "application/vnd.anser-web-funds-transfer-initiation",
"atx": "application/vnd.antix.game-component",
"mpkg": "application/vnd.apple.installer+xml",
"m3u8": "application/vnd.apple.mpegurl",
"swi": "application/vnd.aristanetworks.swi",
"iota": "application/vnd.astraea-software.iota",
"aep": "application/vnd.audiograph",
"mpm": "application/vnd.blueice.multipass",
"bmi": "application/vnd.bmi",
"rep": "application/vnd.businessobjects",
"cdxml": "application/vnd.chemdraw+xml",
"mmd": "application/vnd.chipnuts.karaoke-mmd",
"cdy": "application/vnd.cinderella",
"cla": "application/vnd.claymore",
"rp9": "application/vnd.cloanto.rp9",
"c4d": "application/vnd.clonk.c4group",
"c11amc": "application/vnd.cluetrust.cartomobile-config",
"c11amz": "application/vnd.cluetrust.cartomobile-config-pkg",
"csp": "application/vnd.commonspace",
"cdbcmsg": "application/vnd.contact.cmsg",
"cmc": "application/vnd.cosmocaller",
"clkx": "application/vnd.crick.clicker",
"clkk": "application/vnd.crick.clicker.keyboard",
"clkp": "application/vnd.crick.clicker.palette",
"clkt": "application/vnd.crick.clicker.template",
"clkw": "application/vnd.crick.clicker.wordbank",
"wbs": "application/vnd.criticaltools.wbs+xml",
"pml": "application/vnd.ctc-posml",
"ppd": "application/vnd.cups-ppd",
"car": "application/vnd.curl.car",
"pcurl": "application/vnd.curl.pcurl",
"dart": "application/vnd.dart",
"rdz": "application/vnd.data-vision.rdz",
"uvd": "application/vnd.dece.data",
"fe_launch": "application/vnd.denovo.fcselayout-link",
"dna": "application/vnd.dna",
"mlp": "application/vnd.dolby.mlp",
"dpg": "application/vnd.dpgraph",
"dfac": "application/vnd.dreamfactory",
"kpxx": "application/vnd.ds-keypoint",
"ait": "application/vnd.dvb.ait",
"svc": "application/vnd.dvb.service",
"geo": "application/vnd.dynageo",
"mag": "application/vnd.ecowin.chart",
"nml": "application/vnd.enliven",
"esf": "application/vnd.epson.esf",
"msf": "application/vnd.epson.msf",
"qam": "application/vnd.epson.quickanime",
"slt": "application/vnd.epson.salt",
"ssf": "application/vnd.epson.ssf",
"es3": "application/vnd.eszigno3+xml",
"ez2": "application/vnd.ezpix-album",
"ez3": "application/vnd.ezpix-package",
"fdf": "application/vnd.fdf",
"mseed": "application/vnd.fdsn.mseed",
"dataless": "application/vnd.fdsn.seed",
"gph": "application/vnd.flographit",
"ftc": "application/vnd.fluxtime.clip",
"book": "application/vnd.framemaker",
"fnc": "application/vnd.frogans.fnc",
"ltf": "application/vnd.frogans.ltf",
"fsc": "application/vnd.fsc.weblaunch",
"oas": "application/vnd.fujitsu.oasys",
"oa2": "application/vnd.fujitsu.oasys2",
"oa3": "application/vnd.fujitsu.oasys3",
"fg5": "application/vnd.fujitsu.oasysgp",
"bh2": "application/vnd.fujitsu.oasysprs",
"ddd": "application/vnd.fujixerox.ddd",
"xdw": "application/vnd.fujixerox.docuworks",
"xbd": "application/vnd.fujixerox.docuworks.binder",
"fzs": "application/vnd.fuzzysheet",
"txd": "application/vnd.genomatix.tuxedo",
"ggb": "application/vnd.geogebra.file",
"ggt": "application/vnd.geogebra.tool",
"gex": "application/vnd.geometry-explorer",
"gxt": "application/vnd.geonext",
"g2w": "application/vnd.geoplan",
"g3w": "application/vnd.geospace",
"gmx": "application/vnd.gmx",
"kml": "application/vnd.google-earth.kml+xml",
"kmz": "application/vnd.google-earth.kmz",
"gqf": "application/vnd.grafeq",
"gac": "application/vnd.groove-account",
"ghf": "application/vnd.groove-help",
"gim": "application/vnd.groove-identity-message",
"grv": "application/vnd.groove-injector",
"gtm": "application/vnd.groove-tool-message",
"tpl": "application/vnd.groove-tool-template",
"vcg": "application/vnd.groove-vcard",
"hal": "application/vnd.hal+xml",
"zmm": "application/vnd.handheld-entertainment+xml",
"hbci": "application/vnd.hbci",
"les": "application/vnd.hhe.lesson-player",
"hpgl": "application/vnd.hp-hpgl",
"hpid": "application/vnd.hp-hpid",
"hps": "application/vnd.hp-hps",
"jlt": "application/vnd.hp-jlyt",
"pcl": "application/vnd.hp-pcl",
"pclxl": "application/vnd.hp-pclxl",
"sfd-hdstx": "application/vnd.hydrostatix.sof-data",
"mpy": "application/vnd.ibm.minipay",
"afp": "application/vnd.ibm.modcap",
"irm": "application/vnd.ibm.rights-management",
"sc": "application/vnd.ibm.secure-container",
"icc": "application/vnd.iccprofile",
"igl": "application/vnd.igloader",
"ivp": "application/vnd.immervision-ivp",
"ivu": "application/vnd.immervision-ivu",
"igm": "application/vnd.insors.igm",
"xpw": "application/vnd.intercon.formnet",
"i2g": "application/vnd.intergeo",
"qbo": "application/vnd.intu.qbo",
"qfx": "application/vnd.intu.qfx",
"rcprofile": "application/vnd.ipunplugged.rcprofile",
"irp": "application/vnd.irepository.package+xml",
"xpr": "application/vnd.is-xpr",
"fcs": "application/vnd.isac.fcs",
"jam": "application/vnd.jam",
"rms": "application/vnd.jcp.javame.midlet-rms",
"jisp": "application/vnd.jisp",
"joda": "application/vnd.joost.joda-archive",
"ktr": "application/vnd.kahootz",
"karbon": "application/vnd.kde.karbon",
"chrt": "application/vnd.kde.kchart",
"kfo": "application/vnd.kde.kformula",
"flw": "application/vnd.kde.kivio",
"kon": "application/vnd.kde.kontour",
"kpr": "application/vnd.kde.kpresenter",
"ksp": "application/vnd.kde.kspread",
"kwd": "application/vnd.kde.kword",
"htke": "application/vnd.kenameaapp",
"kia": "application/vnd.kidspiration",
"kne": "application/vnd.kinar",
"skd": "application/vnd.koan",
"sse": "application/vnd.kodak-descriptor",
"lasxml": "application/vnd.las.las+xml",
"lbd": "application/vnd.llamagraphics.life-balance.desktop",
"lbe": "application/vnd.llamagraphics.life-balance.exchange+xml",
"123": "application/vnd.lotus-1-2-3",
"apr": "application/vnd.lotus-approach",
"pre": "application/vnd.lotus-freelance",
"nsf": "application/vnd.lotus-notes",
"org": "application/vnd.lotus-organizer",
"scm": "application/vnd.lotus-screencam",
"lwp": "application/vnd.lotus-wordpro",
"portpkg": "application/vnd.macports.portpkg",
"mcd": "application/vnd.mcd",
"mc1": "application/vnd.medcalcdata",
"cdkey": "application/vnd.mediastation.cdkey",
"mwf": "application/vnd.mfer",
"mfm": "application/vnd.mfmp",
"flo": "application/vnd.micrografx.flo",
"igx": "application/vnd.micrografx.igx",
"mif": "application/vnd.mif",
"daf": "application/vnd.mobius.daf",
"dis": "application/vnd.mobius.dis",
"mbk": "application/vnd.mobius.mbk",
"mqy": "application/vnd.mobius.mqy",
"msl": "application/vnd.mobius.msl",
"plc": "application/vnd.mobius.plc",
"txf": "application/vnd.mobius.txf",
"mpn": "application/vnd.mophun.application",
"mpc": "application/vnd.mophun.certificate",
"xul": "application/vnd.mozilla.xul+xml",
"cil": "application/vnd.ms-artgalry",
"cab": "application/vnd.ms-cab-compressed",
"xls": "application/vnd.ms-excel",
"xlam": "application/vnd.ms-excel.addin.macroenabled.12",
"xlsb": "application/vnd.ms-excel.sheet.binary.macroenabled.12",
"xlsm": "application/vnd.ms-excel.sheet.macroenabled.12",
"xltm": "application/vnd.ms-excel.template.macroenabled.12",
"eot": "application/vnd.ms-fontobject",
"chm": "application/vnd.ms-htmlhelp",
"ims": "application/vnd.ms-ims",
"lrm": "application/vnd.ms-lrm",
"thmx": "application/vnd.ms-officetheme",
"cat": "application/vnd.ms-pki.seccat",
"stl": "application/vnd.ms-pki.stl",
"ppt": "application/vnd.ms-powerpoint",
"ppam": "application/vnd.ms-powerpoint.addin.macroenabled.12",
"pptm": "application/vnd.ms-powerpoint.presentation.macroenabled.12",
"sldm": "application/vnd.ms-powerpoint.slide.macroenabled.12",
"ppsm": "application/vnd.ms-powerpoint.slideshow.macroenabled.12",
"potm": "application/vnd.ms-powerpoint.template.macroenabled.12",
"mpp": "application/vnd.ms-project",
"docm": "application/vnd.ms-word.document.macroenabled.12",
"dotm": "application/vnd.ms-word.template.macroenabled.12",
"wps": "application/vnd.ms-works",
"wpl": "application/vnd.ms-wpl",
"xps": "application/vnd.ms-xpsdocument",
"mseq": "application/vnd.mseq",
"mus": "application/vnd.musician",
"msty": "application/vnd.muvee.style",
"taglet": "application/vnd.mynfc",
"nlu": "application/vnd.neurolanguage.nlu",
"nitf": "application/vnd.nitf",
"nnd": "application/vnd.noblenet-directory",
"nns": "application/vnd.noblenet-sealer",
"nnw": "application/vnd.noblenet-web",
"ngdat": "application/vnd.nokia.n-gage.data",
"n-gage": "application/vnd.nokia.n-gage.symbian.install",
"rpst": "application/vnd.nokia.radio-preset",
"rpss": "application/vnd.nokia.radio-presets",
"edm": "application/vnd.novadigm.edm",
"edx": "application/vnd.novadigm.edx",
"ext": "application/vnd.novadigm.ext",
"odc": "application/vnd.oasis.opendocument.chart",
"otc": "application/vnd.oasis.opendocument.chart-template",
"odb": "application/vnd.oasis.opendocument.database",
"odf": "application/vnd.oasis.opendocument.formula",
"odft": "application/vnd.oasis.opendocument.formula-template",
"odg": "application/vnd.oasis.opendocument.graphics",
"otg": "application/vnd.oasis.opendocument.graphics-template",
"odi": "application/vnd.oasis.opendocument.image",
"oti": "application/vnd.oasis.opendocument.image-template",
"odp": "application/vnd.oasis.opendocument.presentation",
"otp": "application/vnd.oasis.opendocument.presentation-template",
"ods": "application/vnd.oasis.opendocument.spreadsheet",
"ots": "application/vnd.oasis.opendocument.spreadsheet-template",
"odt": "application/vnd.oasis.opendocument.text",
"odm": "application/vnd.oasis.opendocument.text-master",
"ott": "application/vnd.oasis.opendocument.text-template",
"oth": "application/vnd.oasis.opendocument.text-web",
"xo": "application/vnd.olpc-sugar",
"dd2": "application/vnd.oma.dd2+xml",
"oxt": "application/vnd.openofficeorg.extension",
"pptx": "application/vnd.openxmlformats-officedocument.presentationml.presentation",
"sldx": "application/vnd.openxmlformats-officedocument.presentationml.slide",
"ppsx": "application/vnd.openxmlformats-officedocument.presentationml.slideshow",
"potx": "application/vnd.openxmlformats-officedocument.presentationml.template",
"xlsx": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
"xltx": "application/vnd.openxmlformats-officedocument.spreadsheetml.template",
"docx": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
"dotx": "application/vnd.openxmlformats-officedocument.wordprocessingml.template",
"mgp": "application/vnd.osgeo.mapguide.package",
"dp": "application/vnd.osgi.dp",
"esa": "application/vnd.osgi.subsystem",
"oprc": "application/vnd.palm",
"paw": "application/vnd.pawaafile",
"str": "application/vnd.pg.format",
"ei6": "application/vnd.pg.osasli",
"efif": "application/vnd.picsel",
"wg": "application/vnd.pmi.widget",
"plf": "application/vnd.pocketlearn",
"pbd": "application/vnd.powerbuilder6",
"box": "application/vnd.previewsystems.box",
"mgz": "application/vnd.proteus.magazine",
"qps": "application/vnd.publishare-delta-tree",
"ptid": "application/vnd.pvi.ptid1",
"qwd": "application/vnd.quark.quarkxpress",
"bed": "application/vnd.realvnc.bed",
"mxl": "application/vnd.recordare.musicxml",
"musicxml": "application/vnd.recordare.musicxml+xml",
"cryptonote": "application/vnd.rig.cryptonote",
"cod": "application/vnd.rim.cod",
"rm": "application/vnd.rn-realmedia",
"rmvb": "application/vnd.rn-realmedia-vbr",
"link66": "application/vnd.route66.link66+xml",
"st": "application/vnd.sailingtracker.track",
"see": "application/vnd.seemail",
"sema": "application/vnd.sema",
"semd": "application/vnd.semd",
"semf": "application/vnd.semf",
"ifm": "application/vnd.shana.informed.formdata",
"itp": "application/vnd.shana.informed.formtemplate",
"iif": "application/vnd.shana.informed.interchange",
"ipk": "application/vnd.shana.informed.package",
"twd": "application/vnd.simtech-mindmapper",
"mmf": "application/vnd.smaf",
"teacher": "application/vnd.smart.teacher",
"sdkd": "application/vnd.solent.sdkm+xml",
"dxp": "application/vnd.spotfire.dxp",
"sfs": "application/vnd.spotfire.sfs",
"sdc": "application/vnd.stardivision.calc",
"sda": "application/vnd.stardivision.draw",
"sdd": "application/vnd.stardivision.impress",
"smf": "application/vnd.stardivision.math",
"sdw": "application/vnd.stardivision.writer",
"sgl": "application/vnd.stardivision.writer-global",
"smzip": "application/vnd.stepmania.package",
"sm": "application/vnd.stepmania.stepchart",
"sxc": "application/vnd.sun.xml.calc",
"stc": "application/vnd.sun.xml.calc.template",
"sxd": "application/vnd.sun.xml.draw",
"std": "application/vnd.sun.xml.draw.template",
"sxi": "application/vnd.sun.xml.impress",
"sti": "application/vnd.sun.xml.impress.template",
"sxm": "application/vnd.sun.xml.math",
"sxw": "application/vnd.sun.xml.writer",
"sxg": "application/vnd.sun.xml.writer.global",
"stw": "application/vnd.sun.xml.writer.template",
"sus": "application/vnd.sus-calendar",
"svd": "application/vnd.svd",
"sis": "application/vnd.symbian.install",
"bdm": "application/vnd.syncml.dm+wbxml",
"xdm": "application/vnd.syncml.dm+xml",
"xsm": "application/vnd.syncml+xml",
"tao": "application/vnd.tao.intent-module-archive",
"cap": "application/vnd.tcpdump.pcap",
"tmo": "application/vnd.tmobile-livetv",
"tpt": "application/vnd.trid.tpt",
"mxs": "application/vnd.triscape.mxs",
"tra": "application/vnd.trueapp",
"ufd": "application/vnd.ufdl",
"utz": "application/vnd.uiq.theme",
"umj": "application/vnd.umajin",
"unityweb": "application/vnd.unity",
"uoml": "application/vnd.uoml+xml",
"vcx": "application/vnd.vcx",
"vss": "application/vnd.visio",
"vis": "application/vnd.visionary",
"vsf": "application/vnd.vsf",
"wbxml": "application/vnd.wap.wbxml",
"wmlc": "application/vnd.wap.wmlc",
"wmlsc": "application/vnd.wap.wmlscriptc",
"wtb": "application/vnd.webturbo",
"nbp": "application/vnd.wolfram.player",
"wpd": "application/vnd.wordperfect",
"wqd": "application/vnd.wqd",
"stf": "application/vnd.wt.stf",
"xar": "application/vnd.xara",
"xfdl": "application/vnd.xfdl",
"hvd": "application/vnd.yamaha.hv-dic",
"hvs": "application/vnd.yamaha.hv-script",
"hvp": "application/vnd.yamaha.hv-voice",
"osf": "application/vnd.yamaha.openscoreformat",
"osfpvg": "application/vnd.yamaha.openscoreformat.osfpvg+xml",
"saf": "application/vnd.yamaha.smaf-audio",
"spf": "application/vnd.yamaha.smaf-phrase",
"cmp": "application/vnd.yellowriver-custom-menu",
"zir": "application/vnd.zul",
"zaz": "application/vnd.zzazz.deck+xml",
"vxml": "application/voicexml+xml",
"wgt": "application/widget",
"hlp": "application/winhlp",
"wsdl": "application/wsdl+xml",
"wspolicy": "application/wspolicy+xml",
"7z": "application/x-7z-compressed",
"abw": "application/x-abiword",
"ace": "application/x-ace-compressed",
"dmg": "application/x-apple-diskimage",
"aab": "application/x-authorware-bin",
"aam": "application/x-authorware-map",
"aas": "application/x-authorware-seg",
"bcpio": "application/x-bcpio",
"torrent": "application/x-bittorrent",
"blb": "application/x-blorb",
"bz": "application/x-bzip",
"bz2": "application/x-bzip2",
"cbr": "application/x-cbr",
"vcd": "application/x-cdlink",
"cfs": "application/x-cfs-compressed",
"chat": "application/x-chat",
"pgn": "application/x-chess-pgn",
"nsc": "application/x-conference",
"cpio": "application/x-cpio",
"csh": "application/x-csh",
"deb": "application/x-debian-package",
"dgc": "application/x-dgc-compressed",
"cct": "application/x-director",
"wad": "application/x-doom",
"ncx": "application/x-dtbncx+xml",
"dtb": "application/x-dtbook+xml",
"res": "application/x-dtbresource+xml",
"dvi": "application/x-dvi",
"evy": "application/x-envoy",
"eva": "application/x-eva",
"bdf": "application/x-font-bdf",
"gsf": "application/x-font-ghostscript",
"psf": "application/x-font-linux-psf",
"pcf": "application/x-font-pcf",
"snf": "application/x-font-snf",
"afm": "application/x-font-type1",
"arc": "application/x-freearc",
"spl": "application/x-futuresplash",
"gca": "application/x-gca-compressed",
"ulx": "application/x-glulx",
"gnumeric": "application/x-gnumeric",
"gramps": "application/x-gramps-xml",
"gtar": "application/x-gtar",
"hdf": "application/x-hdf",
"install": "application/x-install-instructions",
"iso": "application/x-iso9660-image",
"jnlp": "application/x-java-jnlp-file",
"latex": "application/x-latex",
"lzh": "application/x-lzh-compressed",
"mie": "application/x-mie",
"mobi": "application/x-mobipocket-ebook",
"application": "application/x-ms-application",
"lnk": "application/x-ms-shortcut",
"wmd": "application/x-ms-wmd",
"wmz": "application/x-ms-wmz",
"xbap": "application/x-ms-xbap",
"mdb": "application/x-msaccess",
"obd": "application/x-msbinder",
"crd": "application/x-mscardfile",
"clp": "application/x-msclip",
"mny": "application/x-msmoney",
"pub": "application/x-mspublisher",
"scd": "application/x-msschedule",
"trm": "application/x-msterminal",
"wri": "application/x-mswrite",
"nzb": "application/x-nzb",
"p12": "application/x-pkcs12",
"p7b": "application/x-pkcs7-certificates",
"p7r": "application/x-pkcs7-certreqresp",
"rar": "application/x-rar-compressed",
"ris": "application/x-research-info-systems",
"sh": "application/x-sh",
"shar": "application/x-shar",
"swf": "application/x-shockwave-flash",
"xap": "application/x-silverlight-app",
"sql": "application/x-sql",
"sit": "application/x-stuffit",
"sitx": "application/x-stuffitx",
"srt": "application/x-subrip",
"sv4cpio": "application/x-sv4cpio",
"sv4crc": "application/x-sv4crc",
"t3": "application/x-t3vm-image",
"gam": "application/x-tads",
"tar": "application/x-tar",
"tcl": "application/x-tcl",
"tex": "application/x-tex",
"tfm": "application/x-tex-tfm",
"texi": "application/x-texinfo",
"obj": "application/x-tgif",
"ustar": "application/x-ustar",
"src": "application/x-wais-source",
"crt": "application/x-x509-ca-cert",
"fig": "application/x-xfig",
"xlf": "application/x-xliff+xml",
"xpi": "application/x-xpinstall",
"xz": "application/x-xz",
"xaml": "application/xaml+xml",
"xdf": "application/xcap-diff+xml",
"xenc": "application/xenc+xml",
"xhtml": "application/xhtml+xml",
"xml": "application/xml",
"dtd": "application/xml-dtd",
"xop": "application/xop+xml",
"xpl": "application/xproc+xml",
"xslt": "application/xslt+xml",
"xspf": "application/xspf+xml",
"mxml": "application/xv+xml",
"yang": "application/yang",
"yin": "application/yin+xml",
"zip": "application/zip",
"adp": "audio/adpcm",
"au": "audio/basic",
"mid": "audio/midi",
"m4a": "audio/mp4",
"mp3": "audio/mpeg",
"ogg": "audio/ogg",
"s3m": "audio/s3m",
"sil": "audio/silk",
"uva": "audio/vnd.dece.audio",
"eol": "audio/vnd.digital-winds",
"dra": "audio/vnd.dra",
"dts": "audio/vnd.dts",
"dtshd": "audio/vnd.dts.hd",
"lvp": "audio/vnd.lucent.voice",
"pya": "audio/vnd.ms-playready.media.pya",
"ecelp4800": "audio/vnd.nuera.ecelp4800",
"ecelp7470": "audio/vnd.nuera.ecelp7470",
"ecelp9600": "audio/vnd.nuera.ecelp9600",
"rip": "audio/vnd.rip",
"weba": "audio/webm",
"aac": "audio/x-aac",
"aiff": "audio/x-aiff",
"caf": "audio/x-caf",
"flac": "audio/x-flac",
"mka": "audio/x-matroska",
"m3u": "audio/x-mpegurl",
"wax": "audio/x-ms-wax",
"wma": "audio/x-ms-wma",
"rmp": "audio/x-pn-realaudio-plugin",
"wav": "audio/x-wav",
"xm": "audio/xm",
"cdx": "chemical/x-cdx",
"cif": "chemical/x-cif",
"cmdf": "chemical/x-cmdf",
"cml": "chemical/x-cml",
"csml": "chemical/x-csml",
"xyz": "chemical/x-xyz",
"ttc": "font/collection",
"otf": "font/otf",
"ttf": "font/ttf",
"woff": "font/woff",
"woff2": "font/woff2",
"bmp": "image/bmp",
"cgm": "image/cgm",
"g3": "image/g3fax",
"gif": "image/gif",
"ief": "image/ief",
"jpg": "image/jpeg",
"ktx": "image/ktx",
"png": "image/png",
"btif": "image/prs.btif",
"sgi": "image/sgi",
"svg": "image/svg+xml",
"tiff": "image/tiff",
"psd": "image/vnd.adobe.photoshop",
"dwg": "image/vnd.dwg",
"dxf": "image/vnd.dxf",
"fbs": "image/vnd.fastbidsheet",
"fpx": "image/vnd.fpx",
"fst": "image/vnd.fst",
"mmr": "image/vnd.fujixerox.edmics-mmr",
"rlc": "image/vnd.fujixerox.edmics-rlc",
"mdi": "image/vnd.ms-modi",
"wdp": "image/vnd.ms-photo",
"npx": "image/vnd.net-fpx",
"wbmp": "image/vnd.wap.wbmp",
"xif": "image/vnd.xiff",
"webp": "image/webp",
"3ds": "image/x-3ds",
"ras": "image/x-cmu-raster",
"cmx": "image/x-cmx",
"ico": "image/x-icon",
"sid": "image/x-mrsid-image",
"pcx": "image/x-pcx",
"pnm": "image/x-portable-anymap",
"pbm": "image/x-portable-bitmap",
"pgm": "image/x-portable-graymap",
"ppm": "image/x-portable-pixmap",
"rgb": "image/x-rgb",
"tga": "image/x-tga",
"xbm": "image/x-xbitmap",
"xpm": "image/x-xpixmap",
"xwd": "image/x-xwindowdump",
"dae": "model/vnd.collada+xml",
"dwf": "model/vnd.dwf",
"gdl": "model/vnd.gdl",
"gtw": "model/vnd.gtw",
"mts": "model/vnd.mts",
"vtu": "model/vnd.vtu",
"appcache": "text/cache-manifest",
"ics": "text/calendar",
"css": "text/css",
"csv": "text/csv",
"html": "text/html",
"n3": "text/n3",
"txt": "text/plain",
"dsc": "text/prs.lines.tag",
"rtx": "text/richtext",
"tsv": "text/tab-separated-values",
"ttl": "text/turtle",
"vcard": "text/vcard",
"curl": "text/vnd.curl",
"dcurl": "text/vnd.curl.dcurl",
"mcurl": "text/vnd.curl.mcurl",
"scurl": "text/vnd.curl.scurl",
"sub": "text/vnd.dvb.subtitle",
"fly": "text/vnd.fly",
"flx": "text/vnd.fmi.flexstor",
"gv": "text/vnd.graphviz",
"3dml": "text/vnd.in3d.3dml",
"spot": "text/vnd.in3d.spot",
"jad": "text/vnd.sun.j2me.app-descriptor",
"wml": "text/vnd.wap.wml",
"wmls": "text/vnd.wap.wmlscript",
"asm": "text/x-asm",
"c": "text/x-c",
"java": "text/x-java-source",
"nfo": "text/x-nfo",
"opml": "text/x-opml",
"pas": "text/x-pascal",
"etx": "text/x-setext",
"sfv": "text/x-sfv",
"uu": "text/x-uuencode",
"vcs": "text/x-vcalendar",
"vcf": "text/x-vcard",
"3gp": "video/3gpp",
"3g2": "video/3gpp2",
"h261": "video/h261",
"h263": "video/h263",
"h264": "video/h264",
"jpgv": "video/jpeg",
"mp4": "video/mp4",
"mpeg": "video/mpeg",
"ogv": "video/ogg",
"dvb": "video/vnd.dvb.file",
"fvt": "video/vnd.fvt",
"pyv": "video/vnd.ms-playready.media.pyv",
"viv": "video/vnd.vivo",
"webm": "video/webm",
"f4v": "video/x-f4v",
"fli": "video/x-fli",
"flv": "video/x-flv",
"m4v": "video/x-m4v",
"mkv": "video/x-matroska",
"mng": "video/x-mng",
"asf": "video/x-ms-asf",
"vob": "video/x-ms-vob",
"wm": "video/x-ms-wm",
"wmv": "video/x-ms-wmv",
"wmx": "video/x-ms-wmx",
"wvx": "video/x-ms-wvx",
"avi": "video/x-msvideo",
"movie": "video/x-sgi-movie",
"smv": "video/x-smv",
"ice": "x-conference/x-cooltalk",
}
================================================
FILE: internal/misc/oauth.go
================================================
package misc
import (
"crypto/rand"
"encoding/hex"
"fmt"
"net/url"
"strings"
)
// GenerateRandomState generates a cryptographically secure random state parameter
// for OAuth2 flows to prevent CSRF attacks.
//
// Returns:
// - string: A hexadecimal encoded random state string
// - error: An error if the random generation fails, nil otherwise
func GenerateRandomState() (string, error) {
bytes := make([]byte, 16)
if _, err := rand.Read(bytes); err != nil {
return "", fmt.Errorf("failed to generate random bytes: %w", err)
}
return hex.EncodeToString(bytes), nil
}
// OAuthCallback captures the parsed OAuth callback parameters.
type OAuthCallback struct {
Code string
State string
Error string
ErrorDescription string
}
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
// It returns nil when the input is empty.
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
trimmed := strings.TrimSpace(input)
if trimmed == "" {
return nil, nil
}
candidate := trimmed
if !strings.Contains(candidate, "://") {
if strings.HasPrefix(candidate, "?") {
candidate = "http://localhost" + candidate
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
candidate = "http://" + candidate
} else if strings.Contains(candidate, "=") {
candidate = "http://localhost/?" + candidate
} else {
return nil, fmt.Errorf("invalid callback URL")
}
}
parsedURL, err := url.Parse(candidate)
if err != nil {
return nil, err
}
query := parsedURL.Query()
code := strings.TrimSpace(query.Get("code"))
state := strings.TrimSpace(query.Get("state"))
errCode := strings.TrimSpace(query.Get("error"))
errDesc := strings.TrimSpace(query.Get("error_description"))
if parsedURL.Fragment != "" {
if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil {
if code == "" {
code = strings.TrimSpace(fragQuery.Get("code"))
}
if state == "" {
state = strings.TrimSpace(fragQuery.Get("state"))
}
if errCode == "" {
errCode = strings.TrimSpace(fragQuery.Get("error"))
}
if errDesc == "" {
errDesc = strings.TrimSpace(fragQuery.Get("error_description"))
}
}
}
if code != "" && state == "" && strings.Contains(code, "#") {
parts := strings.SplitN(code, "#", 2)
code = parts[0]
state = parts[1]
}
if errCode == "" && errDesc != "" {
errCode = errDesc
errDesc = ""
}
if code == "" && errCode == "" {
return nil, fmt.Errorf("callback URL missing code")
}
return &OAuthCallback{
Code: code,
State: state,
Error: errCode,
ErrorDescription: errDesc,
}, nil
}
================================================
FILE: internal/registry/model_definitions.go
================================================
// Package registry provides model definitions and lookup helpers for various AI providers.
// Static model metadata is loaded from the embedded models.json file and can be refreshed from network.
package registry
import (
"strings"
)
// staticModelsJSON mirrors the top-level structure of models.json.
type staticModelsJSON struct {
Claude []*ModelInfo `json:"claude"`
Gemini []*ModelInfo `json:"gemini"`
Vertex []*ModelInfo `json:"vertex"`
GeminiCLI []*ModelInfo `json:"gemini-cli"`
AIStudio []*ModelInfo `json:"aistudio"`
CodexFree []*ModelInfo `json:"codex-free"`
CodexTeam []*ModelInfo `json:"codex-team"`
CodexPlus []*ModelInfo `json:"codex-plus"`
CodexPro []*ModelInfo `json:"codex-pro"`
Qwen []*ModelInfo `json:"qwen"`
IFlow []*ModelInfo `json:"iflow"`
Kimi []*ModelInfo `json:"kimi"`
Antigravity []*ModelInfo `json:"antigravity"`
}
// GetClaudeModels returns the standard Claude model definitions.
func GetClaudeModels() []*ModelInfo {
return cloneModelInfos(getModels().Claude)
}
// GetGeminiModels returns the standard Gemini model definitions.
func GetGeminiModels() []*ModelInfo {
return cloneModelInfos(getModels().Gemini)
}
// GetGeminiVertexModels returns Gemini model definitions for Vertex AI.
func GetGeminiVertexModels() []*ModelInfo {
return cloneModelInfos(getModels().Vertex)
}
// GetGeminiCLIModels returns Gemini model definitions for the Gemini CLI.
func GetGeminiCLIModels() []*ModelInfo {
return cloneModelInfos(getModels().GeminiCLI)
}
// GetAIStudioModels returns model definitions for AI Studio.
func GetAIStudioModels() []*ModelInfo {
return cloneModelInfos(getModels().AIStudio)
}
// GetCodexFreeModels returns model definitions for the Codex free plan tier.
func GetCodexFreeModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexFree)
}
// GetCodexTeamModels returns model definitions for the Codex team plan tier.
func GetCodexTeamModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexTeam)
}
// GetCodexPlusModels returns model definitions for the Codex plus plan tier.
func GetCodexPlusModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPlus)
}
// GetCodexProModels returns model definitions for the Codex pro plan tier.
func GetCodexProModels() []*ModelInfo {
return cloneModelInfos(getModels().CodexPro)
}
// GetQwenModels returns the standard Qwen model definitions.
func GetQwenModels() []*ModelInfo {
return cloneModelInfos(getModels().Qwen)
}
// GetIFlowModels returns the standard iFlow model definitions.
func GetIFlowModels() []*ModelInfo {
return cloneModelInfos(getModels().IFlow)
}
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions.
func GetKimiModels() []*ModelInfo {
return cloneModelInfos(getModels().Kimi)
}
// GetAntigravityModels returns the standard Antigravity model definitions.
func GetAntigravityModels() []*ModelInfo {
return cloneModelInfos(getModels().Antigravity)
}
// cloneModelInfos returns a shallow copy of the slice with each element deep-cloned.
func cloneModelInfos(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
out := make([]*ModelInfo, len(models))
for i, m := range models {
out[i] = cloneModelInfo(m)
}
return out
}
// GetStaticModelDefinitionsByChannel returns static model definitions for a given channel/provider.
// It returns nil when the channel is unknown.
//
// Supported channels:
// - claude
// - gemini
// - vertex
// - gemini-cli
// - aistudio
// - codex
// - qwen
// - iflow
// - kimi
// - antigravity
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
key := strings.ToLower(strings.TrimSpace(channel))
switch key {
case "claude":
return GetClaudeModels()
case "gemini":
return GetGeminiModels()
case "vertex":
return GetGeminiVertexModels()
case "gemini-cli":
return GetGeminiCLIModels()
case "aistudio":
return GetAIStudioModels()
case "codex":
return GetCodexProModels()
case "qwen":
return GetQwenModels()
case "iflow":
return GetIFlowModels()
case "kimi":
return GetKimiModels()
case "antigravity":
return GetAntigravityModels()
default:
return nil
}
}
// LookupStaticModelInfo searches all static model definitions for a model by ID.
// Returns nil if no matching model is found.
func LookupStaticModelInfo(modelID string) *ModelInfo {
if modelID == "" {
return nil
}
data := getModels()
allModels := [][]*ModelInfo{
data.Claude,
data.Gemini,
data.Vertex,
data.GeminiCLI,
data.AIStudio,
data.CodexPro,
data.Qwen,
data.IFlow,
data.Kimi,
data.Antigravity,
}
for _, models := range allModels {
for _, m := range models {
if m != nil && m.ID == modelID {
return cloneModelInfo(m)
}
}
}
return nil
}
================================================
FILE: internal/registry/model_registry.go
================================================
// Package registry provides centralized model management for all AI service providers.
// It implements a dynamic model registry with reference counting to track active clients
// and automatically hide models when no clients are available or when quota is exceeded.
package registry
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"time"
misc "github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
// ModelInfo represents information about an available model
type ModelInfo struct {
// ID is the unique identifier for the model
ID string `json:"id"`
// Object type for the model (typically "model")
Object string `json:"object"`
// Created timestamp when the model was created
Created int64 `json:"created"`
// OwnedBy indicates the organization that owns the model
OwnedBy string `json:"owned_by"`
// Type indicates the model type (e.g., "claude", "gemini", "openai")
Type string `json:"type"`
// DisplayName is the human-readable name for the model
DisplayName string `json:"display_name,omitempty"`
// Name is used for Gemini-style model names
Name string `json:"name,omitempty"`
// Version is the model version
Version string `json:"version,omitempty"`
// Description provides detailed information about the model
Description string `json:"description,omitempty"`
// InputTokenLimit is the maximum input token limit
InputTokenLimit int `json:"inputTokenLimit,omitempty"`
// OutputTokenLimit is the maximum output token limit
OutputTokenLimit int `json:"outputTokenLimit,omitempty"`
// SupportedGenerationMethods lists supported generation methods
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
// ContextLength is the context window size
ContextLength int `json:"context_length,omitempty"`
// MaxCompletionTokens is the maximum completion tokens
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// SupportedParameters lists supported parameters
SupportedParameters []string `json:"supported_parameters,omitempty"`
// SupportedInputModalities lists supported input modalities (e.g., TEXT, IMAGE, VIDEO, AUDIO)
SupportedInputModalities []string `json:"supportedInputModalities,omitempty"`
// SupportedOutputModalities lists supported output modalities (e.g., TEXT, IMAGE)
SupportedOutputModalities []string `json:"supportedOutputModalities,omitempty"`
// Thinking holds provider-specific reasoning/thinking budget capabilities.
// This is optional and currently used for Gemini thinking budget normalization.
Thinking *ThinkingSupport `json:"thinking,omitempty"`
// UserDefined indicates this model was defined through config file's models[]
// array (e.g., openai-compatibility.*.models[], *-api-key.models[]).
// UserDefined models have thinking configuration passed through without validation.
UserDefined bool `json:"-"`
}
type availableModelsCacheEntry struct {
models []map[string]any
expiresAt time.Time
}
// ThinkingSupport describes a model family's supported internal reasoning budget range.
// Values are interpreted in provider-native token units.
type ThinkingSupport struct {
// Min is the minimum allowed thinking budget (inclusive).
Min int `json:"min,omitempty"`
// Max is the maximum allowed thinking budget (inclusive).
Max int `json:"max,omitempty"`
// ZeroAllowed indicates whether 0 is a valid value (to disable thinking).
ZeroAllowed bool `json:"zero_allowed,omitempty"`
// DynamicAllowed indicates whether -1 is a valid value (dynamic thinking budget).
DynamicAllowed bool `json:"dynamic_allowed,omitempty"`
// Levels defines discrete reasoning effort levels (e.g., "low", "medium", "high").
// When set, the model uses level-based reasoning instead of token budgets.
Levels []string `json:"levels,omitempty"`
}
// ModelRegistration tracks a model's availability
type ModelRegistration struct {
// Info contains the model metadata
Info *ModelInfo
// InfoByProvider maps provider identifiers to specific ModelInfo to support differing capabilities.
InfoByProvider map[string]*ModelInfo
// Count is the number of active clients that can provide this model
Count int
// LastUpdated tracks when this registration was last modified
LastUpdated time.Time
// QuotaExceededClients tracks which clients have exceeded quota for this model
QuotaExceededClients map[string]*time.Time
// Providers tracks available clients grouped by provider identifier
Providers map[string]int
// SuspendedClients tracks temporarily disabled clients keyed by client ID
SuspendedClients map[string]string
}
// ModelRegistryHook provides optional callbacks for external integrations to track model list changes.
// Hook implementations must be non-blocking and resilient; calls are executed asynchronously and panics are recovered.
type ModelRegistryHook interface {
OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo)
OnModelsUnregistered(ctx context.Context, provider, clientID string)
}
// ModelRegistry manages the global registry of available models
type ModelRegistry struct {
// models maps model ID to registration information
models map[string]*ModelRegistration
// clientModels maps client ID to the models it provides
clientModels map[string][]string
// clientModelInfos maps client ID to a map of model ID -> ModelInfo
// This preserves the original model info provided by each client
clientModelInfos map[string]map[string]*ModelInfo
// clientProviders maps client ID to its provider identifier
clientProviders map[string]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
// availableModelsCache stores per-handler snapshots for GetAvailableModels.
availableModelsCache map[string]availableModelsCacheEntry
// hook is an optional callback sink for model registration changes
hook ModelRegistryHook
}
// Global model registry instance
var globalRegistry *ModelRegistry
var registryOnce sync.Once
// GetGlobalRegistry returns the global model registry instance
func GetGlobalRegistry() *ModelRegistry {
registryOnce.Do(func() {
globalRegistry = &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
availableModelsCache: make(map[string]availableModelsCacheEntry),
mutex: &sync.RWMutex{},
}
})
return globalRegistry
}
func (r *ModelRegistry) ensureAvailableModelsCacheLocked() {
if r.availableModelsCache == nil {
r.availableModelsCache = make(map[string]availableModelsCacheEntry)
}
}
func (r *ModelRegistry) invalidateAvailableModelsCacheLocked() {
if len(r.availableModelsCache) == 0 {
return
}
clear(r.availableModelsCache)
}
// LookupModelInfo searches dynamic registry (provider-specific > global) then static definitions.
func LookupModelInfo(modelID string, provider ...string) *ModelInfo {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
return nil
}
p := ""
if len(provider) > 0 {
p = strings.ToLower(strings.TrimSpace(provider[0]))
}
if info := GetGlobalRegistry().GetModelInfo(modelID, p); info != nil {
return cloneModelInfo(info)
}
return cloneModelInfo(LookupStaticModelInfo(modelID))
}
// SetHook sets an optional hook for observing model registration changes.
func (r *ModelRegistry) SetHook(hook ModelRegistryHook) {
if r == nil {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.hook = hook
}
const defaultModelRegistryHookTimeout = 5 * time.Second
const modelQuotaExceededWindow = 5 * time.Minute
func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) {
hook := r.hook
if hook == nil {
return
}
modelsCopy := cloneModelInfosUnique(models)
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsRegistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsRegistered(ctx, provider, clientID, modelsCopy)
}()
}
func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) {
hook := r.hook
if hook == nil {
return
}
go func() {
defer func() {
if recovered := recover(); recovered != nil {
log.Errorf("model registry hook OnModelsUnregistered panic: %v", recovered)
}
}()
ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout)
defer cancel()
hook.OnModelsUnregistered(ctx, provider, clientID)
}()
}
// RegisterClient registers a client and its supported models
// Parameters:
// - clientID: Unique identifier for the client
// - clientProvider: Provider name (e.g., "gemini", "claude", "openai")
// - models: List of models that this client can provide
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
provider := strings.ToLower(clientProvider)
uniqueModelIDs := make([]string, 0, len(models))
rawModelIDs := make([]string, 0, len(models))
newModels := make(map[string]*ModelInfo, len(models))
newCounts := make(map[string]int, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
rawModelIDs = append(rawModelIDs, model.ID)
newCounts[model.ID]++
if _, exists := newModels[model.ID]; exists {
continue
}
newModels[model.ID] = model
uniqueModelIDs = append(uniqueModelIDs, model.ID)
}
if len(uniqueModelIDs) == 0 {
// No models supplied; unregister existing client state if present.
r.unregisterClientInternal(clientID)
delete(r.clientModels, clientID)
delete(r.clientModelInfos, clientID)
delete(r.clientProviders, clientID)
r.invalidateAvailableModelsCacheLocked()
misc.LogCredentialSeparator()
return
}
now := time.Now()
oldModels, hadExisting := r.clientModels[clientID]
oldProvider := r.clientProviders[clientID]
providerChanged := oldProvider != provider
if !hadExisting {
// Pure addition path.
for _, modelID := range rawModelIDs {
model := newModels[modelID]
r.addModelRegistration(modelID, provider, model, now)
}
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
// Store client's own model infos
clientInfos := make(map[string]*ModelInfo, len(newModels))
for id, m := range newModels {
clientInfos[id] = cloneModelInfo(m)
}
r.clientModelInfos[clientID] = clientInfos
if provider != "" {
r.clientProviders[clientID] = provider
} else {
delete(r.clientProviders, clientID)
}
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models)
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
misc.LogCredentialSeparator()
return
}
oldCounts := make(map[string]int, len(oldModels))
for _, id := range oldModels {
oldCounts[id]++
}
added := make([]string, 0)
for _, id := range uniqueModelIDs {
if oldCounts[id] == 0 {
added = append(added, id)
}
}
removed := make([]string, 0)
for id := range oldCounts {
if newCounts[id] == 0 {
removed = append(removed, id)
}
}
// Handle provider change for overlapping models before modifications.
if providerChanged && oldProvider != "" {
for id, newCount := range newCounts {
if newCount == 0 {
continue
}
oldCount := oldCounts[id]
if oldCount == 0 {
continue
}
toRemove := newCount
if oldCount < toRemove {
toRemove = oldCount
}
if reg, ok := r.models[id]; ok && reg.Providers != nil {
if count, okProv := reg.Providers[oldProvider]; okProv {
if count <= toRemove {
delete(reg.Providers, oldProvider)
if reg.InfoByProvider != nil {
delete(reg.InfoByProvider, oldProvider)
}
} else {
reg.Providers[oldProvider] = count - toRemove
}
}
}
}
}
// Apply removals first to keep counters accurate.
for _, id := range removed {
oldCount := oldCounts[id]
for i := 0; i < oldCount; i++ {
r.removeModelRegistration(clientID, id, oldProvider, now)
}
}
for id, oldCount := range oldCounts {
newCount := newCounts[id]
if newCount == 0 || oldCount <= newCount {
continue
}
overage := oldCount - newCount
for i := 0; i < overage; i++ {
r.removeModelRegistration(clientID, id, oldProvider, now)
}
}
// Apply additions.
for id, newCount := range newCounts {
oldCount := oldCounts[id]
if newCount <= oldCount {
continue
}
model := newModels[id]
diff := newCount - oldCount
for i := 0; i < diff; i++ {
r.addModelRegistration(id, provider, model, now)
}
}
// Update metadata for models that remain associated with the client.
addedSet := make(map[string]struct{}, len(added))
for _, id := range added {
addedSet[id] = struct{}{}
}
for _, id := range uniqueModelIDs {
model := newModels[id]
if reg, ok := r.models[id]; ok {
reg.Info = cloneModelInfo(model)
if provider != "" {
if reg.InfoByProvider == nil {
reg.InfoByProvider = make(map[string]*ModelInfo)
}
reg.InfoByProvider[provider] = cloneModelInfo(model)
}
reg.LastUpdated = now
// Re-registering an existing client/model binding starts a fresh registry
// snapshot for that binding. Cooldown and suspension are transient
// scheduling state and must not survive this reconciliation step.
if reg.QuotaExceededClients != nil {
delete(reg.QuotaExceededClients, clientID)
}
if reg.SuspendedClients != nil {
delete(reg.SuspendedClients, clientID)
}
if providerChanged && provider != "" {
if _, newlyAdded := addedSet[id]; newlyAdded {
continue
}
overlapCount := newCounts[id]
if oldCount := oldCounts[id]; oldCount < overlapCount {
overlapCount = oldCount
}
if overlapCount <= 0 {
continue
}
if reg.Providers == nil {
reg.Providers = make(map[string]int)
}
reg.Providers[provider] += overlapCount
}
}
}
// Update client bookkeeping.
if len(rawModelIDs) > 0 {
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
}
// Update client's own model infos
clientInfos := make(map[string]*ModelInfo, len(newModels))
for id, m := range newModels {
clientInfos[id] = cloneModelInfo(m)
}
r.clientModelInfos[clientID] = clientInfos
if provider != "" {
r.clientProviders[clientID] = provider
} else {
delete(r.clientProviders, clientID)
}
r.invalidateAvailableModelsCacheLocked()
r.triggerModelsRegistered(provider, clientID, models)
if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output.
return
}
log.Debugf("Reconciled client %s (provider %s) models: +%d, -%d", clientID, provider, len(added), len(removed))
misc.LogCredentialSeparator()
}
func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *ModelInfo, now time.Time) {
if model == nil || modelID == "" {
return
}
if existing, exists := r.models[modelID]; exists {
existing.Count++
existing.LastUpdated = now
existing.Info = cloneModelInfo(model)
if existing.SuspendedClients == nil {
existing.SuspendedClients = make(map[string]string)
}
if existing.InfoByProvider == nil {
existing.InfoByProvider = make(map[string]*ModelInfo)
}
if provider != "" {
if existing.Providers == nil {
existing.Providers = make(map[string]int)
}
existing.Providers[provider]++
existing.InfoByProvider[provider] = cloneModelInfo(model)
}
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
return
}
registration := &ModelRegistration{
Info: cloneModelInfo(model),
InfoByProvider: make(map[string]*ModelInfo),
Count: 1,
LastUpdated: now,
QuotaExceededClients: make(map[string]*time.Time),
SuspendedClients: make(map[string]string),
}
if provider != "" {
registration.Providers = map[string]int{provider: 1}
registration.InfoByProvider[provider] = cloneModelInfo(model)
}
r.models[modelID] = registration
log.Debugf("Registered new model %s from provider %s", modelID, provider)
}
func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider string, now time.Time) {
registration, exists := r.models[modelID]
if !exists {
return
}
registration.Count--
registration.LastUpdated = now
if registration.QuotaExceededClients != nil {
delete(registration.QuotaExceededClients, clientID)
}
if registration.SuspendedClients != nil {
delete(registration.SuspendedClients, clientID)
}
if registration.Count < 0 {
registration.Count = 0
}
if provider != "" && registration.Providers != nil {
if count, ok := registration.Providers[provider]; ok {
if count <= 1 {
delete(registration.Providers, provider)
if registration.InfoByProvider != nil {
delete(registration.InfoByProvider, provider)
}
} else {
registration.Providers[provider] = count - 1
}
}
}
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
if registration.Count <= 0 {
delete(r.models, modelID)
log.Debugf("Removed model %s as no clients remain", modelID)
}
}
func cloneModelInfo(model *ModelInfo) *ModelInfo {
if model == nil {
return nil
}
copyModel := *model
if len(model.SupportedGenerationMethods) > 0 {
copyModel.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
copyModel.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
if len(model.SupportedInputModalities) > 0 {
copyModel.SupportedInputModalities = append([]string(nil), model.SupportedInputModalities...)
}
if len(model.SupportedOutputModalities) > 0 {
copyModel.SupportedOutputModalities = append([]string(nil), model.SupportedOutputModalities...)
}
if model.Thinking != nil {
copyThinking := *model.Thinking
if len(model.Thinking.Levels) > 0 {
copyThinking.Levels = append([]string(nil), model.Thinking.Levels...)
}
copyModel.Thinking = ©Thinking
}
return ©Model
}
func cloneModelInfosUnique(models []*ModelInfo) []*ModelInfo {
if len(models) == 0 {
return nil
}
cloned := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
if _, exists := seen[model.ID]; exists {
continue
}
seen[model.ID] = struct{}{}
cloned = append(cloned, cloneModelInfo(model))
}
return cloned
}
// UnregisterClient removes a client and decrements counts for its models
// Parameters:
// - clientID: Unique identifier for the client to remove
func (r *ModelRegistry) UnregisterClient(clientID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.unregisterClientInternal(clientID)
r.invalidateAvailableModelsCacheLocked()
}
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
func (r *ModelRegistry) unregisterClientInternal(clientID string) {
models, exists := r.clientModels[clientID]
provider, hasProvider := r.clientProviders[clientID]
if !exists {
if hasProvider {
delete(r.clientProviders, clientID)
}
return
}
now := time.Now()
for _, modelID := range models {
if registration, isExists := r.models[modelID]; isExists {
registration.Count--
registration.LastUpdated = now
// Remove quota tracking for this client
delete(registration.QuotaExceededClients, clientID)
if registration.SuspendedClients != nil {
delete(registration.SuspendedClients, clientID)
}
if hasProvider && registration.Providers != nil {
if count, ok := registration.Providers[provider]; ok {
if count <= 1 {
delete(registration.Providers, provider)
if registration.InfoByProvider != nil {
delete(registration.InfoByProvider, provider)
}
} else {
registration.Providers[provider] = count - 1
}
}
}
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
// Remove model if no clients remain
if registration.Count <= 0 {
delete(r.models, modelID)
log.Debugf("Removed model %s as no clients remain", modelID)
}
}
}
delete(r.clientModels, clientID)
delete(r.clientModelInfos, clientID)
if hasProvider {
delete(r.clientProviders, clientID)
}
log.Debugf("Unregistered client %s", clientID)
// Separator line after completing client unregistration (after the summary line)
misc.LogCredentialSeparator()
r.triggerModelsUnregistered(provider, clientID)
}
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
// Parameters:
// - clientID: The client that exceeded quota
// - modelID: The model that exceeded quota
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists {
now := time.Now()
registration.QuotaExceededClients[clientID] = &now
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
}
}
// ClearModelQuotaExceeded removes quota exceeded status for a model and client
// Parameters:
// - clientID: The client to clear quota status for
// - modelID: The model to clear quota status for
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if registration, exists := r.models[modelID]; exists {
delete(registration.QuotaExceededClients, clientID)
r.invalidateAvailableModelsCacheLocked()
// log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
}
}
// SuspendClientModel marks a client's model as temporarily unavailable until explicitly resumed.
// Parameters:
// - clientID: The client to suspend
// - modelID: The model affected by the suspension
// - reason: Optional description for observability
func (r *ModelRegistry) SuspendClientModel(clientID, modelID, reason string) {
if clientID == "" || modelID == "" {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID]
if !exists || registration == nil {
return
}
if registration.SuspendedClients == nil {
registration.SuspendedClients = make(map[string]string)
}
if _, already := registration.SuspendedClients[clientID]; already {
return
}
registration.SuspendedClients[clientID] = reason
registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
if reason != "" {
log.Debugf("Suspended client %s for model %s: %s", clientID, modelID, reason)
} else {
log.Debugf("Suspended client %s for model %s", clientID, modelID)
}
}
// ResumeClientModel clears a previous suspension so the client counts toward availability again.
// Parameters:
// - clientID: The client to resume
// - modelID: The model being resumed
func (r *ModelRegistry) ResumeClientModel(clientID, modelID string) {
if clientID == "" || modelID == "" {
return
}
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
registration, exists := r.models[modelID]
if !exists || registration == nil || registration.SuspendedClients == nil {
return
}
if _, ok := registration.SuspendedClients[clientID]; !ok {
return
}
delete(registration.SuspendedClients, clientID)
registration.LastUpdated = time.Now()
r.invalidateAvailableModelsCacheLocked()
log.Debugf("Resumed client %s for model %s", clientID, modelID)
}
// ClientSupportsModel reports whether the client registered support for modelID.
func (r *ModelRegistry) ClientSupportsModel(clientID, modelID string) bool {
clientID = strings.TrimSpace(clientID)
modelID = strings.TrimSpace(modelID)
if clientID == "" || modelID == "" {
return false
}
r.mutex.RLock()
defer r.mutex.RUnlock()
models, exists := r.clientModels[clientID]
if !exists || len(models) == 0 {
return false
}
for _, id := range models {
if strings.EqualFold(strings.TrimSpace(id), modelID) {
return true
}
}
return false
}
// GetAvailableModels returns all models that have at least one available client
// Parameters:
// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini")
//
// Returns:
// - []map[string]any: List of available models in the requested format
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
now := time.Now()
r.mutex.RLock()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
models := cloneModelMaps(cache.models)
r.mutex.RUnlock()
return models
}
r.mutex.RUnlock()
r.mutex.Lock()
defer r.mutex.Unlock()
r.ensureAvailableModelsCacheLocked()
if cache, ok := r.availableModelsCache[handlerType]; ok && (cache.expiresAt.IsZero() || now.Before(cache.expiresAt)) {
return cloneModelMaps(cache.models)
}
models, expiresAt := r.buildAvailableModelsLocked(handlerType, now)
r.availableModelsCache[handlerType] = availableModelsCacheEntry{
models: cloneModelMaps(models),
expiresAt: expiresAt,
}
return models
}
func (r *ModelRegistry) buildAvailableModelsLocked(handlerType string, now time.Time) ([]map[string]any, time.Time) {
models := make([]map[string]any, 0, len(r.models))
var expiresAt time.Time
for _, registration := range r.models {
availableClients := registration.Count
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime == nil {
continue
}
recoveryAt := quotaTime.Add(modelQuotaExceededWindow)
if now.Before(recoveryAt) {
expiredClients++
if expiresAt.IsZero() || recoveryAt.Before(expiresAt) {
expiresAt = recoveryAt
}
}
}
cooldownSuspended := 0
otherSuspended := 0
if registration.SuspendedClients != nil {
for _, reason := range registration.SuspendedClients {
if strings.EqualFold(reason, "quota") {
cooldownSuspended++
continue
}
otherSuspended++
}
}
effectiveClients := availableClients - expiredClients - otherSuspended
if effectiveClients < 0 {
effectiveClients = 0
}
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
model := r.convertModelToMap(registration.Info, handlerType)
if model != nil {
models = append(models, model)
}
}
}
return models, expiresAt
}
func cloneModelMaps(models []map[string]any) []map[string]any {
cloned := make([]map[string]any, 0, len(models))
for _, model := range models {
if model == nil {
cloned = append(cloned, nil)
continue
}
copyModel := make(map[string]any, len(model))
for key, value := range model {
copyModel[key] = cloneModelMapValue(value)
}
cloned = append(cloned, copyModel)
}
return cloned
}
func cloneModelMapValue(value any) any {
switch typed := value.(type) {
case map[string]any:
copyMap := make(map[string]any, len(typed))
for key, entry := range typed {
copyMap[key] = cloneModelMapValue(entry)
}
return copyMap
case []any:
copySlice := make([]any, len(typed))
for i, entry := range typed {
copySlice[i] = cloneModelMapValue(entry)
}
return copySlice
case []string:
return append([]string(nil), typed...)
default:
return value
}
}
// GetAvailableModelsByProvider returns models available for the given provider identifier.
// Parameters:
// - provider: Provider identifier (e.g., "codex", "gemini", "antigravity")
//
// Returns:
// - []*ModelInfo: List of available models for the provider
func (r *ModelRegistry) GetAvailableModelsByProvider(provider string) []*ModelInfo {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return nil
}
r.mutex.RLock()
defer r.mutex.RUnlock()
type providerModel struct {
count int
info *ModelInfo
}
providerModels := make(map[string]*providerModel)
for clientID, clientProvider := range r.clientProviders {
if clientProvider != provider {
continue
}
modelIDs := r.clientModels[clientID]
if len(modelIDs) == 0 {
continue
}
clientInfos := r.clientModelInfos[clientID]
for _, modelID := range modelIDs {
modelID = strings.TrimSpace(modelID)
if modelID == "" {
continue
}
entry := providerModels[modelID]
if entry == nil {
entry = &providerModel{}
providerModels[modelID] = entry
}
entry.count++
if entry.info == nil {
if clientInfos != nil {
if info := clientInfos[modelID]; info != nil {
entry.info = info
}
}
if entry.info == nil {
if reg, ok := r.models[modelID]; ok && reg != nil && reg.Info != nil {
entry.info = reg.Info
}
}
}
}
}
if len(providerModels) == 0 {
return nil
}
now := time.Now()
result := make([]*ModelInfo, 0, len(providerModels))
for modelID, entry := range providerModels {
if entry == nil || entry.count <= 0 {
continue
}
registration, ok := r.models[modelID]
expiredClients := 0
cooldownSuspended := 0
otherSuspended := 0
if ok && registration != nil {
if registration.QuotaExceededClients != nil {
for clientID, quotaTime := range registration.QuotaExceededClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++
}
}
}
if registration.SuspendedClients != nil {
for clientID, reason := range registration.SuspendedClients {
if clientID == "" {
continue
}
if p, okProvider := r.clientProviders[clientID]; !okProvider || p != provider {
continue
}
if strings.EqualFold(reason, "quota") {
cooldownSuspended++
continue
}
otherSuspended++
}
}
}
availableClients := entry.count
effectiveClients := availableClients - expiredClients - otherSuspended
if effectiveClients < 0 {
effectiveClients = 0
}
if effectiveClients > 0 || (availableClients > 0 && (expiredClients > 0 || cooldownSuspended > 0) && otherSuspended == 0) {
if entry.info != nil {
result = append(result, cloneModelInfo(entry.info))
continue
}
if ok && registration != nil && registration.Info != nil {
result = append(result, cloneModelInfo(registration.Info))
}
}
}
return result
}
// GetModelCount returns the number of available clients for a specific model
// Parameters:
// - modelID: The model ID to check
//
// Returns:
// - int: Number of available clients for the model
func (r *ModelRegistry) GetModelCount(modelID string) int {
r.mutex.RLock()
defer r.mutex.RUnlock()
if registration, exists := r.models[modelID]; exists {
now := time.Now()
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < modelQuotaExceededWindow {
expiredClients++
}
}
suspendedClients := 0
if registration.SuspendedClients != nil {
suspendedClients = len(registration.SuspendedClients)
}
result := registration.Count - expiredClients - suspendedClients
if result < 0 {
return 0
}
return result
}
return 0
}
// GetModelProviders returns provider identifiers that currently supply the given model
// Parameters:
// - modelID: The model ID to check
//
// Returns:
// - []string: Provider identifiers ordered by availability count (descending)
func (r *ModelRegistry) GetModelProviders(modelID string) []string {
r.mutex.RLock()
defer r.mutex.RUnlock()
registration, exists := r.models[modelID]
if !exists || registration == nil || len(registration.Providers) == 0 {
return nil
}
type providerCount struct {
name string
count int
}
providers := make([]providerCount, 0, len(registration.Providers))
// suspendedByProvider := make(map[string]int)
// if registration.SuspendedClients != nil {
// for clientID := range registration.SuspendedClients {
// if provider, ok := r.clientProviders[clientID]; ok && provider != "" {
// suspendedByProvider[provider]++
// }
// }
// }
for name, count := range registration.Providers {
if count <= 0 {
continue
}
// adjusted := count - suspendedByProvider[name]
// if adjusted <= 0 {
// continue
// }
// providers = append(providers, providerCount{name: name, count: adjusted})
providers = append(providers, providerCount{name: name, count: count})
}
if len(providers) == 0 {
return nil
}
sort.Slice(providers, func(i, j int) bool {
if providers[i].count == providers[j].count {
return providers[i].name < providers[j].name
}
return providers[i].count > providers[j].count
})
result := make([]string, 0, len(providers))
for _, item := range providers {
result = append(result, item.name)
}
return result
}
// GetModelInfo returns ModelInfo, prioritizing provider-specific definition if available.
func (r *ModelRegistry) GetModelInfo(modelID, provider string) *ModelInfo {
r.mutex.RLock()
defer r.mutex.RUnlock()
if reg, ok := r.models[modelID]; ok && reg != nil {
// Try provider specific definition first
if provider != "" && reg.InfoByProvider != nil {
if reg.Providers != nil {
if count, ok := reg.Providers[provider]; ok && count > 0 {
if info, ok := reg.InfoByProvider[provider]; ok && info != nil {
return cloneModelInfo(info)
}
}
}
}
// Fallback to global info (last registered)
return cloneModelInfo(reg.Info)
}
return nil
}
// convertModelToMap converts ModelInfo to the appropriate format for different handler types
func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) map[string]any {
if model == nil {
return nil
}
switch handlerType {
case "openai":
result := map[string]any{
"id": model.ID,
"object": "model",
"owned_by": model.OwnedBy,
}
if model.Created > 0 {
result["created"] = model.Created
}
if model.Type != "" {
result["type"] = model.Type
}
if model.DisplayName != "" {
result["display_name"] = model.DisplayName
}
if model.Version != "" {
result["version"] = model.Version
}
if model.Description != "" {
result["description"] = model.Description
}
if model.ContextLength > 0 {
result["context_length"] = model.ContextLength
}
if model.MaxCompletionTokens > 0 {
result["max_completion_tokens"] = model.MaxCompletionTokens
}
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = append([]string(nil), model.SupportedParameters...)
}
return result
case "claude":
result := map[string]any{
"id": model.ID,
"object": "model",
"owned_by": model.OwnedBy,
}
if model.Created > 0 {
result["created_at"] = model.Created
}
if model.Type != "" {
result["type"] = "model"
}
if model.DisplayName != "" {
result["display_name"] = model.DisplayName
}
return result
case "gemini":
result := map[string]any{}
if model.Name != "" {
result["name"] = model.Name
} else {
result["name"] = model.ID
}
if model.Version != "" {
result["version"] = model.Version
}
if model.DisplayName != "" {
result["displayName"] = model.DisplayName
}
if model.Description != "" {
result["description"] = model.Description
}
if model.InputTokenLimit > 0 {
result["inputTokenLimit"] = model.InputTokenLimit
}
if model.OutputTokenLimit > 0 {
result["outputTokenLimit"] = model.OutputTokenLimit
}
if len(model.SupportedGenerationMethods) > 0 {
result["supportedGenerationMethods"] = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedInputModalities) > 0 {
result["supportedInputModalities"] = append([]string(nil), model.SupportedInputModalities...)
}
if len(model.SupportedOutputModalities) > 0 {
result["supportedOutputModalities"] = append([]string(nil), model.SupportedOutputModalities...)
}
return result
default:
// Generic format
result := map[string]any{
"id": model.ID,
"object": "model",
}
if model.OwnedBy != "" {
result["owned_by"] = model.OwnedBy
}
if model.Type != "" {
result["type"] = model.Type
}
if model.Created != 0 {
result["created"] = model.Created
}
return result
}
}
// CleanupExpiredQuotas removes expired quota tracking entries
func (r *ModelRegistry) CleanupExpiredQuotas() {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
invalidated := false
for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= modelQuotaExceededWindow {
delete(registration.QuotaExceededClients, clientID)
invalidated = true
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
}
}
}
if invalidated {
r.invalidateAvailableModelsCacheLocked()
}
}
// GetFirstAvailableModel returns the first available model for the given handler type.
// It prioritizes models by their creation timestamp (newest first) and checks if they have
// available clients that are not suspended or over quota.
//
// Parameters:
// - handlerType: The API handler type (e.g., "openai", "claude", "gemini")
//
// Returns:
// - string: The model ID of the first available model, or empty string if none available
// - error: An error if no models are available
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
// Get all available models for this handler type
models := r.GetAvailableModels(handlerType)
if len(models) == 0 {
return "", fmt.Errorf("no models available for handler type: %s", handlerType)
}
// Sort models by creation timestamp (newest first)
sort.Slice(models, func(i, j int) bool {
// Extract created timestamps from map
createdI, okI := models[i]["created"].(int64)
createdJ, okJ := models[j]["created"].(int64)
if !okI || !okJ {
return false
}
return createdI > createdJ
})
// Find the first model with available clients
for _, model := range models {
if modelID, ok := model["id"].(string); ok {
if count := r.GetModelCount(modelID); count > 0 {
return modelID, nil
}
}
}
return "", fmt.Errorf("no available clients for any model in handler type: %s", handlerType)
}
// GetModelsForClient returns the models registered for a specific client.
// Parameters:
// - clientID: The client identifier (typically auth file name or auth ID)
//
// Returns:
// - []*ModelInfo: List of models registered for this client, nil if client not found
func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
r.mutex.RLock()
defer r.mutex.RUnlock()
modelIDs, exists := r.clientModels[clientID]
if !exists || len(modelIDs) == 0 {
return nil
}
// Try to use client-specific model infos first
clientInfos := r.clientModelInfos[clientID]
seen := make(map[string]struct{})
result := make([]*ModelInfo, 0, len(modelIDs))
for _, modelID := range modelIDs {
if _, dup := seen[modelID]; dup {
continue
}
seen[modelID] = struct{}{}
// Prefer client's own model info to preserve original type/owned_by
if clientInfos != nil {
if info, ok := clientInfos[modelID]; ok && info != nil {
result = append(result, cloneModelInfo(info))
continue
}
}
// Fallback to global registry (for backwards compatibility)
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
result = append(result, cloneModelInfo(reg.Info))
}
}
return result
}
================================================
FILE: internal/registry/model_registry_cache_test.go
================================================
package registry
import "testing"
func TestGetAvailableModelsReturnsClonedSnapshots(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected 1 model, got %d", len(first))
}
first[0]["id"] = "mutated"
first[0]["display_name"] = "Mutated"
second := r.GetAvailableModels("openai")
if got := second[0]["id"]; got != "m1" {
t.Fatalf("expected cached snapshot to stay isolated, got id %v", got)
}
if got := second[0]["display_name"]; got != "Model One" {
t.Fatalf("expected cached snapshot to stay isolated, got display_name %v", got)
}
}
func TestGetAvailableModelsInvalidatesCacheOnRegistryChanges(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One"}})
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected 1 model, got %d", len(models))
}
if got := models[0]["display_name"]; got != "Model One" {
t.Fatalf("expected initial display_name Model One, got %v", got)
}
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1", OwnedBy: "team-a", DisplayName: "Model One Updated"}})
models = r.GetAvailableModels("openai")
if got := models[0]["display_name"]; got != "Model One Updated" {
t.Fatalf("expected updated display_name after cache invalidation, got %v", got)
}
r.SuspendClientModel("client-1", "m1", "manual")
models = r.GetAvailableModels("openai")
if len(models) != 0 {
t.Fatalf("expected no available models after suspension, got %d", len(models))
}
r.ResumeClientModel("client-1", "m1")
models = r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to reappear after resume, got %d", len(models))
}
}
================================================
FILE: internal/registry/model_registry_hook_test.go
================================================
package registry
import (
"context"
"sync"
"testing"
"time"
)
func newTestModelRegistry() *ModelRegistry {
return &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
clientModelInfos: make(map[string]map[string]*ModelInfo),
clientProviders: make(map[string]string),
mutex: &sync.RWMutex{},
}
}
type registeredCall struct {
provider string
clientID string
models []*ModelInfo
}
type unregisteredCall struct {
provider string
clientID string
}
type capturingHook struct {
registeredCh chan registeredCall
unregisteredCh chan unregisteredCall
}
func (h *capturingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
h.registeredCh <- registeredCall{provider: provider, clientID: clientID, models: models}
}
func (h *capturingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
h.unregisteredCh <- unregisteredCall{provider: provider, clientID: clientID}
}
func TestModelRegistryHook_OnModelsRegisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
inputModels := []*ModelInfo{
{ID: "m1", DisplayName: "Model One"},
{ID: "m2", DisplayName: "Model Two"},
}
r.RegisterClient("client-1", "OpenAI", inputModels)
select {
case call := <-hook.registeredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
if len(call.models) != 2 {
t.Fatalf("models length mismatch: got %d, want %d", len(call.models), 2)
}
if call.models[0] == nil || call.models[0].ID != "m1" {
t.Fatalf("models[0] mismatch: got %#v, want ID=%q", call.models[0], "m1")
}
if call.models[1] == nil || call.models[1].ID != "m2" {
t.Fatalf("models[1] mismatch: got %#v, want ID=%q", call.models[1], "m2")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
}
func TestModelRegistryHook_OnModelsUnregisteredCalled(t *testing.T) {
r := newTestModelRegistry()
hook := &capturingHook{
registeredCh: make(chan registeredCall, 1),
unregisteredCh: make(chan unregisteredCall, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCh:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
r.UnregisterClient("client-1")
select {
case call := <-hook.unregisteredCh:
if call.provider != "openai" {
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
}
if call.clientID != "client-1" {
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
}
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}
type blockingHook struct {
started chan struct{}
unblock chan struct{}
}
func (h *blockingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
select {
case <-h.started:
default:
close(h.started)
}
<-h.unblock
}
func (h *blockingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {}
func TestModelRegistryHook_DoesNotBlockRegisterClient(t *testing.T) {
r := newTestModelRegistry()
hook := &blockingHook{
started: make(chan struct{}),
unblock: make(chan struct{}),
}
r.SetHook(hook)
defer close(hook.unblock)
done := make(chan struct{})
go func() {
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
close(done)
}()
select {
case <-hook.started:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for hook to start")
}
select {
case <-done:
case <-time.After(200 * time.Millisecond):
t.Fatal("RegisterClient appears to be blocked by hook")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
}
type panicHook struct {
registeredCalled chan struct{}
unregisteredCalled chan struct{}
}
func (h *panicHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
if h.registeredCalled != nil {
h.registeredCalled <- struct{}{}
}
panic("boom")
}
func (h *panicHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
if h.unregisteredCalled != nil {
h.unregisteredCalled <- struct{}{}
}
panic("boom")
}
func TestModelRegistryHook_PanicDoesNotAffectRegistry(t *testing.T) {
r := newTestModelRegistry()
hook := &panicHook{
registeredCalled: make(chan struct{}, 1),
unregisteredCalled: make(chan struct{}, 1),
}
r.SetHook(hook)
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
select {
case <-hook.registeredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsRegistered hook call")
}
if !r.ClientSupportsModel("client-1", "m1") {
t.Fatal("model registration failed; expected client to support model")
}
r.UnregisterClient("client-1")
select {
case <-hook.unregisteredCalled:
case <-time.After(2 * time.Second):
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
}
}
================================================
FILE: internal/registry/model_registry_safety_test.go
================================================
package registry
import (
"testing"
"time"
)
func TestGetModelInfoReturnsClone(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Min: 1, Max: 2, Levels: []string{"low", "high"}},
}})
first := r.GetModelInfo("m1", "gemini")
if first == nil {
t.Fatal("expected model info")
}
first.DisplayName = "mutated"
first.Thinking.Levels[0] = "mutated"
second := r.GetModelInfo("m1", "gemini")
if second.DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second.DisplayName)
}
if second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second.Thinking)
}
}
func TestGetModelsForClientReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetModelsForClient("client-1")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetModelsForClient("client-1")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestGetAvailableModelsByProviderReturnsClones(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "gemini", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
Thinking: &ThinkingSupport{Levels: []string{"low", "high"}},
}})
first := r.GetAvailableModelsByProvider("gemini")
if len(first) != 1 || first[0] == nil {
t.Fatalf("expected one model, got %+v", first)
}
first[0].DisplayName = "mutated"
first[0].Thinking.Levels[0] = "mutated"
second := r.GetAvailableModelsByProvider("gemini")
if len(second) != 1 || second[0] == nil {
t.Fatalf("expected one model on second fetch, got %+v", second)
}
if second[0].DisplayName != "Model One" {
t.Fatalf("expected cloned display name, got %q", second[0].DisplayName)
}
if second[0].Thinking == nil || len(second[0].Thinking.Levels) == 0 || second[0].Thinking.Levels[0] != "low" {
t.Fatalf("expected cloned thinking levels, got %+v", second[0].Thinking)
}
}
func TestCleanupExpiredQuotasInvalidatesAvailableModelsCache(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{ID: "m1", Created: 1}})
r.SetModelQuotaExceeded("client-1", "m1")
if models := r.GetAvailableModels("openai"); len(models) != 1 {
t.Fatalf("expected cooldown model to remain listed before cleanup, got %d", len(models))
}
r.mutex.Lock()
quotaTime := time.Now().Add(-6 * time.Minute)
r.models["m1"].QuotaExceededClients["client-1"] = "aTime
r.mutex.Unlock()
r.CleanupExpiredQuotas()
if count := r.GetModelCount("m1"); count != 1 {
t.Fatalf("expected model count 1 after cleanup, got %d", count)
}
models := r.GetAvailableModels("openai")
if len(models) != 1 {
t.Fatalf("expected model to stay available after cleanup, got %d", len(models))
}
if got := models[0]["id"]; got != "m1" {
t.Fatalf("expected model id m1, got %v", got)
}
}
func TestGetAvailableModelsReturnsClonedSupportedParameters(t *testing.T) {
r := newTestModelRegistry()
r.RegisterClient("client-1", "openai", []*ModelInfo{{
ID: "m1",
DisplayName: "Model One",
SupportedParameters: []string{"temperature", "top_p"},
}})
first := r.GetAvailableModels("openai")
if len(first) != 1 {
t.Fatalf("expected one model, got %d", len(first))
}
params, ok := first[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 {
t.Fatalf("expected supported_parameters slice, got %#v", first[0]["supported_parameters"])
}
params[0] = "mutated"
second := r.GetAvailableModels("openai")
params, ok = second[0]["supported_parameters"].([]string)
if !ok || len(params) != 2 || params[0] != "temperature" {
t.Fatalf("expected cloned supported_parameters, got %#v", second[0]["supported_parameters"])
}
}
func TestLookupModelInfoReturnsCloneForStaticDefinitions(t *testing.T) {
first := LookupModelInfo("glm-4.6")
if first == nil || first.Thinking == nil || len(first.Thinking.Levels) == 0 {
t.Fatalf("expected static model with thinking levels, got %+v", first)
}
first.Thinking.Levels[0] = "mutated"
second := LookupModelInfo("glm-4.6")
if second == nil || second.Thinking == nil || len(second.Thinking.Levels) == 0 || second.Thinking.Levels[0] == "mutated" {
t.Fatalf("expected static lookup clone, got %+v", second)
}
}
================================================
FILE: internal/registry/model_updater.go
================================================
package registry
import (
"context"
_ "embed"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
modelsFetchTimeout = 30 * time.Second
modelsRefreshInterval = 3 * time.Hour
)
var modelsURLs = []string{
"https://raw.githubusercontent.com/router-for-me/models/refs/heads/main/models.json",
"https://models.router-for.me/models.json",
}
//go:embed models/models.json
var embeddedModelsJSON []byte
type modelStore struct {
mu sync.RWMutex
data *staticModelsJSON
}
var modelsCatalogStore = &modelStore{}
var updaterOnce sync.Once
// ModelRefreshCallback is invoked when startup or periodic model refresh detects changes.
// changedProviders contains the provider names whose model definitions changed.
type ModelRefreshCallback func(changedProviders []string)
var (
refreshCallbackMu sync.Mutex
refreshCallback ModelRefreshCallback
pendingRefreshChanges []string
)
// SetModelRefreshCallback registers a callback that is invoked when startup or
// periodic model refresh detects changes. Only one callback is supported;
// subsequent calls replace the previous callback.
func SetModelRefreshCallback(cb ModelRefreshCallback) {
refreshCallbackMu.Lock()
refreshCallback = cb
var pending []string
if cb != nil && len(pendingRefreshChanges) > 0 {
pending = append([]string(nil), pendingRefreshChanges...)
pendingRefreshChanges = nil
}
refreshCallbackMu.Unlock()
if cb != nil && len(pending) > 0 {
cb(pending)
}
}
func init() {
// Load embedded data as fallback on startup.
if err := loadModelsFromBytes(embeddedModelsJSON, "embed"); err != nil {
panic(fmt.Sprintf("registry: failed to parse embedded models.json: %v", err))
}
}
// StartModelsUpdater starts a background updater that fetches models
// immediately on startup and then refreshes the model catalog every 3 hours.
// Safe to call multiple times; only one updater will run.
func StartModelsUpdater(ctx context.Context) {
updaterOnce.Do(func() {
go runModelsUpdater(ctx)
})
}
func runModelsUpdater(ctx context.Context) {
tryStartupRefresh(ctx)
periodicRefresh(ctx)
}
func periodicRefresh(ctx context.Context) {
ticker := time.NewTicker(modelsRefreshInterval)
defer ticker.Stop()
log.Infof("periodic model refresh started (interval=%s)", modelsRefreshInterval)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
tryPeriodicRefresh(ctx)
}
}
}
// tryPeriodicRefresh fetches models from remote, compares with the current
// catalog, and notifies the registered callback if any provider changed.
func tryPeriodicRefresh(ctx context.Context) {
tryRefreshModels(ctx, "periodic model refresh")
}
// tryStartupRefresh fetches models from remote in the background during
// process startup. It uses the same change detection as periodic refresh so
// existing auth registrations can be updated after the callback is registered.
func tryStartupRefresh(ctx context.Context) {
tryRefreshModels(ctx, "startup model refresh")
}
func tryRefreshModels(ctx context.Context, label string) {
oldData := getModels()
parsed, url := fetchModelsFromRemote(ctx)
if parsed == nil {
log.Warnf("%s: fetch failed from all URLs, keeping current data", label)
return
}
// Detect changes before updating store.
changed := detectChangedProviders(oldData, parsed)
// Update store with new data regardless.
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = parsed
modelsCatalogStore.mu.Unlock()
if len(changed) == 0 {
log.Infof("%s completed from %s, no changes detected", label, url)
return
}
log.Infof("%s completed from %s, changes detected for providers: %v", label, url, changed)
notifyModelRefresh(changed)
}
// fetchModelsFromRemote tries all remote URLs and returns the parsed model catalog
// along with the URL it was fetched from. Returns (nil, "") if all fetches fail.
func fetchModelsFromRemote(ctx context.Context) (*staticModelsJSON, string) {
client := &http.Client{Timeout: modelsFetchTimeout}
for _, url := range modelsURLs {
reqCtx, cancel := context.WithTimeout(ctx, modelsFetchTimeout)
req, err := http.NewRequestWithContext(reqCtx, "GET", url, nil)
if err != nil {
cancel()
log.Debugf("models fetch request creation failed for %s: %v", url, err)
continue
}
resp, err := client.Do(req)
if err != nil {
cancel()
log.Debugf("models fetch failed from %s: %v", url, err)
continue
}
if resp.StatusCode != 200 {
resp.Body.Close()
cancel()
log.Debugf("models fetch returned %d from %s", resp.StatusCode, url)
continue
}
data, err := io.ReadAll(resp.Body)
resp.Body.Close()
cancel()
if err != nil {
log.Debugf("models fetch read error from %s: %v", url, err)
continue
}
var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
log.Warnf("models parse failed from %s: %v", url, err)
continue
}
if err := validateModelsCatalog(&parsed); err != nil {
log.Warnf("models validate failed from %s: %v", url, err)
continue
}
return &parsed, url
}
return nil, ""
}
// detectChangedProviders compares two model catalogs and returns provider names
// whose model definitions differ. Codex tiers (free/team/plus/pro) are grouped
// under a single "codex" provider.
func detectChangedProviders(oldData, newData *staticModelsJSON) []string {
if oldData == nil || newData == nil {
return nil
}
type section struct {
provider string
oldList []*ModelInfo
newList []*ModelInfo
}
sections := []section{
{"claude", oldData.Claude, newData.Claude},
{"gemini", oldData.Gemini, newData.Gemini},
{"vertex", oldData.Vertex, newData.Vertex},
{"gemini-cli", oldData.GeminiCLI, newData.GeminiCLI},
{"aistudio", oldData.AIStudio, newData.AIStudio},
{"codex", oldData.CodexFree, newData.CodexFree},
{"codex", oldData.CodexTeam, newData.CodexTeam},
{"codex", oldData.CodexPlus, newData.CodexPlus},
{"codex", oldData.CodexPro, newData.CodexPro},
{"qwen", oldData.Qwen, newData.Qwen},
{"iflow", oldData.IFlow, newData.IFlow},
{"kimi", oldData.Kimi, newData.Kimi},
{"antigravity", oldData.Antigravity, newData.Antigravity},
}
seen := make(map[string]bool, len(sections))
var changed []string
for _, s := range sections {
if seen[s.provider] {
continue
}
if modelSectionChanged(s.oldList, s.newList) {
changed = append(changed, s.provider)
seen[s.provider] = true
}
}
return changed
}
// modelSectionChanged reports whether two model slices differ.
func modelSectionChanged(a, b []*ModelInfo) bool {
if len(a) != len(b) {
return true
}
if len(a) == 0 {
return false
}
aj, err1 := json.Marshal(a)
bj, err2 := json.Marshal(b)
if err1 != nil || err2 != nil {
return true
}
return string(aj) != string(bj)
}
func notifyModelRefresh(changedProviders []string) {
if len(changedProviders) == 0 {
return
}
refreshCallbackMu.Lock()
cb := refreshCallback
if cb == nil {
pendingRefreshChanges = mergeProviderNames(pendingRefreshChanges, changedProviders)
refreshCallbackMu.Unlock()
return
}
refreshCallbackMu.Unlock()
cb(changedProviders)
}
func mergeProviderNames(existing, incoming []string) []string {
if len(incoming) == 0 {
return existing
}
seen := make(map[string]struct{}, len(existing)+len(incoming))
merged := make([]string, 0, len(existing)+len(incoming))
for _, provider := range existing {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
for _, provider := range incoming {
name := strings.ToLower(strings.TrimSpace(provider))
if name == "" {
continue
}
if _, ok := seen[name]; ok {
continue
}
seen[name] = struct{}{}
merged = append(merged, name)
}
return merged
}
func loadModelsFromBytes(data []byte, source string) error {
var parsed staticModelsJSON
if err := json.Unmarshal(data, &parsed); err != nil {
return fmt.Errorf("%s: decode models catalog: %w", source, err)
}
if err := validateModelsCatalog(&parsed); err != nil {
return fmt.Errorf("%s: validate models catalog: %w", source, err)
}
modelsCatalogStore.mu.Lock()
modelsCatalogStore.data = &parsed
modelsCatalogStore.mu.Unlock()
return nil
}
func getModels() *staticModelsJSON {
modelsCatalogStore.mu.RLock()
defer modelsCatalogStore.mu.RUnlock()
return modelsCatalogStore.data
}
func validateModelsCatalog(data *staticModelsJSON) error {
if data == nil {
return fmt.Errorf("catalog is nil")
}
requiredSections := []struct {
name string
models []*ModelInfo
}{
{name: "claude", models: data.Claude},
{name: "gemini", models: data.Gemini},
{name: "vertex", models: data.Vertex},
{name: "gemini-cli", models: data.GeminiCLI},
{name: "aistudio", models: data.AIStudio},
{name: "codex-free", models: data.CodexFree},
{name: "codex-team", models: data.CodexTeam},
{name: "codex-plus", models: data.CodexPlus},
{name: "codex-pro", models: data.CodexPro},
{name: "qwen", models: data.Qwen},
{name: "iflow", models: data.IFlow},
{name: "kimi", models: data.Kimi},
{name: "antigravity", models: data.Antigravity},
}
for _, section := range requiredSections {
if err := validateModelSection(section.name, section.models); err != nil {
return err
}
}
return nil
}
func validateModelSection(section string, models []*ModelInfo) error {
if len(models) == 0 {
return fmt.Errorf("%s section is empty", section)
}
seen := make(map[string]struct{}, len(models))
for i, model := range models {
if model == nil {
return fmt.Errorf("%s[%d] is null", section, i)
}
modelID := strings.TrimSpace(model.ID)
if modelID == "" {
return fmt.Errorf("%s[%d] has empty id", section, i)
}
if _, exists := seen[modelID]; exists {
return fmt.Errorf("%s contains duplicate model id %q", section, modelID)
}
seen[modelID] = struct{}{}
}
return nil
}
================================================
FILE: internal/registry/models/models.json
================================================
{
"claude": [
{
"id": "claude-haiku-4-5-20251001",
"object": "model",
"created": 1759276800,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 4.5 Haiku",
"context_length": 200000,
"max_completion_tokens": 64000,
"thinking": {
"min": 1024,
"max": 128000,
"zero_allowed": true
}
},
{
"id": "claude-sonnet-4-5-20250929",
"object": "model",
"created": 1759104000,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 4.5 Sonnet",
"context_length": 200000,
"max_completion_tokens": 64000,
"thinking": {
"min": 1024,
"max": 128000,
"zero_allowed": true
}
},
{
"id": "claude-sonnet-4-6",
"object": "model",
"created": 1771372800,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 4.6 Sonnet",
"context_length": 200000,
"max_completion_tokens": 64000,
"thinking": {
"min": 1024,
"max": 128000,
"zero_allowed": true,
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "claude-opus-4-6",
"object": "model",
"created": 1770318000,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 4.6 Opus",
"description": "Premium model combining maximum intelligence with practical performance",
"context_length": 1000000,
"max_completion_tokens": 128000,
"thinking": {
"min": 1024,
"max": 128000,
"zero_allowed": true,
"levels": [
"low",
"medium",
"high",
"max"
]
}
},
{
"id": "claude-opus-4-5-20251101",
"object": "model",
"created": 1761955200,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 4.5 Opus",
"description": "Premium model combining maximum intelligence with practical performance",
"context_length": 200000,
"max_completion_tokens": 64000,
"thinking": {
"min": 1024,
"max": 128000,
"zero_allowed": true
}
},
{
"id": "claude-opus-4-1-20250805",
"object": "model",
"created": 1722945600,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 4.1 Opus",
"context_length": 200000,
"max_completion_tokens": 32000,
"thinking": {
"min": 1024,
"max": 128000
}
},
{
"id": "claude-opus-4-20250514",
"object": "model",
"created": 1715644800,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 4 Opus",
"context_length": 200000,
"max_completion_tokens": 32000,
"thinking": {
"min": 1024,
"max": 128000
}
},
{
"id": "claude-sonnet-4-20250514",
"object": "model",
"created": 1715644800,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 4 Sonnet",
"context_length": 200000,
"max_completion_tokens": 64000,
"thinking": {
"min": 1024,
"max": 128000
}
},
{
"id": "claude-3-7-sonnet-20250219",
"object": "model",
"created": 1708300800,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 3.7 Sonnet",
"context_length": 128000,
"max_completion_tokens": 8192,
"thinking": {
"min": 1024,
"max": 128000
}
},
{
"id": "claude-3-5-haiku-20241022",
"object": "model",
"created": 1729555200,
"owned_by": "anthropic",
"type": "claude",
"display_name": "Claude 3.5 Haiku",
"context_length": 128000,
"max_completion_tokens": 8192
}
],
"gemini": [
{
"id": "gemini-2.5-pro",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Pro",
"name": "models/gemini-2.5-pro",
"version": "2.5",
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Flash",
"name": "models/gemini-2.5-flash",
"version": "001",
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash-lite",
"object": "model",
"created": 1753142400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Flash Lite",
"name": "models/gemini-2.5-flash-lite",
"version": "2.5",
"description": "Our smallest and most cost effective model, built for at scale usage.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-3-pro-preview",
"object": "model",
"created": 1737158400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Pro Preview",
"name": "models/gemini-3-pro-preview",
"version": "3.0",
"description": "Gemini 3 Pro Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gemini-3.1-pro-preview",
"object": "model",
"created": 1771459200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Pro Preview",
"name": "models/gemini-3.1-pro-preview",
"version": "3.1",
"description": "Gemini 3.1 Pro Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gemini-3.1-flash-image-preview",
"object": "model",
"created": 1771459200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Flash Image Preview",
"name": "models/gemini-3.1-flash-image-preview",
"version": "3.1",
"description": "Gemini 3.1 Flash Image Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"high"
]
}
},
{
"id": "gemini-3-flash-preview",
"object": "model",
"created": 1765929600,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Flash Preview",
"name": "models/gemini-3-flash-preview",
"version": "3.0",
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
},
{
"id": "gemini-3.1-flash-lite-preview",
"object": "model",
"created": 1776288000,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Flash Lite Preview",
"name": "models/gemini-3.1-flash-lite-preview",
"version": "3.1",
"description": "Our smallest and most cost effective model, built for at scale usage.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"high"
]
}
},
{
"id": "gemini-3-pro-image-preview",
"object": "model",
"created": 1737158400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Pro Image Preview",
"name": "models/gemini-3-pro-image-preview",
"version": "3.0",
"description": "Gemini 3 Pro Image Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
}
],
"vertex": [
{
"id": "gemini-2.5-pro",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Pro",
"name": "models/gemini-2.5-pro",
"version": "2.5",
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Flash",
"name": "models/gemini-2.5-flash",
"version": "001",
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash-lite",
"object": "model",
"created": 1753142400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Flash Lite",
"name": "models/gemini-2.5-flash-lite",
"version": "2.5",
"description": "Our smallest and most cost effective model, built for at scale usage.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-3-pro-preview",
"object": "model",
"created": 1737158400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Pro Preview",
"name": "models/gemini-3-pro-preview",
"version": "3.0",
"description": "Gemini 3 Pro Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gemini-3-flash-preview",
"object": "model",
"created": 1765929600,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Flash Preview",
"name": "models/gemini-3-flash-preview",
"version": "3.0",
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
},
{
"id": "gemini-3.1-pro-preview",
"object": "model",
"created": 1771459200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Pro Preview",
"name": "models/gemini-3.1-pro-preview",
"version": "3.1",
"description": "Gemini 3.1 Pro Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gemini-3.1-flash-image-preview",
"object": "model",
"created": 1771459200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Flash Image Preview",
"name": "models/gemini-3.1-flash-image-preview",
"version": "3.1",
"description": "Gemini 3.1 Flash Image Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"high"
]
}
},
{
"id": "gemini-3.1-flash-lite-preview",
"object": "model",
"created": 1776288000,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Flash Lite Preview",
"name": "models/gemini-3.1-flash-lite-preview",
"version": "3.1",
"description": "Our smallest and most cost effective model, built for at scale usage.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"high"
]
}
},
{
"id": "gemini-3-pro-image-preview",
"object": "model",
"created": 1737158400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Pro Image Preview",
"name": "models/gemini-3-pro-image-preview",
"version": "3.0",
"description": "Gemini 3 Pro Image Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "imagen-4.0-generate-001",
"object": "model",
"created": 1750000000,
"owned_by": "google",
"type": "gemini",
"display_name": "Imagen 4.0 Generate",
"name": "models/imagen-4.0-generate-001",
"version": "4.0",
"description": "Imagen 4.0 image generation model",
"supportedGenerationMethods": [
"predict"
]
},
{
"id": "imagen-4.0-ultra-generate-001",
"object": "model",
"created": 1750000000,
"owned_by": "google",
"type": "gemini",
"display_name": "Imagen 4.0 Ultra Generate",
"name": "models/imagen-4.0-ultra-generate-001",
"version": "4.0",
"description": "Imagen 4.0 Ultra high-quality image generation model",
"supportedGenerationMethods": [
"predict"
]
},
{
"id": "imagen-3.0-generate-002",
"object": "model",
"created": 1740000000,
"owned_by": "google",
"type": "gemini",
"display_name": "Imagen 3.0 Generate",
"name": "models/imagen-3.0-generate-002",
"version": "3.0",
"description": "Imagen 3.0 image generation model",
"supportedGenerationMethods": [
"predict"
]
},
{
"id": "imagen-3.0-fast-generate-001",
"object": "model",
"created": 1740000000,
"owned_by": "google",
"type": "gemini",
"display_name": "Imagen 3.0 Fast Generate",
"name": "models/imagen-3.0-fast-generate-001",
"version": "3.0",
"description": "Imagen 3.0 fast image generation model",
"supportedGenerationMethods": [
"predict"
]
},
{
"id": "imagen-4.0-fast-generate-001",
"object": "model",
"created": 1750000000,
"owned_by": "google",
"type": "gemini",
"display_name": "Imagen 4.0 Fast Generate",
"name": "models/imagen-4.0-fast-generate-001",
"version": "4.0",
"description": "Imagen 4.0 fast image generation model",
"supportedGenerationMethods": [
"predict"
]
}
],
"gemini-cli": [
{
"id": "gemini-2.5-pro",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Pro",
"name": "models/gemini-2.5-pro",
"version": "2.5",
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Flash",
"name": "models/gemini-2.5-flash",
"version": "001",
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash-lite",
"object": "model",
"created": 1753142400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Flash Lite",
"name": "models/gemini-2.5-flash-lite",
"version": "2.5",
"description": "Our smallest and most cost effective model, built for at scale usage.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-3-pro-preview",
"object": "model",
"created": 1737158400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Pro Preview",
"name": "models/gemini-3-pro-preview",
"version": "3.0",
"description": "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gemini-3.1-pro-preview",
"object": "model",
"created": 1771459200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Pro Preview",
"name": "models/gemini-3.1-pro-preview",
"version": "3.1",
"description": "Gemini 3.1 Pro Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gemini-3-flash-preview",
"object": "model",
"created": 1765929600,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Flash Preview",
"name": "models/gemini-3-flash-preview",
"version": "3.0",
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
},
{
"id": "gemini-3.1-flash-lite-preview",
"object": "model",
"created": 1776288000,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Flash Lite Preview",
"name": "models/gemini-3.1-flash-lite-preview",
"version": "3.1",
"description": "Our smallest and most cost effective model, built for at scale usage.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"high"
]
}
}
],
"aistudio": [
{
"id": "gemini-2.5-pro",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Pro",
"name": "models/gemini-2.5-pro",
"version": "2.5",
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Flash",
"name": "models/gemini-2.5-flash",
"version": "001",
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash-lite",
"object": "model",
"created": 1753142400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Flash Lite",
"name": "models/gemini-2.5-flash-lite",
"version": "2.5",
"description": "Our smallest and most cost effective model, built for at scale usage.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-3-pro-preview",
"object": "model",
"created": 1737158400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Pro Preview",
"name": "models/gemini-3-pro-preview",
"version": "3.0",
"description": "Gemini 3 Pro Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true
}
},
{
"id": "gemini-3.1-pro-preview",
"object": "model",
"created": 1771459200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Pro Preview",
"name": "models/gemini-3.1-pro-preview",
"version": "3.1",
"description": "Gemini 3.1 Pro Preview",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true
}
},
{
"id": "gemini-3-flash-preview",
"object": "model",
"created": 1765929600,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3 Flash Preview",
"name": "models/gemini-3-flash-preview",
"version": "3.0",
"description": "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true
}
},
{
"id": "gemini-3.1-flash-lite-preview",
"object": "model",
"created": 1776288000,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 3.1 Flash Lite Preview",
"name": "models/gemini-3.1-flash-lite-preview",
"version": "3.1",
"description": "Our smallest and most cost effective model, built for at scale usage.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"high"
]
}
},
{
"id": "gemini-pro-latest",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini Pro Latest",
"name": "models/gemini-pro-latest",
"version": "2.5",
"description": "Latest release of Gemini Pro",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true
}
},
{
"id": "gemini-flash-latest",
"object": "model",
"created": 1750118400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini Flash Latest",
"name": "models/gemini-flash-latest",
"version": "2.5",
"description": "Latest release of Gemini Flash",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-flash-lite-latest",
"object": "model",
"created": 1753142400,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini Flash-Lite Latest",
"name": "models/gemini-flash-lite-latest",
"version": "2.5",
"description": "Latest release of Gemini Flash-Lite",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
],
"thinking": {
"min": 512,
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash-image",
"object": "model",
"created": 1759363200,
"owned_by": "google",
"type": "gemini",
"display_name": "Gemini 2.5 Flash Image",
"name": "models/gemini-2.5-flash-image",
"version": "2.5",
"description": "State-of-the-art image generation and editing model.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 8192,
"supportedGenerationMethods": [
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent"
]
}
],
"codex-free": [
{
"id": "gpt-5",
"object": "model",
"created": 1754524800,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5",
"version": "gpt-5-2025-08-07",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5-codex",
"object": "model",
"created": 1757894400,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5 Codex",
"version": "gpt-5-2025-09-15",
"description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5-codex-mini",
"object": "model",
"created": 1762473600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5 Codex Mini",
"version": "gpt-5-2025-11-07",
"description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"none",
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex-mini",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex Mini",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex-max",
"object": "model",
"created": 1763424000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex Max",
"version": "gpt-5.1-max",
"description": "Stable version of GPT 5.1 Codex Max",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.2",
"object": "model",
"created": 1765440000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.2",
"version": "gpt-5.2",
"description": "Stable version of GPT 5.2",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"none",
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.2-codex",
"object": "model",
"created": 1765440000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.2 Codex",
"version": "gpt-5.2",
"description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
}
],
"codex-team": [
{
"id": "gpt-5",
"object": "model",
"created": 1754524800,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5",
"version": "gpt-5-2025-08-07",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5-codex",
"object": "model",
"created": 1757894400,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5 Codex",
"version": "gpt-5-2025-09-15",
"description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5-codex-mini",
"object": "model",
"created": 1762473600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5 Codex Mini",
"version": "gpt-5-2025-11-07",
"description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"none",
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex-mini",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex Mini",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex-max",
"object": "model",
"created": 1763424000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex Max",
"version": "gpt-5.1-max",
"description": "Stable version of GPT 5.1 Codex Max",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.2",
"object": "model",
"created": 1765440000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.2",
"version": "gpt-5.2",
"description": "Stable version of GPT 5.2",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"none",
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.2-codex",
"object": "model",
"created": 1765440000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.2 Codex",
"version": "gpt-5.2",
"description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.3-codex",
"object": "model",
"created": 1770307200,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.3 Codex",
"version": "gpt-5.3",
"description": "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.4",
"object": "model",
"created": 1772668800,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.4",
"version": "gpt-5.4",
"description": "Stable version of GPT 5.4",
"context_length": 1050000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
}
],
"codex-plus": [
{
"id": "gpt-5",
"object": "model",
"created": 1754524800,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5",
"version": "gpt-5-2025-08-07",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5-codex",
"object": "model",
"created": 1757894400,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5 Codex",
"version": "gpt-5-2025-09-15",
"description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5-codex-mini",
"object": "model",
"created": 1762473600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5 Codex Mini",
"version": "gpt-5-2025-11-07",
"description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"none",
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex-mini",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex Mini",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex-max",
"object": "model",
"created": 1763424000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex Max",
"version": "gpt-5.1-max",
"description": "Stable version of GPT 5.1 Codex Max",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.2",
"object": "model",
"created": 1765440000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.2",
"version": "gpt-5.2",
"description": "Stable version of GPT 5.2",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"none",
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.2-codex",
"object": "model",
"created": 1765440000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.2 Codex",
"version": "gpt-5.2",
"description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.3-codex",
"object": "model",
"created": 1770307200,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.3 Codex",
"version": "gpt-5.3",
"description": "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.3-codex-spark",
"object": "model",
"created": 1770912000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.3 Codex Spark",
"version": "gpt-5.3",
"description": "Ultra-fast coding model.",
"context_length": 128000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.4",
"object": "model",
"created": 1772668800,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.4",
"version": "gpt-5.4",
"description": "Stable version of GPT 5.4",
"context_length": 1050000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
}
],
"codex-pro": [
{
"id": "gpt-5",
"object": "model",
"created": 1754524800,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5",
"version": "gpt-5-2025-08-07",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5-codex",
"object": "model",
"created": 1757894400,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5 Codex",
"version": "gpt-5-2025-09-15",
"description": "Stable version of GPT 5 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5-codex-mini",
"object": "model",
"created": 1762473600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5 Codex Mini",
"version": "gpt-5-2025-11-07",
"description": "Stable version of GPT 5 Codex Mini: cheaper, faster, but less capable version of GPT 5 Codex.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"none",
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5.1 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex-mini",
"object": "model",
"created": 1762905600,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex Mini",
"version": "gpt-5.1-2025-11-12",
"description": "Stable version of GPT 5.1 Codex Mini: cheaper, faster, but less capable version of GPT 5.1 Codex.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high"
]
}
},
{
"id": "gpt-5.1-codex-max",
"object": "model",
"created": 1763424000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.1 Codex Max",
"version": "gpt-5.1-max",
"description": "Stable version of GPT 5.1 Codex Max",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.2",
"object": "model",
"created": 1765440000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.2",
"version": "gpt-5.2",
"description": "Stable version of GPT 5.2",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"none",
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.2-codex",
"object": "model",
"created": 1765440000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.2 Codex",
"version": "gpt-5.2",
"description": "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.3-codex",
"object": "model",
"created": 1770307200,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.3 Codex",
"version": "gpt-5.3",
"description": "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.",
"context_length": 400000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.3-codex-spark",
"object": "model",
"created": 1770912000,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.3 Codex Spark",
"version": "gpt-5.3",
"description": "Ultra-fast coding model.",
"context_length": 128000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "gpt-5.4",
"object": "model",
"created": 1772668800,
"owned_by": "openai",
"type": "openai",
"display_name": "GPT 5.4",
"version": "gpt-5.4",
"description": "Stable version of GPT 5.4",
"context_length": 1050000,
"max_completion_tokens": 128000,
"supported_parameters": [
"tools"
],
"thinking": {
"levels": [
"low",
"medium",
"high",
"xhigh"
]
}
}
],
"qwen": [
{
"id": "qwen3-coder-plus",
"object": "model",
"created": 1753228800,
"owned_by": "qwen",
"type": "qwen",
"display_name": "Qwen3 Coder Plus",
"version": "3.0",
"description": "Advanced code generation and understanding model",
"context_length": 32768,
"max_completion_tokens": 8192,
"supported_parameters": [
"temperature",
"top_p",
"max_tokens",
"stream",
"stop"
]
},
{
"id": "qwen3-coder-flash",
"object": "model",
"created": 1753228800,
"owned_by": "qwen",
"type": "qwen",
"display_name": "Qwen3 Coder Flash",
"version": "3.0",
"description": "Fast code generation model",
"context_length": 8192,
"max_completion_tokens": 2048,
"supported_parameters": [
"temperature",
"top_p",
"max_tokens",
"stream",
"stop"
]
},
{
"id": "coder-model",
"object": "model",
"created": 1771171200,
"owned_by": "qwen",
"type": "qwen",
"display_name": "Qwen 3.5 Plus",
"version": "3.5",
"description": "efficient hybrid model with leading coding performance",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": [
"temperature",
"top_p",
"max_tokens",
"stream",
"stop"
]
},
{
"id": "vision-model",
"object": "model",
"created": 1758672000,
"owned_by": "qwen",
"type": "qwen",
"display_name": "Qwen3 Vision Model",
"version": "3.0",
"description": "Vision model model",
"context_length": 32768,
"max_completion_tokens": 2048,
"supported_parameters": [
"temperature",
"top_p",
"max_tokens",
"stream",
"stop"
]
}
],
"iflow": [
{
"id": "qwen3-coder-plus",
"object": "model",
"created": 1753228800,
"owned_by": "iflow",
"type": "iflow",
"display_name": "Qwen3-Coder-Plus",
"description": "Qwen3 Coder Plus code generation"
},
{
"id": "qwen3-max",
"object": "model",
"created": 1758672000,
"owned_by": "iflow",
"type": "iflow",
"display_name": "Qwen3-Max",
"description": "Qwen3 flagship model"
},
{
"id": "qwen3-vl-plus",
"object": "model",
"created": 1758672000,
"owned_by": "iflow",
"type": "iflow",
"display_name": "Qwen3-VL-Plus",
"description": "Qwen3 multimodal vision-language"
},
{
"id": "qwen3-max-preview",
"object": "model",
"created": 1757030400,
"owned_by": "iflow",
"type": "iflow",
"display_name": "Qwen3-Max-Preview",
"description": "Qwen3 Max preview build",
"thinking": {
"levels": [
"none",
"auto",
"minimal",
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "glm-4.6",
"object": "model",
"created": 1759190400,
"owned_by": "iflow",
"type": "iflow",
"display_name": "GLM-4.6",
"description": "Zhipu GLM 4.6 general model",
"thinking": {
"levels": [
"none",
"auto",
"minimal",
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "kimi-k2",
"object": "model",
"created": 1752192000,
"owned_by": "iflow",
"type": "iflow",
"display_name": "Kimi-K2",
"description": "Moonshot Kimi K2 general model"
},
{
"id": "deepseek-v3.2",
"object": "model",
"created": 1759104000,
"owned_by": "iflow",
"type": "iflow",
"display_name": "DeepSeek-V3.2-Exp",
"description": "DeepSeek V3.2 experimental",
"thinking": {
"levels": [
"none",
"auto",
"minimal",
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "deepseek-v3.1",
"object": "model",
"created": 1756339200,
"owned_by": "iflow",
"type": "iflow",
"display_name": "DeepSeek-V3.1-Terminus",
"description": "DeepSeek V3.1 Terminus",
"thinking": {
"levels": [
"none",
"auto",
"minimal",
"low",
"medium",
"high",
"xhigh"
]
}
},
{
"id": "deepseek-r1",
"object": "model",
"created": 1737331200,
"owned_by": "iflow",
"type": "iflow",
"display_name": "DeepSeek-R1",
"description": "DeepSeek reasoning model R1"
},
{
"id": "deepseek-v3",
"object": "model",
"created": 1734307200,
"owned_by": "iflow",
"type": "iflow",
"display_name": "DeepSeek-V3-671B",
"description": "DeepSeek V3 671B"
},
{
"id": "qwen3-32b",
"object": "model",
"created": 1747094400,
"owned_by": "iflow",
"type": "iflow",
"display_name": "Qwen3-32B",
"description": "Qwen3 32B"
},
{
"id": "qwen3-235b-a22b-thinking-2507",
"object": "model",
"created": 1753401600,
"owned_by": "iflow",
"type": "iflow",
"display_name": "Qwen3-235B-A22B-Thinking",
"description": "Qwen3 235B A22B Thinking (2507)"
},
{
"id": "qwen3-235b-a22b-instruct",
"object": "model",
"created": 1753401600,
"owned_by": "iflow",
"type": "iflow",
"display_name": "Qwen3-235B-A22B-Instruct",
"description": "Qwen3 235B A22B Instruct"
},
{
"id": "qwen3-235b",
"object": "model",
"created": 1753401600,
"owned_by": "iflow",
"type": "iflow",
"display_name": "Qwen3-235B-A22B",
"description": "Qwen3 235B A22B"
},
{
"id": "iflow-rome-30ba3b",
"object": "model",
"created": 1736899200,
"owned_by": "iflow",
"type": "iflow",
"display_name": "iFlow-ROME",
"description": "iFlow Rome 30BA3B model"
}
],
"kimi": [
{
"id": "kimi-k2",
"object": "model",
"created": 1752192000,
"owned_by": "moonshot",
"type": "kimi",
"display_name": "Kimi K2",
"description": "Kimi K2 - Moonshot AI's flagship coding model",
"context_length": 131072,
"max_completion_tokens": 32768
},
{
"id": "kimi-k2-thinking",
"object": "model",
"created": 1762387200,
"owned_by": "moonshot",
"type": "kimi",
"display_name": "Kimi K2 Thinking",
"description": "Kimi K2 Thinking - Extended reasoning model",
"context_length": 131072,
"max_completion_tokens": 32768,
"thinking": {
"min": 1024,
"max": 32000,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "kimi-k2.5",
"object": "model",
"created": 1769472000,
"owned_by": "moonshot",
"type": "kimi",
"display_name": "Kimi K2.5",
"description": "Kimi K2.5 - Latest Moonshot AI coding model with improved capabilities",
"context_length": 131072,
"max_completion_tokens": 32768,
"thinking": {
"min": 1024,
"max": 32000,
"zero_allowed": true,
"dynamic_allowed": true
}
}
],
"antigravity": [
{
"id": "claude-opus-4-6-thinking",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Claude Opus 4.6 (Thinking)",
"name": "claude-opus-4-6-thinking",
"description": "Claude Opus 4.6 (Thinking)",
"context_length": 200000,
"max_completion_tokens": 64000,
"thinking": {
"min": 1024,
"max": 64000,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "claude-sonnet-4-6",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Claude Sonnet 4.6 (Thinking)",
"name": "claude-sonnet-4-6",
"description": "Claude Sonnet 4.6 (Thinking)",
"context_length": 200000,
"max_completion_tokens": 64000,
"thinking": {
"min": 1024,
"max": 64000,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 2.5 Flash",
"name": "gemini-2.5-flash",
"description": "Gemini 2.5 Flash",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-2.5-flash-lite",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 2.5 Flash Lite",
"name": "gemini-2.5-flash-lite",
"description": "Gemini 2.5 Flash Lite",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": {
"max": 24576,
"zero_allowed": true,
"dynamic_allowed": true
}
},
{
"id": "gemini-3-flash",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3 Flash",
"name": "gemini-3-flash",
"description": "Gemini 3 Flash",
"context_length": 1048576,
"max_completion_tokens": 65536,
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"low",
"medium",
"high"
]
}
},
{
"id": "gemini-3-pro-high",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3 Pro (High)",
"name": "gemini-3-pro-high",
"description": "Gemini 3 Pro (High)",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gemini-3-pro-low",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3 Pro (Low)",
"name": "gemini-3-pro-low",
"description": "Gemini 3 Pro (Low)",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gemini-3.1-flash-image",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3.1 Flash Image",
"name": "gemini-3.1-flash-image",
"description": "Gemini 3.1 Flash Image",
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"minimal",
"high"
]
}
},
{
"id": "gemini-3.1-pro-high",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3.1 Pro (High)",
"name": "gemini-3.1-pro-high",
"description": "Gemini 3.1 Pro (High)",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gemini-3.1-pro-low",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "Gemini 3.1 Pro (Low)",
"name": "gemini-3.1-pro-low",
"description": "Gemini 3.1 Pro (Low)",
"context_length": 1048576,
"max_completion_tokens": 65535,
"thinking": {
"min": 128,
"max": 32768,
"dynamic_allowed": true,
"levels": [
"low",
"high"
]
}
},
{
"id": "gpt-oss-120b-medium",
"object": "model",
"owned_by": "antigravity",
"type": "antigravity",
"display_name": "GPT-OSS 120B (Medium)",
"name": "gpt-oss-120b-medium",
"description": "GPT-OSS 120B (Medium)",
"context_length": 114000,
"max_completion_tokens": 32768
}
]
}
================================================
FILE: internal/runtime/executor/aistudio_executor.go
================================================
// Package executor provides runtime execution capabilities for various AI service providers.
// This file implements the AI Studio executor that routes requests through a websocket-backed
// transport for the AI Studio provider.
package executor
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// AIStudioExecutor routes AI Studio requests through a websocket-backed transport.
type AIStudioExecutor struct {
provider string
relay *wsrelay.Manager
cfg *config.Config
}
// NewAIStudioExecutor creates a new AI Studio executor instance.
//
// Parameters:
// - cfg: The application configuration
// - provider: The provider name
// - relay: The websocket relay manager
//
// Returns:
// - *AIStudioExecutor: A new AI Studio executor instance
func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AIStudioExecutor {
return &AIStudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
}
// Identifier returns the executor identifier.
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
// PrepareRequest prepares the HTTP request for execution (no-op for AI Studio).
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
return nil
}
// HttpRequest forwards an arbitrary HTTP request through the websocket relay.
func (e *AIStudioExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("aistudio executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
if e.relay == nil {
return nil, fmt.Errorf("aistudio executor: ws relay is nil")
}
if auth == nil || auth.ID == "" {
return nil, fmt.Errorf("aistudio executor: missing auth")
}
httpReq := req.WithContext(ctx)
if httpReq.URL == nil || strings.TrimSpace(httpReq.URL.String()) == "" {
return nil, fmt.Errorf("aistudio executor: request URL is empty")
}
var body []byte
if httpReq.Body != nil {
b, errRead := io.ReadAll(httpReq.Body)
if errRead != nil {
return nil, errRead
}
body = b
httpReq.Body = io.NopCloser(bytes.NewReader(b))
}
wsReq := &wsrelay.HTTPRequest{
Method: httpReq.Method,
URL: httpReq.URL.String(),
Headers: httpReq.Header.Clone(),
Body: body,
}
wsResp, errRelay := e.relay.NonStream(ctx, auth.ID, wsReq)
if errRelay != nil {
return nil, errRelay
}
if wsResp == nil {
return nil, fmt.Errorf("aistudio executor: ws response is nil")
}
statusText := http.StatusText(wsResp.Status)
if statusText == "" {
statusText = "Unknown"
}
resp := &http.Response{
StatusCode: wsResp.Status,
Status: fmt.Sprintf("%d %s", wsResp.Status, statusText),
Header: wsResp.Headers.Clone(),
Body: io.NopCloser(bytes.NewReader(wsResp.Body)),
ContentLength: int64(len(wsResp.Body)),
Request: httpReq,
}
return resp, nil
}
// Execute performs a non-streaming request to the AI Studio API.
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, false)
if err != nil {
return resp, err
}
endpoint := e.buildEndpoint(baseModel, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload,
}
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: body.payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
wsResp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
if len(wsResp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
}
if wsResp.Status < 200 || wsResp.Status >= 300 {
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
}
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
var param any
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out)), Headers: wsResp.Headers.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the AI Studio API.
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
translatedReq, body, err := e.translateRequest(req, opts, true)
if err != nil {
return nil, err
}
endpoint := e.buildEndpoint(baseModel, body.action, opts.Alt)
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload,
}
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: body.payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
wsStream, err := e.relay.Stream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
firstEvent, ok := <-wsStream
if !ok {
err = fmt.Errorf("wsrelay: stream closed before start")
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
if firstEvent.Status > 0 && firstEvent.Status != http.StatusOK {
metadataLogged := false
if firstEvent.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, firstEvent.Status, firstEvent.Headers.Clone())
metadataLogged = true
}
var body bytes.Buffer
if len(firstEvent.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
body.Write(firstEvent.Payload)
}
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
}
for event := range wsStream {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
if body.Len() == 0 {
body.WriteString(event.Err.Error())
}
break
}
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
body.Write(event.Payload)
}
if event.Type == wsrelay.MessageTypeStreamEnd {
break
}
}
return nil, statusErr{code: firstEvent.Status, msg: body.String()}
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(first wsrelay.StreamEvent) {
defer close(out)
var param any
metadataLogged := false
processEvent := func(event wsrelay.StreamEvent) bool {
if event.Err != nil {
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return false
}
switch event.Type {
case wsrelay.MessageTypeStreamStart:
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
filtered := FilterSSEUsageMetadata(event.Payload)
if detail, ok := parseGeminiStreamUsage(filtered); ok {
reporter.publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
}
break
}
case wsrelay.MessageTypeStreamEnd:
return false
case wsrelay.MessageTypeHTTPResp:
if !metadataLogged && event.Status > 0 {
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
metadataLogged = true
}
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
}
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
}
reporter.publish(ctx, parseGeminiUsage(event.Payload))
return false
case wsrelay.MessageTypeError:
recordAPIResponseError(ctx, e.cfg, event.Err)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
return false
}
return true
}
if !processEvent(first) {
return
}
for event := range wsStream {
if !processEvent(event) {
return
}
}
}(firstEvent)
return &cliproxyexecutor.StreamResult{Headers: firstEvent.Headers.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the AI Studio API.
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
_, body, err := e.translateRequest(req, opts, false)
if err != nil {
return cliproxyexecutor.Response{}, err
}
body.payload, _ = sjson.DeleteBytes(body.payload, "generationConfig")
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
body.payload, _ = sjson.DeleteBytes(body.payload, "safetySettings")
endpoint := e.buildEndpoint(baseModel, "countTokens", "")
wsReq := &wsrelay.HTTPRequest{
Method: http.MethodPost,
URL: endpoint,
Headers: http.Header{"Content-Type": []string{"application/json"}},
Body: body.payload,
}
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: body.payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
resp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
if len(resp.Body) > 0 {
appendAPIResponseChunk(ctx, e.cfg, resp.Body)
}
if resp.Status < 200 || resp.Status >= 300 {
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
}
totalTokens := gjson.GetBytes(resp.Body, "totalTokens").Int()
if totalTokens <= 0 {
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
}
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
// Refresh refreshes the authentication credentials (no-op for AI Studio).
func (e *AIStudioExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}
type translatedPayload struct {
payload []byte
action string
toFormat sdktranslator.Format
}
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, translatedPayload{}, err
}
payload = fixGeminiImageAspectRatio(baseModel, payload)
requestedModel := payloadRequestedModel(opts, req.Model)
payload = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", payload, originalTranslated, requestedModel)
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
metadataAction := "generateContent"
if req.Metadata != nil {
if action, _ := req.Metadata["action"].(string); action == "countTokens" {
metadataAction = action
}
}
action := metadataAction
if stream && action != "countTokens" {
action = "streamGenerateContent"
}
payload, _ = sjson.DeleteBytes(payload, "session_id")
return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil
}
func (e *AIStudioExecutor) buildEndpoint(model, action, alt string) string {
base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action)
if action == "streamGenerateContent" {
if alt == "" {
return base + "?alt=sse"
}
return base + "?$alt=" + url.QueryEscape(alt)
}
if alt != "" && action != "countTokens" {
return base + "?$alt=" + url.QueryEscape(alt)
}
return base
}
// ensureColonSpacedJSON normalizes JSON objects so that colons are followed by a single space while
// keeping the payload otherwise compact. Non-JSON inputs are returned unchanged.
func ensureColonSpacedJSON(payload []byte) []byte {
trimmed := bytes.TrimSpace(payload)
if len(trimmed) == 0 {
return payload
}
var decoded any
if err := json.Unmarshal(trimmed, &decoded); err != nil {
return payload
}
indented, err := json.MarshalIndent(decoded, "", " ")
if err != nil {
return payload
}
compacted := make([]byte, 0, len(indented))
inString := false
skipSpace := false
for i := 0; i < len(indented); i++ {
ch := indented[i]
if ch == '"' {
// A quote is escaped only when preceded by an odd number of consecutive backslashes.
// For example: "\\\"" keeps the quote inside the string, but "\\\\" closes the string.
backslashes := 0
for j := i - 1; j >= 0 && indented[j] == '\\'; j-- {
backslashes++
}
if backslashes%2 == 0 {
inString = !inString
}
}
if !inString {
if ch == '\n' || ch == '\r' {
skipSpace = true
continue
}
if skipSpace {
if ch == ' ' || ch == '\t' {
continue
}
skipSpace = false
}
}
compacted = append(compacted, ch)
}
return compacted
}
================================================
FILE: internal/runtime/executor/antigravity_executor.go
================================================
// Package executor provides runtime execution capabilities for various AI service providers.
// This file implements the Antigravity executor that proxies requests to the antigravity
// upstream using OAuth credentials.
package executor
import (
"bufio"
"bytes"
"context"
"crypto/sha256"
"crypto/tls"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"io"
"math/rand"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
antigravityBaseURLDaily = "https://daily-cloudcode-pa.googleapis.com"
antigravitySandboxBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
antigravityCountTokensPath = "/v1internal:countTokens"
antigravityStreamPath = "/v1internal:streamGenerateContent"
antigravityGeneratePath = "/v1internal:generateContent"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.19.6 darwin/arm64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
// systemInstruction = "You are Antigravity, a powerful agentic AI coding assistant designed by the Google Deepmind team working on Advanced Agentic Coding.You are pair programming with a USER to solve their coding task. The task may require creating a new codebase, modifying or debugging an existing codebase, or simply answering a question.**Absolute paths only****Proactiveness**"
)
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
)
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
cfg *config.Config
}
// NewAntigravityExecutor creates a new Antigravity executor instance.
//
// Parameters:
// - cfg: The application configuration
//
// Returns:
// - *AntigravityExecutor: A new Antigravity executor instance
func NewAntigravityExecutor(cfg *config.Config) *AntigravityExecutor {
return &AntigravityExecutor{cfg: cfg}
}
// antigravityTransport is a singleton HTTP/1.1 transport shared by all Antigravity requests.
// It is initialized once via antigravityTransportOnce to avoid leaking a new connection pool
// (and the goroutines managing it) on every request.
var (
antigravityTransport *http.Transport
antigravityTransportOnce sync.Once
)
func cloneTransportWithHTTP11(base *http.Transport) *http.Transport {
if base == nil {
return nil
}
clone := base.Clone()
clone.ForceAttemptHTTP2 = false
// Wipe TLSNextProto to prevent implicit HTTP/2 upgrade.
clone.TLSNextProto = make(map[string]func(authority string, c *tls.Conn) http.RoundTripper)
if clone.TLSClientConfig == nil {
clone.TLSClientConfig = &tls.Config{}
} else {
clone.TLSClientConfig = clone.TLSClientConfig.Clone()
}
// Actively advertise only HTTP/1.1 in the ALPN handshake.
clone.TLSClientConfig.NextProtos = []string{"http/1.1"}
return clone
}
// initAntigravityTransport creates the shared HTTP/1.1 transport exactly once.
func initAntigravityTransport() {
base, ok := http.DefaultTransport.(*http.Transport)
if !ok {
base = &http.Transport{}
}
antigravityTransport = cloneTransportWithHTTP11(base)
}
// newAntigravityHTTPClient creates an HTTP client specifically for Antigravity,
// enforcing HTTP/1.1 by disabling HTTP/2 to perfectly mimic Node.js https defaults.
// The underlying Transport is a singleton to avoid leaking connection pools.
func newAntigravityHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
antigravityTransportOnce.Do(initAntigravityTransport)
client := newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
// If no transport is set, use the shared HTTP/1.1 transport.
if client.Transport == nil {
client.Transport = antigravityTransport
return client
}
// Preserve proxy settings from proxy-aware transports while forcing HTTP/1.1.
if transport, ok := client.Transport.(*http.Transport); ok {
client.Transport = cloneTransportWithHTTP11(transport)
}
return client
}
// Identifier returns the executor identifier.
func (e *AntigravityExecutor) Identifier() string { return antigravityAuthType }
// PrepareRequest injects Antigravity credentials into the outgoing HTTP request.
func (e *AntigravityExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
token, _, errToken := e.ensureAccessToken(req.Context(), auth)
if errToken != nil {
return errToken
}
if strings.TrimSpace(token) == "" {
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
req.Header.Set("Authorization", "Bearer "+token)
return nil
}
// HttpRequest injects Antigravity credentials into the request and executes it.
// It uses a whitelist approach: all incoming headers are stripped and only
// the minimum set required by the Antigravity protocol is explicitly set.
func (e *AntigravityExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("antigravity executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
// --- Whitelist: save only the headers we need from the original request ---
contentType := httpReq.Header.Get("Content-Type")
// Wipe ALL incoming headers
for k := range httpReq.Header {
delete(httpReq.Header, k)
}
// --- Set only the headers Antigravity actually sends ---
if contentType != "" {
httpReq.Header.Set("Content-Type", contentType)
}
// Content-Length is managed automatically by Go's http.Client from the Body
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
httpReq.Close = true // sends Connection: close
// Inject Authorization: Bearer
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
isClaude := strings.Contains(strings.ToLower(baseModel), "claude")
if isClaude || strings.Contains(baseModel, "gemini-3-pro") || strings.Contains(baseModel, "gemini-3.1-flash-image") {
return e.executeClaudeNonStream(ctx, auth, req, opts)
}
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
if errToken != nil {
return resp, errToken
}
if updatedAuth != nil {
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
attempts := antigravityRetryAttempts(auth, e.cfg)
attemptLoop:
for attempt := 0; attempt < attempts; attempt++ {
var lastStatus int
var lastBody []byte
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, false, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return resp, err
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errDo
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if attempt+1 < attempts {
delay := antigravityNoCapacityRetryDelay(attempt)
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
if errWait := antigravityWait(ctx, delay); errWait != nil {
return resp, errWait
}
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err
}
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
}
return resp, err
}
return resp, err
}
// executeClaudeNonStream performs a claude non-streaming request to the Antigravity API.
func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
if errToken != nil {
return resp, errToken
}
if updatedAuth != nil {
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
attempts := antigravityRetryAttempts(auth, e.cfg)
attemptLoop:
for attempt := 0; attempt < attempts; attempt++ {
var lastStatus int
var lastBody []byte
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return resp, err
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return resp, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errDo
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return resp, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return resp, err
}
lastStatus = 0
lastBody = nil
lastErr = errRead
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if attempt+1 < attempts {
delay := antigravityNoCapacityRetryDelay(attempt)
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
if errWait := antigravityWait(ctx, delay); errWait != nil {
return resp, errWait
}
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return resp, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(nil, streamScannerBuffer)
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
}
out <- cliproxyexecutor.StreamChunk{Payload: payload}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
}
}(httpResp)
var buffer bytes.Buffer
for chunk := range out {
if chunk.Err != nil {
return resp, chunk.Err
}
if len(chunk.Payload) > 0 {
_, _ = buffer.Write(chunk.Payload)
_, _ = buffer.Write([]byte("\n"))
}
}
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
var param any
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(converted), Headers: httpResp.Header.Clone()}
reporter.ensurePublished(ctx)
return resp, nil
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
}
return resp, err
}
return resp, err
}
func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
responseTemplate := ""
var traceID string
var finishReason string
var modelVersion string
var responseID string
var role string
var usageRaw string
parts := make([]map[string]interface{}, 0)
var pendingKind string
var pendingText strings.Builder
var pendingThoughtSig string
flushPending := func() {
if pendingKind == "" {
return
}
text := pendingText.String()
switch pendingKind {
case "text":
if strings.TrimSpace(text) == "" {
pendingKind = ""
pendingText.Reset()
pendingThoughtSig = ""
return
}
parts = append(parts, map[string]interface{}{"text": text})
case "thought":
if strings.TrimSpace(text) == "" && pendingThoughtSig == "" {
pendingKind = ""
pendingText.Reset()
pendingThoughtSig = ""
return
}
part := map[string]interface{}{"thought": true}
part["text"] = text
if pendingThoughtSig != "" {
part["thoughtSignature"] = pendingThoughtSig
}
parts = append(parts, part)
}
pendingKind = ""
pendingText.Reset()
pendingThoughtSig = ""
}
normalizePart := func(partResult gjson.Result) map[string]interface{} {
var m map[string]interface{}
_ = json.Unmarshal([]byte(partResult.Raw), &m)
if m == nil {
m = map[string]interface{}{}
}
sig := partResult.Get("thoughtSignature").String()
if sig == "" {
sig = partResult.Get("thought_signature").String()
}
if sig != "" {
m["thoughtSignature"] = sig
delete(m, "thought_signature")
}
if inlineData, ok := m["inline_data"]; ok {
m["inlineData"] = inlineData
delete(m, "inline_data")
}
return m
}
for _, line := range bytes.Split(stream, []byte("\n")) {
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 || !gjson.ValidBytes(trimmed) {
continue
}
root := gjson.ParseBytes(trimmed)
responseNode := root.Get("response")
if !responseNode.Exists() {
if root.Get("candidates").Exists() {
responseNode = root
} else {
continue
}
}
responseTemplate = responseNode.Raw
if traceResult := root.Get("traceId"); traceResult.Exists() && traceResult.String() != "" {
traceID = traceResult.String()
}
if roleResult := responseNode.Get("candidates.0.content.role"); roleResult.Exists() {
role = roleResult.String()
}
if finishResult := responseNode.Get("candidates.0.finishReason"); finishResult.Exists() && finishResult.String() != "" {
finishReason = finishResult.String()
}
if modelResult := responseNode.Get("modelVersion"); modelResult.Exists() && modelResult.String() != "" {
modelVersion = modelResult.String()
}
if responseIDResult := responseNode.Get("responseId"); responseIDResult.Exists() && responseIDResult.String() != "" {
responseID = responseIDResult.String()
}
if usageResult := responseNode.Get("usageMetadata"); usageResult.Exists() {
usageRaw = usageResult.Raw
} else if usageMetadataResult := root.Get("usageMetadata"); usageMetadataResult.Exists() {
usageRaw = usageMetadataResult.Raw
}
if partsResult := responseNode.Get("candidates.0.content.parts"); partsResult.IsArray() {
for _, part := range partsResult.Array() {
hasFunctionCall := part.Get("functionCall").Exists()
hasInlineData := part.Get("inlineData").Exists() || part.Get("inline_data").Exists()
sig := part.Get("thoughtSignature").String()
if sig == "" {
sig = part.Get("thought_signature").String()
}
text := part.Get("text").String()
thought := part.Get("thought").Bool()
if hasFunctionCall || hasInlineData {
flushPending()
parts = append(parts, normalizePart(part))
continue
}
if thought || part.Get("text").Exists() {
kind := "text"
if thought {
kind = "thought"
}
if pendingKind != "" && pendingKind != kind {
flushPending()
}
pendingKind = kind
pendingText.WriteString(text)
if kind == "thought" && sig != "" {
pendingThoughtSig = sig
}
continue
}
flushPending()
parts = append(parts, normalizePart(part))
}
}
}
flushPending()
if responseTemplate == "" {
responseTemplate = `{"candidates":[{"content":{"role":"model","parts":[]}}]}`
}
partsJSON, _ := json.Marshal(parts)
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
if role != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
}
if finishReason != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
}
if modelVersion != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
}
if responseID != "" {
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
}
if usageRaw != "" {
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
}
output := `{"response":{},"traceId":""}`
output, _ = sjson.SetRaw(output, "response", responseTemplate)
if traceID != "" {
output, _ = sjson.Set(output, "traceId", traceID)
}
return []byte(output)
}
// ExecuteStream performs a streaming request to the Antigravity API.
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
ctx = context.WithValue(ctx, "alt", "")
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
if errToken != nil {
return nil, errToken
}
if updatedAuth != nil {
auth = updatedAuth
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, "antigravity", "request", translated, originalTranslated, requestedModel)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
attempts := antigravityRetryAttempts(auth, e.cfg)
attemptLoop:
for attempt := 0; attempt < attempts; attempt++ {
var lastStatus int
var lastBody []byte
var lastErr error
for idx, baseURL := range baseURLs {
httpReq, errReq := e.buildRequest(ctx, auth, token, baseModel, translated, true, opts.Alt, baseURL)
if errReq != nil {
err = errReq
return nil, err
}
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return nil, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errDo
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
if errors.Is(errRead, context.Canceled) || errors.Is(errRead, context.DeadlineExceeded) {
err = errRead
return nil, err
}
if errCtx := ctx.Err(); errCtx != nil {
err = errCtx
return nil, err
}
lastStatus = 0
lastBody = nil
lastErr = errRead
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
err = errRead
return nil, err
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if antigravityShouldRetryNoCapacity(httpResp.StatusCode, bodyBytes) {
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: no capacity on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
if attempt+1 < attempts {
delay := antigravityNoCapacityRetryDelay(attempt)
log.Debugf("antigravity executor: no capacity for model %s, retrying in %s (attempt %d/%d)", baseModel, delay, attempt+1, attempts)
if errWait := antigravityWait(ctx, delay); errWait != nil {
return nil, errWait
}
continue attemptLoop
}
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
// Filter usage metadata for all models
// Only retain usage statistics in the terminal chunk
line = FilterSSEUsageMetadata(line)
payload := jsonPayload(line)
if payload == nil {
continue
}
if detail, ok := parseAntigravityStreamUsage(payload); ok {
reporter.publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
for i := range tail {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
} else {
reporter.ensurePublished(ctx)
}
}(httpResp)
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
err = sErr
case lastErr != nil:
err = lastErr
default:
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
}
return nil, err
}
return nil, err
}
// Refresh refreshes the authentication credentials using the refresh token.
func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return auth, nil
}
updated, errRefresh := e.refreshToken(ctx, auth.Clone())
if errRefresh != nil {
return nil, errRefresh
}
return updated, nil
}
// CountTokens counts tokens for the given request using the Antigravity API.
func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
if errToken != nil {
return cliproxyexecutor.Response{}, errToken
}
if updatedAuth != nil {
auth = updatedAuth
}
if strings.TrimSpace(token) == "" {
return cliproxyexecutor.Response{}, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
from := opts.SourceFormat
to := sdktranslator.FromString("antigravity")
respCtx := context.WithValue(ctx, "alt", opts.Alt)
// Prepare payload once (doesn't depend on baseURL)
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
var lastStatus int
var lastBody []byte
var lastErr error
for idx, baseURL := range baseURLs {
base := strings.TrimSuffix(baseURL, "/")
if base == "" {
base = buildBaseURL(auth)
}
var requestURL strings.Builder
requestURL.WriteString(base)
requestURL.WriteString(antigravityCountTokensPath)
if opts.Alt != "" {
requestURL.WriteString("?$alt=")
requestURL.WriteString(url.QueryEscape(opts.Alt))
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
if errReq != nil {
return cliproxyexecutor.Response{}, errReq
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
if host := resolveHost(base); host != "" {
httpReq.Host = host
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: requestURL.String(),
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
if errors.Is(errDo, context.Canceled) || errors.Is(errDo, context.DeadlineExceeded) {
return cliproxyexecutor.Response{}, errDo
}
lastStatus = 0
lastBody = nil
lastErr = errDo
if idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
return cliproxyexecutor.Response{}, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: httpResp.Header.Clone()}, nil
}
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), bodyBytes...)
lastErr = nil
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
continue
}
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
return cliproxyexecutor.Response{}, sErr
}
switch {
case lastStatus != 0:
sErr := statusErr{code: lastStatus, msg: string(lastBody)}
if lastStatus == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(lastBody); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
return cliproxyexecutor.Response{}, sErr
case lastErr != nil:
return cliproxyexecutor.Response{}, lastErr
default:
return cliproxyexecutor.Response{}, statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
}
}
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
if auth == nil {
return "", nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
}
accessToken := metaStringValue(auth.Metadata, "access_token")
expiry := tokenExpiry(auth.Metadata)
if accessToken != "" && expiry.After(time.Now().Add(refreshSkew)) {
return accessToken, nil, nil
}
refreshCtx := context.Background()
if ctx != nil {
if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil {
refreshCtx = context.WithValue(refreshCtx, "cliproxy.roundtripper", rt)
}
}
updated, errRefresh := e.refreshToken(refreshCtx, auth.Clone())
if errRefresh != nil {
return "", nil, errRefresh
}
return metaStringValue(updated.Metadata, "access_token"), updated, nil
}
func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if auth == nil {
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing auth"}
}
refreshToken := metaStringValue(auth.Metadata, "refresh_token")
if refreshToken == "" {
return auth, statusErr{code: http.StatusUnauthorized, msg: "missing refresh token"}
}
form := url.Values{}
form.Set("client_id", antigravityClientID)
form.Set("client_secret", antigravityClientSecret)
form.Set("grant_type", "refresh_token")
form.Set("refresh_token", refreshToken)
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
if errReq != nil {
return auth, errReq
}
httpReq.Header.Set("Host", "oauth2.googleapis.com")
httpReq.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Real Antigravity uses Go's default User-Agent for OAuth token refresh
httpReq.Header.Set("User-Agent", "Go-http-client/2.0")
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
return auth, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("antigravity executor: close response body error: %v", errClose)
}
}()
bodyBytes, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
return auth, errRead
}
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
sErr := statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
if httpResp.StatusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(bodyBytes); parseErr == nil && retryAfter != nil {
sErr.retryAfter = retryAfter
}
}
return auth, sErr
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int64 `json:"expires_in"`
TokenType string `json:"token_type"`
}
if errUnmarshal := json.Unmarshal(bodyBytes, &tokenResp); errUnmarshal != nil {
return auth, errUnmarshal
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["access_token"] = tokenResp.AccessToken
if tokenResp.RefreshToken != "" {
auth.Metadata["refresh_token"] = tokenResp.RefreshToken
}
auth.Metadata["expires_in"] = tokenResp.ExpiresIn
now := time.Now()
auth.Metadata["timestamp"] = now.UnixMilli()
auth.Metadata["expired"] = now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339)
auth.Metadata["type"] = antigravityAuthType
if errProject := e.ensureAntigravityProjectID(ctx, auth, tokenResp.AccessToken); errProject != nil {
log.Warnf("antigravity executor: ensure project id failed: %v", errProject)
}
return auth, nil
}
func (e *AntigravityExecutor) ensureAntigravityProjectID(ctx context.Context, auth *cliproxyauth.Auth, accessToken string) error {
if auth == nil {
return nil
}
if auth.Metadata["project_id"] != nil {
return nil
}
token := strings.TrimSpace(accessToken)
if token == "" {
token = metaStringValue(auth.Metadata, "access_token")
}
if token == "" {
return nil
}
httpClient := newAntigravityHTTPClient(ctx, e.cfg, auth, 0)
projectID, errFetch := sdkAuth.FetchAntigravityProjectID(ctx, token, httpClient)
if errFetch != nil {
return errFetch
}
if strings.TrimSpace(projectID) == "" {
return nil
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["project_id"] = strings.TrimSpace(projectID)
return nil
}
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) {
if token == "" {
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
base := strings.TrimSuffix(baseURL, "/")
if base == "" {
base = buildBaseURL(auth)
}
path := antigravityGeneratePath
if stream {
path = antigravityStreamPath
}
var requestURL strings.Builder
requestURL.WriteString(base)
requestURL.WriteString(path)
if stream {
if alt != "" {
requestURL.WriteString("?$alt=")
requestURL.WriteString(url.QueryEscape(alt))
} else {
requestURL.WriteString("?alt=sse")
}
} else if alt != "" {
requestURL.WriteString("?$alt=")
requestURL.WriteString(url.QueryEscape(alt))
}
// Extract project_id from auth metadata if available
projectID := ""
if auth != nil && auth.Metadata != nil {
if pid, ok := auth.Metadata["project_id"].(string); ok {
projectID = strings.TrimSpace(pid)
}
}
payload = geminiToAntigravity(modelName, payload, projectID)
payload, _ = sjson.SetBytes(payload, "model", modelName)
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro") || strings.Contains(modelName, "gemini-3.1-pro")
payloadStr := string(payload)
paths := make([]string, 0)
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
for _, p := range paths {
payloadStr, _ = util.RenameKey(payloadStr, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
}
if useAntigravitySchema {
payloadStr = util.CleanJSONSchemaForAntigravity(payloadStr)
} else {
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
}
// if useAntigravitySchema {
// systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
// payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
// if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
// for _, partResult := range systemInstructionPartsResult.Array() {
// payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
// }
// }
// }
if strings.Contains(modelName, "claude") {
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
} else {
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
}
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), strings.NewReader(payloadStr))
if errReq != nil {
return nil, errReq
}
httpReq.Close = true
httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+token)
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
if host := resolveHost(base); host != "" {
httpReq.Host = host
}
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
var payloadLog []byte
if e.cfg != nil && e.cfg.RequestLog {
payloadLog = []byte(payloadStr)
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: requestURL.String(),
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: payloadLog,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
return httpReq, nil
}
func tokenExpiry(metadata map[string]any) time.Time {
if metadata == nil {
return time.Time{}
}
if expStr, ok := metadata["expired"].(string); ok {
expStr = strings.TrimSpace(expStr)
if expStr != "" {
if parsed, errParse := time.Parse(time.RFC3339, expStr); errParse == nil {
return parsed
}
}
}
expiresIn, hasExpires := int64Value(metadata["expires_in"])
tsMs, hasTimestamp := int64Value(metadata["timestamp"])
if hasExpires && hasTimestamp {
return time.Unix(0, tsMs*int64(time.Millisecond)).Add(time.Duration(expiresIn) * time.Second)
}
return time.Time{}
}
func metaStringValue(metadata map[string]any, key string) string {
if metadata == nil {
return ""
}
if v, ok := metadata[key]; ok {
switch typed := v.(type) {
case string:
return strings.TrimSpace(typed)
case []byte:
return strings.TrimSpace(string(typed))
}
}
return ""
}
func int64Value(value any) (int64, bool) {
switch typed := value.(type) {
case int:
return int64(typed), true
case int64:
return typed, true
case float64:
return int64(typed), true
case json.Number:
if i, errParse := typed.Int64(); errParse == nil {
return i, true
}
case string:
if strings.TrimSpace(typed) == "" {
return 0, false
}
if i, errParse := strconv.ParseInt(strings.TrimSpace(typed), 10, 64); errParse == nil {
return i, true
}
}
return 0, false
}
func buildBaseURL(auth *cliproxyauth.Auth) string {
if baseURLs := antigravityBaseURLFallbackOrder(auth); len(baseURLs) > 0 {
return baseURLs[0]
}
return antigravityBaseURLDaily
}
func resolveHost(base string) string {
parsed, errParse := url.Parse(base)
if errParse != nil {
return ""
}
if parsed.Host != "" {
return parsed.Host
}
return strings.TrimPrefix(strings.TrimPrefix(base, "https://"), "http://")
}
func resolveUserAgent(auth *cliproxyauth.Auth) string {
if auth != nil {
if auth.Attributes != nil {
if ua := strings.TrimSpace(auth.Attributes["user_agent"]); ua != "" {
return ua
}
}
if auth.Metadata != nil {
if ua, ok := auth.Metadata["user_agent"].(string); ok && strings.TrimSpace(ua) != "" {
return strings.TrimSpace(ua)
}
}
}
return defaultAntigravityAgent
}
func antigravityRetryAttempts(auth *cliproxyauth.Auth, cfg *config.Config) int {
retry := 0
if cfg != nil {
retry = cfg.RequestRetry
}
if auth != nil {
if override, ok := auth.RequestRetryOverride(); ok {
retry = override
}
}
if retry < 0 {
retry = 0
}
attempts := retry + 1
if attempts < 1 {
return 1
}
return attempts
}
func antigravityShouldRetryNoCapacity(statusCode int, body []byte) bool {
if statusCode != http.StatusServiceUnavailable {
return false
}
if len(body) == 0 {
return false
}
msg := strings.ToLower(string(body))
return strings.Contains(msg, "no capacity available")
}
func antigravityNoCapacityRetryDelay(attempt int) time.Duration {
if attempt < 0 {
attempt = 0
}
delay := time.Duration(attempt+1) * 250 * time.Millisecond
if delay > 2*time.Second {
delay = 2 * time.Second
}
return delay
}
func antigravityWait(ctx context.Context, wait time.Duration) error {
if wait <= 0 {
return nil
}
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
return []string{base}
}
return []string{
antigravityBaseURLDaily,
antigravitySandboxBaseURLDaily,
// antigravityBaseURLProd,
}
}
func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["base_url"]); v != "" {
return strings.TrimSuffix(v, "/")
}
}
if auth.Metadata != nil {
if v, ok := auth.Metadata["base_url"].(string); ok {
v = strings.TrimSpace(v)
if v != "" {
return strings.TrimSuffix(v, "/")
}
}
}
return ""
}
func geminiToAntigravity(modelName string, payload []byte, projectID string) []byte {
template, _ := sjson.Set(string(payload), "model", modelName)
template, _ = sjson.Set(template, "userAgent", "antigravity")
isImageModel := strings.Contains(modelName, "image")
var reqType string
if isImageModel {
reqType = "image_gen"
} else {
reqType = "agent"
}
template, _ = sjson.Set(template, "requestType", reqType)
// Use real project ID from auth if available, otherwise generate random (legacy fallback)
if projectID != "" {
template, _ = sjson.Set(template, "project", projectID)
} else {
template, _ = sjson.Set(template, "project", generateProjectID())
}
if isImageModel {
template, _ = sjson.Set(template, "requestId", generateImageGenRequestID())
} else {
template, _ = sjson.Set(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
}
template, _ = sjson.Delete(template, "request.safetySettings")
if toolConfig := gjson.Get(template, "toolConfig"); toolConfig.Exists() && !gjson.Get(template, "request.toolConfig").Exists() {
template, _ = sjson.SetRaw(template, "request.toolConfig", toolConfig.Raw)
template, _ = sjson.Delete(template, "toolConfig")
}
return []byte(template)
}
func generateRequestID() string {
return "agent-" + uuid.NewString()
}
func generateImageGenRequestID() string {
return fmt.Sprintf("image_gen/%d/%s/12", time.Now().UnixMilli(), uuid.NewString())
}
func generateSessionID() string {
randSourceMutex.Lock()
n := randSource.Int63n(9_000_000_000_000_000_000)
randSourceMutex.Unlock()
return "-" + strconv.FormatInt(n, 10)
}
func generateStableSessionID(payload []byte) string {
contents := gjson.GetBytes(payload, "request.contents")
if contents.IsArray() {
for _, content := range contents.Array() {
if content.Get("role").String() == "user" {
text := content.Get("parts.0.text").String()
if text != "" {
h := sha256.Sum256([]byte(text))
n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF
return "-" + strconv.FormatInt(n, 10)
}
}
}
}
return generateSessionID()
}
func generateProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randSourceMutex.Lock()
adj := adjectives[randSource.Intn(len(adjectives))]
noun := nouns[randSource.Intn(len(nouns))]
randSourceMutex.Unlock()
randomPart := strings.ToLower(uuid.NewString())[:5]
return adj + "-" + noun + "-" + randomPart
}
================================================
FILE: internal/runtime/executor/antigravity_executor_buildrequest_test.go
================================================
package executor
import (
"context"
"encoding/json"
"io"
"testing"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestAntigravityBuildRequest_SanitizesGeminiToolSchema(t *testing.T) {
body := buildRequestBodyFromPayload(t, "gemini-2.5-pro")
decl := extractFirstFunctionDeclaration(t, body)
if _, ok := decl["parametersJsonSchema"]; ok {
t.Fatalf("parametersJsonSchema should be renamed to parameters")
}
params, ok := decl["parameters"].(map[string]any)
if !ok {
t.Fatalf("parameters missing or invalid type")
}
assertSchemaSanitizedAndPropertyPreserved(t, params)
}
func TestAntigravityBuildRequest_SanitizesAntigravityToolSchema(t *testing.T) {
body := buildRequestBodyFromPayload(t, "claude-opus-4-6")
decl := extractFirstFunctionDeclaration(t, body)
params, ok := decl["parameters"].(map[string]any)
if !ok {
t.Fatalf("parameters missing or invalid type")
}
assertSchemaSanitizedAndPropertyPreserved(t, params)
}
func buildRequestBodyFromPayload(t *testing.T, modelName string) map[string]any {
t.Helper()
executor := &AntigravityExecutor{}
auth := &cliproxyauth.Auth{}
payload := []byte(`{
"request": {
"tools": [
{
"function_declarations": [
{
"name": "tool_1",
"parametersJsonSchema": {
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "root-schema",
"type": "object",
"properties": {
"$id": {"type": "string"},
"arg": {
"type": "object",
"prefill": "hello",
"properties": {
"mode": {
"type": "string",
"deprecated": true,
"enum": ["a", "b"],
"enumTitles": ["A", "B"]
}
}
}
},
"patternProperties": {
"^x-": {"type": "string"}
}
}
}
]
}
]
}
}`)
req, err := executor.buildRequest(context.Background(), auth, "token", modelName, payload, false, "", "https://example.com")
if err != nil {
t.Fatalf("buildRequest error: %v", err)
}
raw, err := io.ReadAll(req.Body)
if err != nil {
t.Fatalf("read request body error: %v", err)
}
var body map[string]any
if err := json.Unmarshal(raw, &body); err != nil {
t.Fatalf("unmarshal request body error: %v, body=%s", err, string(raw))
}
return body
}
func extractFirstFunctionDeclaration(t *testing.T, body map[string]any) map[string]any {
t.Helper()
request, ok := body["request"].(map[string]any)
if !ok {
t.Fatalf("request missing or invalid type")
}
tools, ok := request["tools"].([]any)
if !ok || len(tools) == 0 {
t.Fatalf("tools missing or empty")
}
tool, ok := tools[0].(map[string]any)
if !ok {
t.Fatalf("first tool invalid type")
}
decls, ok := tool["function_declarations"].([]any)
if !ok || len(decls) == 0 {
t.Fatalf("function_declarations missing or empty")
}
decl, ok := decls[0].(map[string]any)
if !ok {
t.Fatalf("first function declaration invalid type")
}
return decl
}
func assertSchemaSanitizedAndPropertyPreserved(t *testing.T, params map[string]any) {
t.Helper()
if _, ok := params["$id"]; ok {
t.Fatalf("root $id should be removed from schema")
}
if _, ok := params["patternProperties"]; ok {
t.Fatalf("patternProperties should be removed from schema")
}
props, ok := params["properties"].(map[string]any)
if !ok {
t.Fatalf("properties missing or invalid type")
}
if _, ok := props["$id"]; !ok {
t.Fatalf("property named $id should be preserved")
}
arg, ok := props["arg"].(map[string]any)
if !ok {
t.Fatalf("arg property missing or invalid type")
}
if _, ok := arg["prefill"]; ok {
t.Fatalf("prefill should be removed from nested schema")
}
argProps, ok := arg["properties"].(map[string]any)
if !ok {
t.Fatalf("arg.properties missing or invalid type")
}
mode, ok := argProps["mode"].(map[string]any)
if !ok {
t.Fatalf("mode property missing or invalid type")
}
if _, ok := mode["enumTitles"]; ok {
t.Fatalf("enumTitles should be removed from nested schema")
}
if _, ok := mode["deprecated"]; ok {
t.Fatalf("deprecated should be removed from nested schema")
}
}
================================================
FILE: internal/runtime/executor/cache_helpers.go
================================================
package executor
import (
"sync"
"time"
)
type codexCache struct {
ID string
Expire time.Time
}
// codexCacheMap stores prompt cache IDs keyed by model+user_id.
// Protected by codexCacheMu. Entries expire after 1 hour.
var (
codexCacheMap = make(map[string]codexCache)
codexCacheMu sync.RWMutex
)
// codexCacheCleanupInterval controls how often expired entries are purged.
const codexCacheCleanupInterval = 15 * time.Minute
// codexCacheCleanupOnce ensures the background cleanup goroutine starts only once.
var codexCacheCleanupOnce sync.Once
// startCodexCacheCleanup launches a background goroutine that periodically
// removes expired entries from codexCacheMap to prevent memory leaks.
func startCodexCacheCleanup() {
go func() {
ticker := time.NewTicker(codexCacheCleanupInterval)
defer ticker.Stop()
for range ticker.C {
purgeExpiredCodexCache()
}
}()
}
// purgeExpiredCodexCache removes entries that have expired.
func purgeExpiredCodexCache() {
now := time.Now()
codexCacheMu.Lock()
defer codexCacheMu.Unlock()
for key, cache := range codexCacheMap {
if cache.Expire.Before(now) {
delete(codexCacheMap, key)
}
}
}
// getCodexCache retrieves a cached entry, returning ok=false if not found or expired.
func getCodexCache(key string) (codexCache, bool) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.RLock()
cache, ok := codexCacheMap[key]
codexCacheMu.RUnlock()
if !ok || cache.Expire.Before(time.Now()) {
return codexCache{}, false
}
return cache, true
}
// setCodexCache stores a cache entry.
func setCodexCache(key string, cache codexCache) {
codexCacheCleanupOnce.Do(startCodexCacheCleanup)
codexCacheMu.Lock()
codexCacheMap[key] = cache
codexCacheMu.Unlock()
}
================================================
FILE: internal/runtime/executor/caching_verify_test.go
================================================
package executor
import (
"fmt"
"testing"
"github.com/tidwall/gjson"
)
func TestEnsureCacheControl(t *testing.T) {
// Test case 1: System prompt as string
t.Run("String System Prompt", func(t *testing.T) {
input := []byte(`{"model": "claude-3-5-sonnet", "system": "This is a long system prompt", "messages": []}`)
output := ensureCacheControl(input)
res := gjson.GetBytes(output, "system.0.cache_control.type")
if res.String() != "ephemeral" {
t.Errorf("cache_control not found in system string. Output: %s", string(output))
}
})
// Test case 2: System prompt as array
t.Run("Array System Prompt", func(t *testing.T) {
input := []byte(`{"model": "claude-3-5-sonnet", "system": [{"type": "text", "text": "Part 1"}, {"type": "text", "text": "Part 2"}], "messages": []}`)
output := ensureCacheControl(input)
// cache_control should only be on the LAST element
res0 := gjson.GetBytes(output, "system.0.cache_control")
res1 := gjson.GetBytes(output, "system.1.cache_control.type")
if res0.Exists() {
t.Errorf("cache_control should NOT be on the first element")
}
if res1.String() != "ephemeral" {
t.Errorf("cache_control not found on last system element. Output: %s", string(output))
}
})
// Test case 3: Tools are cached
t.Run("Tools Caching", func(t *testing.T) {
input := []byte(`{
"model": "claude-3-5-sonnet",
"tools": [
{"name": "tool1", "description": "First tool", "input_schema": {"type": "object"}},
{"name": "tool2", "description": "Second tool", "input_schema": {"type": "object"}}
],
"system": "System prompt",
"messages": []
}`)
output := ensureCacheControl(input)
// cache_control should only be on the LAST tool
tool0Cache := gjson.GetBytes(output, "tools.0.cache_control")
tool1Cache := gjson.GetBytes(output, "tools.1.cache_control.type")
if tool0Cache.Exists() {
t.Errorf("cache_control should NOT be on the first tool")
}
if tool1Cache.String() != "ephemeral" {
t.Errorf("cache_control not found on last tool. Output: %s", string(output))
}
// System should also have cache_control
systemCache := gjson.GetBytes(output, "system.0.cache_control.type")
if systemCache.String() != "ephemeral" {
t.Errorf("cache_control not found in system. Output: %s", string(output))
}
})
// Test case 4: Tools and system are INDEPENDENT breakpoints
// Per Anthropic docs: Up to 4 breakpoints allowed, tools and system are cached separately
t.Run("Independent Cache Breakpoints", func(t *testing.T) {
input := []byte(`{
"model": "claude-3-5-sonnet",
"tools": [
{"name": "tool1", "description": "First tool", "input_schema": {"type": "object"}, "cache_control": {"type": "ephemeral"}}
],
"system": [{"type": "text", "text": "System"}],
"messages": []
}`)
output := ensureCacheControl(input)
// Tool already has cache_control - should not be changed
tool0Cache := gjson.GetBytes(output, "tools.0.cache_control.type")
if tool0Cache.String() != "ephemeral" {
t.Errorf("existing cache_control was incorrectly removed")
}
// System SHOULD get cache_control because it is an INDEPENDENT breakpoint
// Tools and system are separate cache levels in the hierarchy
systemCache := gjson.GetBytes(output, "system.0.cache_control.type")
if systemCache.String() != "ephemeral" {
t.Errorf("system should have its own cache_control breakpoint (independent of tools)")
}
})
// Test case 5: Only tools, no system
t.Run("Only Tools No System", func(t *testing.T) {
input := []byte(`{
"model": "claude-3-5-sonnet",
"tools": [
{"name": "tool1", "description": "Tool", "input_schema": {"type": "object"}}
],
"messages": [{"role": "user", "content": "Hi"}]
}`)
output := ensureCacheControl(input)
toolCache := gjson.GetBytes(output, "tools.0.cache_control.type")
if toolCache.String() != "ephemeral" {
t.Errorf("cache_control not found on tool. Output: %s", string(output))
}
})
// Test case 6: Many tools (Claude Code scenario)
t.Run("Many Tools (Claude Code Scenario)", func(t *testing.T) {
// Simulate Claude Code with many tools
toolsJSON := `[`
for i := 0; i < 50; i++ {
if i > 0 {
toolsJSON += ","
}
toolsJSON += fmt.Sprintf(`{"name": "tool%d", "description": "Tool %d", "input_schema": {"type": "object"}}`, i, i)
}
toolsJSON += `]`
input := []byte(fmt.Sprintf(`{
"model": "claude-3-5-sonnet",
"tools": %s,
"system": [{"type": "text", "text": "You are Claude Code"}],
"messages": [{"role": "user", "content": "Hello"}]
}`, toolsJSON))
output := ensureCacheControl(input)
// Only the last tool (index 49) should have cache_control
for i := 0; i < 49; i++ {
path := fmt.Sprintf("tools.%d.cache_control", i)
if gjson.GetBytes(output, path).Exists() {
t.Errorf("tool %d should NOT have cache_control", i)
}
}
lastToolCache := gjson.GetBytes(output, "tools.49.cache_control.type")
if lastToolCache.String() != "ephemeral" {
t.Errorf("last tool (49) should have cache_control")
}
// System should also have cache_control
systemCache := gjson.GetBytes(output, "system.0.cache_control.type")
if systemCache.String() != "ephemeral" {
t.Errorf("system should have cache_control")
}
t.Log("test passed: 50 tools - cache_control only on last tool")
})
// Test case 7: Empty tools array
t.Run("Empty Tools Array", func(t *testing.T) {
input := []byte(`{"model": "claude-3-5-sonnet", "tools": [], "system": "Test", "messages": []}`)
output := ensureCacheControl(input)
// System should still get cache_control
systemCache := gjson.GetBytes(output, "system.0.cache_control.type")
if systemCache.String() != "ephemeral" {
t.Errorf("system should have cache_control even with empty tools array")
}
})
// Test case 8: Messages caching for multi-turn (second-to-last user)
t.Run("Messages Caching Second-To-Last User", func(t *testing.T) {
input := []byte(`{
"model": "claude-3-5-sonnet",
"messages": [
{"role": "user", "content": "First user"},
{"role": "assistant", "content": "Assistant reply"},
{"role": "user", "content": "Second user"},
{"role": "assistant", "content": "Assistant reply 2"},
{"role": "user", "content": "Third user"}
]
}`)
output := ensureCacheControl(input)
cacheType := gjson.GetBytes(output, "messages.2.content.0.cache_control.type")
if cacheType.String() != "ephemeral" {
t.Errorf("cache_control not found on second-to-last user turn. Output: %s", string(output))
}
lastUserCache := gjson.GetBytes(output, "messages.4.content.0.cache_control")
if lastUserCache.Exists() {
t.Errorf("last user turn should NOT have cache_control")
}
})
// Test case 9: Existing message cache_control should skip injection
t.Run("Messages Skip When Cache Control Exists", func(t *testing.T) {
input := []byte(`{
"model": "claude-3-5-sonnet",
"messages": [
{"role": "user", "content": [{"type": "text", "text": "First user"}]},
{"role": "assistant", "content": [{"type": "text", "text": "Assistant reply", "cache_control": {"type": "ephemeral"}}]},
{"role": "user", "content": [{"type": "text", "text": "Second user"}]}
]
}`)
output := ensureCacheControl(input)
userCache := gjson.GetBytes(output, "messages.0.content.0.cache_control")
if userCache.Exists() {
t.Errorf("cache_control should NOT be injected when a message already has cache_control")
}
existingCache := gjson.GetBytes(output, "messages.1.content.0.cache_control.type")
if existingCache.String() != "ephemeral" {
t.Errorf("existing cache_control should be preserved. Output: %s", string(output))
}
})
}
// TestCacheControlOrder verifies the correct order: tools -> system -> messages
func TestCacheControlOrder(t *testing.T) {
input := []byte(`{
"model": "claude-sonnet-4",
"tools": [
{"name": "Read", "description": "Read file", "input_schema": {"type": "object", "properties": {"path": {"type": "string"}}}},
{"name": "Write", "description": "Write file", "input_schema": {"type": "object", "properties": {"path": {"type": "string"}, "content": {"type": "string"}}}}
],
"system": [
{"type": "text", "text": "You are Claude Code, Anthropic's official CLI for Claude."},
{"type": "text", "text": "Additional instructions here..."}
],
"messages": [
{"role": "user", "content": "Hello"}
]
}`)
output := ensureCacheControl(input)
// 1. Last tool has cache_control
if gjson.GetBytes(output, "tools.1.cache_control.type").String() != "ephemeral" {
t.Error("last tool should have cache_control")
}
// 2. First tool has NO cache_control
if gjson.GetBytes(output, "tools.0.cache_control").Exists() {
t.Error("first tool should NOT have cache_control")
}
// 3. Last system element has cache_control
if gjson.GetBytes(output, "system.1.cache_control.type").String() != "ephemeral" {
t.Error("last system element should have cache_control")
}
// 4. First system element has NO cache_control
if gjson.GetBytes(output, "system.0.cache_control").Exists() {
t.Error("first system element should NOT have cache_control")
}
t.Log("cache order correct: tools -> system")
}
================================================
FILE: internal/runtime/executor/claude_executor.go
================================================
package executor
import (
"bufio"
"bytes"
"compress/flate"
"compress/gzip"
"context"
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/textproto"
"runtime"
"strings"
"time"
"github.com/andybalholm/brotli"
"github.com/klauspost/compress/zstd"
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/gin-gonic/gin"
)
// ClaudeExecutor is a stateless executor for Anthropic Claude over the messages API.
// If api_key is unavailable on auth, it falls back to legacy via ClientAdapter.
type ClaudeExecutor struct {
cfg *config.Config
}
// claudeToolPrefix is empty to match real Claude Code behavior (no tool name prefix).
// Previously "proxy_" was used but this is a detectable fingerprint difference.
const claudeToolPrefix = ""
func NewClaudeExecutor(cfg *config.Config) *ClaudeExecutor { return &ClaudeExecutor{cfg: cfg} }
func (e *ClaudeExecutor) Identifier() string { return "claude" }
// PrepareRequest injects Claude credentials into the outgoing HTTP request.
func (e *ClaudeExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, _ := claudeCreds(auth)
if strings.TrimSpace(apiKey) == "" {
return nil
}
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := req.URL != nil && strings.EqualFold(req.URL.Scheme, "https") && strings.EqualFold(req.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
req.Header.Del("Authorization")
req.Header.Set("x-api-key", apiKey)
} else {
req.Header.Del("x-api-key")
req.Header.Set("Authorization", "Bearer "+apiKey)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest injects Claude credentials into the request and executes it.
func (e *ClaudeExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("claude executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := claudeCreds(auth)
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
stream := from != to
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Auto-inject cache_control if missing (optimization for ClawdBot/clients without caching support)
if countCacheControls(body) == 0 {
body = ensureCacheControl(body)
}
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
// Cloaking and ensureCacheControl may push the total over 4 when the client
// (e.g. Amp CLI) already sends multiple cache_control blocks.
body = enforceCacheControlLimit(body, 4)
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
// A 1h-TTL block must not appear after a 5m-TTL block in evaluation order (tools→system→messages).
body = normalizeCacheControlTTL(body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
if err != nil {
return resp, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: bodyForUpstream,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
return resp, statusErr{code: httpResp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return resp, err
}
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return resp, err
}
defer func() {
if errClose := decodedBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
data, err := io.ReadAll(decodedBody)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
if stream {
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines {
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
}
} else {
reporter.publish(ctx, parseClaudeUsage(data))
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
data = stripClaudeToolPrefixFromResponse(data, claudeToolPrefix)
}
var param any
out := sdktranslator.TranslateNonStream(
ctx,
to,
from,
req.Model,
opts.OriginalRequest,
bodyForTranslation,
data,
¶m,
)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := claudeCreds(auth)
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
// Apply cloaking (system prompt injection, fake user ID, sensitive word obfuscation)
// based on client type and configuration.
body = applyCloaking(ctx, e.cfg, auth, body, baseModel, apiKey)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
// Disable thinking if tool_choice forces tool use (Anthropic API constraint)
body = disableThinkingIfToolChoiceForced(body)
// Auto-inject cache_control if missing (optimization for ClawdBot/clients without caching support)
if countCacheControls(body) == 0 {
body = ensureCacheControl(body)
}
// Enforce Anthropic's cache_control block limit (max 4 breakpoints per request).
body = enforceCacheControlLimit(body, 4)
// Normalize TTL values to prevent ordering violations under prompt-caching-scope-2026-01-05.
body = normalizeCacheControlTTL(body)
// Extract betas from body and convert to header
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
bodyForTranslation := body
bodyForUpstream := body
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
bodyForUpstream = applyClaudeToolPrefix(body, claudeToolPrefix)
}
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(bodyForUpstream))
if err != nil {
return nil, err
}
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: bodyForUpstream,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
return nil, statusErr{code: httpResp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
decodedBody, err := decodeResponseBody(httpResp.Body, httpResp.Header.Get("Content-Encoding"))
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := decodedBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
// If from == to (Claude → Claude), directly forward the SSE stream without translation
if from == to {
scanner := bufio.NewScanner(decodedBody)
scanner.Buffer(nil, 52_428_800) // 50MB
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
// Forward the line as-is to preserve SSE format
cloned := make([]byte, len(line)+1)
copy(cloned, line)
cloned[len(line)] = '\n'
out <- cliproxyexecutor.StreamChunk{Payload: cloned}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
return
}
// For other formats, use translation
scanner := bufio.NewScanner(decodedBody)
scanner.Buffer(nil, 52_428_800) // 50MB
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseClaudeStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
line = stripClaudeToolPrefixFromStreamLine(line, claudeToolPrefix)
}
chunks := sdktranslator.TranslateStream(
ctx,
to,
from,
req.Model,
opts.OriginalRequest,
bodyForTranslation,
bytes.Clone(line),
¶m,
)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := claudeCreds(auth)
if baseURL == "" {
baseURL = "https://api.anthropic.com"
}
from := opts.SourceFormat
to := sdktranslator.FromString("claude")
// Use streaming translation to preserve function calling, except for claude.
stream := from != to
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
body, _ = sjson.SetBytes(body, "model", baseModel)
if !strings.HasPrefix(baseModel, "claude-3-5-haiku") {
body = checkSystemInstructions(body)
}
// Keep count_tokens requests compatible with Anthropic cache-control constraints too.
body = enforceCacheControlLimit(body, 4)
body = normalizeCacheControlTTL(body)
// Extract betas from body and convert to header (for count_tokens too)
var extraBetas []string
extraBetas, body = extractAndRemoveBetas(body)
if isClaudeOAuthToken(apiKey) && !auth.ToolPrefixDisabled() {
body = applyClaudeToolPrefix(body, claudeToolPrefix)
}
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return cliproxyexecutor.Response{}, err
}
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
resp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
// Decompress error responses — pass the Content-Encoding value (may be empty)
// and let decodeResponseBody handle both header-declared and magic-byte-detected
// compression. This keeps error-path behaviour consistent with the success path.
errBody, decErr := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
if decErr != nil {
recordAPIResponseError(ctx, e.cfg, decErr)
msg := fmt.Sprintf("failed to decode error response body: %v", decErr)
logWithRequestID(ctx).Warn(msg)
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: msg}
}
b, readErr := io.ReadAll(errBody)
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
msg := fmt.Sprintf("failed to read error response body: %v", readErr)
logWithRequestID(ctx).Warn(msg)
b = []byte(msg)
}
appendAPIResponseChunk(ctx, e.cfg, b)
if errClose := errBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(b)}
}
decodedBody, err := decodeResponseBody(resp.Body, resp.Header.Get("Content-Encoding"))
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
return cliproxyexecutor.Response{}, err
}
defer func() {
if errClose := decodedBody.Close(); errClose != nil {
log.Errorf("response body close error: %v", errClose)
}
}()
data, err := io.ReadAll(decodedBody)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "input_tokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: resp.Header.Clone()}, nil
}
func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("claude executor: refresh called")
if auth == nil {
return nil, fmt.Errorf("claude executor: auth is nil")
}
var refreshToken string
if auth.Metadata != nil {
if v, ok := auth.Metadata["refresh_token"].(string); ok && v != "" {
refreshToken = v
}
}
if refreshToken == "" {
return auth, nil
}
svc := claudeauth.NewClaudeAuth(e.cfg)
td, err := svc.RefreshTokens(ctx, refreshToken)
if err != nil {
return nil, err
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["access_token"] = td.AccessToken
if td.RefreshToken != "" {
auth.Metadata["refresh_token"] = td.RefreshToken
}
auth.Metadata["email"] = td.Email
auth.Metadata["expired"] = td.Expire
auth.Metadata["type"] = "claude"
now := time.Now().Format(time.RFC3339)
auth.Metadata["last_refresh"] = now
return auth, nil
}
// extractAndRemoveBetas extracts the "betas" array from the body and removes it.
// Returns the extracted betas as a string slice and the modified body.
func extractAndRemoveBetas(body []byte) ([]string, []byte) {
betasResult := gjson.GetBytes(body, "betas")
if !betasResult.Exists() {
return nil, body
}
var betas []string
if betasResult.IsArray() {
for _, item := range betasResult.Array() {
if s := strings.TrimSpace(item.String()); s != "" {
betas = append(betas, s)
}
}
} else if s := strings.TrimSpace(betasResult.String()); s != "" {
betas = append(betas, s)
}
body, _ = sjson.DeleteBytes(body, "betas")
return betas, body
}
// disableThinkingIfToolChoiceForced checks if tool_choice forces tool use and disables thinking.
// Anthropic API does not allow thinking when tool_choice is set to "any" or a specific tool.
// See: https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#important-considerations
func disableThinkingIfToolChoiceForced(body []byte) []byte {
toolChoiceType := gjson.GetBytes(body, "tool_choice.type").String()
// "auto" is allowed with thinking, but "any" or "tool" (specific tool) are not
if toolChoiceType == "any" || toolChoiceType == "tool" {
// Remove thinking configuration entirely to avoid API error
body, _ = sjson.DeleteBytes(body, "thinking")
// Adaptive thinking may also set output_config.effort; remove it to avoid
// leaking thinking controls when tool_choice forces tool use.
body, _ = sjson.DeleteBytes(body, "output_config.effort")
if oc := gjson.GetBytes(body, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
body, _ = sjson.DeleteBytes(body, "output_config")
}
}
return body
}
type compositeReadCloser struct {
io.Reader
closers []func() error
}
func (c *compositeReadCloser) Close() error {
var firstErr error
for i := range c.closers {
if c.closers[i] == nil {
continue
}
if err := c.closers[i](); err != nil && firstErr == nil {
firstErr = err
}
}
return firstErr
}
// peekableBody wraps a bufio.Reader around the original ReadCloser so that
// magic bytes can be inspected without consuming them from the stream.
type peekableBody struct {
*bufio.Reader
closer io.Closer
}
func (p *peekableBody) Close() error {
return p.closer.Close()
}
func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadCloser, error) {
if body == nil {
return nil, fmt.Errorf("response body is nil")
}
if contentEncoding == "" {
// No Content-Encoding header. Attempt best-effort magic-byte detection to
// handle misbehaving upstreams that compress without setting the header.
// Only gzip (1f 8b) and zstd (28 b5 2f fd) have reliable magic sequences;
// br and deflate have none and are left as-is.
// The bufio wrapper preserves unread bytes so callers always see the full
// stream regardless of whether decompression was applied.
pb := &peekableBody{Reader: bufio.NewReader(body), closer: body}
magic, peekErr := pb.Peek(4)
if peekErr == nil || (peekErr == io.EOF && len(magic) >= 2) {
switch {
case len(magic) >= 2 && magic[0] == 0x1f && magic[1] == 0x8b:
gzipReader, gzErr := gzip.NewReader(pb)
if gzErr != nil {
_ = pb.Close()
return nil, fmt.Errorf("magic-byte gzip: failed to create reader: %w", gzErr)
}
return &compositeReadCloser{
Reader: gzipReader,
closers: []func() error{
gzipReader.Close,
pb.Close,
},
}, nil
case len(magic) >= 4 && magic[0] == 0x28 && magic[1] == 0xb5 && magic[2] == 0x2f && magic[3] == 0xfd:
decoder, zdErr := zstd.NewReader(pb)
if zdErr != nil {
_ = pb.Close()
return nil, fmt.Errorf("magic-byte zstd: failed to create reader: %w", zdErr)
}
return &compositeReadCloser{
Reader: decoder,
closers: []func() error{
func() error { decoder.Close(); return nil },
pb.Close,
},
}, nil
}
}
return pb, nil
}
encodings := strings.Split(contentEncoding, ",")
for _, raw := range encodings {
encoding := strings.TrimSpace(strings.ToLower(raw))
switch encoding {
case "", "identity":
continue
case "gzip":
gzipReader, err := gzip.NewReader(body)
if err != nil {
_ = body.Close()
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
}
return &compositeReadCloser{
Reader: gzipReader,
closers: []func() error{
gzipReader.Close,
func() error { return body.Close() },
},
}, nil
case "deflate":
deflateReader := flate.NewReader(body)
return &compositeReadCloser{
Reader: deflateReader,
closers: []func() error{
deflateReader.Close,
func() error { return body.Close() },
},
}, nil
case "br":
return &compositeReadCloser{
Reader: brotli.NewReader(body),
closers: []func() error{
func() error { return body.Close() },
},
}, nil
case "zstd":
decoder, err := zstd.NewReader(body)
if err != nil {
_ = body.Close()
return nil, fmt.Errorf("failed to create zstd reader: %w", err)
}
return &compositeReadCloser{
Reader: decoder,
closers: []func() error{
func() error { decoder.Close(); return nil },
func() error { return body.Close() },
},
}, nil
default:
continue
}
}
return body, nil
}
// mapStainlessOS maps runtime.GOOS to Stainless SDK OS names.
func mapStainlessOS() string {
switch runtime.GOOS {
case "darwin":
return "MacOS"
case "windows":
return "Windows"
case "linux":
return "Linux"
case "freebsd":
return "FreeBSD"
default:
return "Other::" + runtime.GOOS
}
}
// mapStainlessArch maps runtime.GOARCH to Stainless SDK architecture names.
func mapStainlessArch() string {
switch runtime.GOARCH {
case "amd64":
return "x64"
case "arm64":
return "arm64"
case "386":
return "x86"
default:
return "other::" + runtime.GOARCH
}
}
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string, cfg *config.Config) {
hdrDefault := func(cfgVal, fallback string) string {
if cfgVal != "" {
return cfgVal
}
return fallback
}
var hd config.ClaudeHeaderDefaults
if cfg != nil {
hd = cfg.ClaudeHeaderDefaults
}
useAPIKey := auth != nil && auth.Attributes != nil && strings.TrimSpace(auth.Attributes["api_key"]) != ""
isAnthropicBase := r.URL != nil && strings.EqualFold(r.URL.Scheme, "https") && strings.EqualFold(r.URL.Host, "api.anthropic.com")
if isAnthropicBase && useAPIKey {
r.Header.Del("Authorization")
r.Header.Set("x-api-key", apiKey)
} else {
r.Header.Set("Authorization", "Bearer "+apiKey)
}
r.Header.Set("Content-Type", "application/json")
var ginHeaders http.Header
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header
}
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,context-management-2025-06-27,prompt-caching-scope-2026-01-05"
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
baseBetas = val
if !strings.Contains(val, "oauth") {
baseBetas += ",oauth-2025-04-20"
}
}
hasClaude1MHeader := false
if ginHeaders != nil {
if _, ok := ginHeaders[textproto.CanonicalMIMEHeaderKey("X-CPA-CLAUDE-1M")]; ok {
hasClaude1MHeader = true
}
}
// Merge extra betas from request body and request flags.
if len(extraBetas) > 0 || hasClaude1MHeader {
existingSet := make(map[string]bool)
for _, b := range strings.Split(baseBetas, ",") {
betaName := strings.TrimSpace(b)
if betaName != "" {
existingSet[betaName] = true
}
}
for _, beta := range extraBetas {
beta = strings.TrimSpace(beta)
if beta != "" && !existingSet[beta] {
baseBetas += "," + beta
existingSet[beta] = true
}
}
if hasClaude1MHeader && !existingSet["context-1m-2025-08-07"] {
baseBetas += ",context-1m-2025-08-07"
}
}
r.Header.Set("Anthropic-Beta", baseBetas)
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
// Values below match Claude Code 2.1.63 / @anthropic-ai/sdk 0.74.0 (updated 2026-02-28).
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Retry-Count", "0")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime-Version", hdrDefault(hd.RuntimeVersion, "v24.3.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Package-Version", hdrDefault(hd.PackageVersion, "0.74.0"))
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Runtime", "node")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Lang", "js")
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Arch", mapStainlessArch())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Os", mapStainlessOS())
misc.EnsureHeader(r.Header, ginHeaders, "X-Stainless-Timeout", hdrDefault(hd.Timeout, "600"))
// For User-Agent, only forward the client's header if it's already a Claude Code client.
// Non-Claude-Code clients (e.g. curl, OpenAI SDKs) get the default Claude Code User-Agent
// to avoid leaking the real client identity during cloaking.
clientUA := ""
if ginHeaders != nil {
clientUA = ginHeaders.Get("User-Agent")
}
if isClaudeCodeClient(clientUA) {
r.Header.Set("User-Agent", clientUA)
} else {
r.Header.Set("User-Agent", hdrDefault(hd.UserAgent, "claude-cli/2.1.63 (external, cli)"))
}
r.Header.Set("Connection", "keep-alive")
if stream {
r.Header.Set("Accept", "text/event-stream")
// SSE streams must not be compressed: the downstream scanner reads
// line-delimited text and cannot parse compressed bytes. Using
// "identity" tells the upstream to send an uncompressed stream.
r.Header.Set("Accept-Encoding", "identity")
} else {
r.Header.Set("Accept", "application/json")
r.Header.Set("Accept-Encoding", "gzip, deflate, br, zstd")
}
// Keep OS/Arch mapping dynamic (not configurable).
// They intentionally continue to derive from runtime.GOOS/runtime.GOARCH.
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(r, attrs)
// Re-enforce Accept-Encoding: identity after ApplyCustomHeadersFromAttrs, which
// may override it with a user-configured value. Compressed SSE breaks the line
// scanner regardless of user preference, so this is non-negotiable for streams.
if stream {
r.Header.Set("Accept-Encoding", "identity")
}
}
func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
if a == nil {
return "", ""
}
if a.Attributes != nil {
apiKey = a.Attributes["api_key"]
baseURL = a.Attributes["base_url"]
}
if apiKey == "" && a.Metadata != nil {
if v, ok := a.Metadata["access_token"].(string); ok {
apiKey = v
}
}
return
}
func checkSystemInstructions(payload []byte) []byte {
return checkSystemInstructionsWithMode(payload, false)
}
func isClaudeOAuthToken(apiKey string) bool {
return strings.Contains(apiKey, "sk-ant-oat")
}
func applyClaudeToolPrefix(body []byte, prefix string) []byte {
if prefix == "" {
return body
}
// Collect built-in tool names (those with a non-empty "type" field) so we can
// skip them consistently in both tools and message history.
builtinTools := map[string]bool{}
for _, name := range []string{"web_search", "code_execution", "text_editor", "computer"} {
builtinTools[name] = true
}
if tools := gjson.GetBytes(body, "tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(index, tool gjson.Result) bool {
// Skip built-in tools (web_search, code_execution, etc.) which have
// a "type" field and require their name to remain unchanged.
if tool.Get("type").Exists() && tool.Get("type").String() != "" {
if n := tool.Get("name").String(); n != "" {
builtinTools[n] = true
}
return true
}
name := tool.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("tools.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
return true
})
}
if gjson.GetBytes(body, "tool_choice.type").String() == "tool" {
name := gjson.GetBytes(body, "tool_choice.name").String()
if name != "" && !strings.HasPrefix(name, prefix) && !builtinTools[name] {
body, _ = sjson.SetBytes(body, "tool_choice.name", prefix+name)
}
}
if messages := gjson.GetBytes(body, "messages"); messages.Exists() && messages.IsArray() {
messages.ForEach(func(msgIndex, msg gjson.Result) bool {
content := msg.Get("content")
if !content.Exists() || !content.IsArray() {
return true
}
content.ForEach(func(contentIndex, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if name == "" || strings.HasPrefix(name, prefix) || builtinTools[name] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+name)
case "tool_reference":
toolName := part.Get("tool_name").String()
if toolName == "" || strings.HasPrefix(toolName, prefix) || builtinTools[toolName] {
return true
}
path := fmt.Sprintf("messages.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int())
body, _ = sjson.SetBytes(body, path, prefix+toolName)
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if nestedToolName != "" && !strings.HasPrefix(nestedToolName, prefix) && !builtinTools[nestedToolName] {
nestedPath := fmt.Sprintf("messages.%d.content.%d.content.%d.tool_name", msgIndex.Int(), contentIndex.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, prefix+nestedToolName)
}
}
return true
})
}
}
return true
})
return true
})
}
return body
}
func stripClaudeToolPrefixFromResponse(body []byte, prefix string) []byte {
if prefix == "" {
return body
}
content := gjson.GetBytes(body, "content")
if !content.Exists() || !content.IsArray() {
return body
}
content.ForEach(func(index, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "tool_use":
name := part.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return true
}
path := fmt.Sprintf("content.%d.name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(name, prefix))
case "tool_reference":
toolName := part.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return true
}
path := fmt.Sprintf("content.%d.tool_name", index.Int())
body, _ = sjson.SetBytes(body, path, strings.TrimPrefix(toolName, prefix))
case "tool_result":
// Handle nested tool_reference blocks inside tool_result.content[]
nestedContent := part.Get("content")
if nestedContent.Exists() && nestedContent.IsArray() {
nestedContent.ForEach(func(nestedIndex, nestedPart gjson.Result) bool {
if nestedPart.Get("type").String() == "tool_reference" {
nestedToolName := nestedPart.Get("tool_name").String()
if strings.HasPrefix(nestedToolName, prefix) {
nestedPath := fmt.Sprintf("content.%d.content.%d.tool_name", index.Int(), nestedIndex.Int())
body, _ = sjson.SetBytes(body, nestedPath, strings.TrimPrefix(nestedToolName, prefix))
}
}
return true
})
}
}
return true
})
return body
}
func stripClaudeToolPrefixFromStreamLine(line []byte, prefix string) []byte {
if prefix == "" {
return line
}
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return line
}
contentBlock := gjson.GetBytes(payload, "content_block")
if !contentBlock.Exists() {
return line
}
blockType := contentBlock.Get("type").String()
var updated []byte
var err error
switch blockType {
case "tool_use":
name := contentBlock.Get("name").String()
if !strings.HasPrefix(name, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.name", strings.TrimPrefix(name, prefix))
if err != nil {
return line
}
case "tool_reference":
toolName := contentBlock.Get("tool_name").String()
if !strings.HasPrefix(toolName, prefix) {
return line
}
updated, err = sjson.SetBytes(payload, "content_block.tool_name", strings.TrimPrefix(toolName, prefix))
if err != nil {
return line
}
default:
return line
}
trimmed := bytes.TrimSpace(line)
if bytes.HasPrefix(trimmed, []byte("data:")) {
return append([]byte("data: "), updated...)
}
return updated
}
// getClientUserAgent extracts the client User-Agent from the gin context.
func getClientUserAgent(ctx context.Context) string {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
return ginCtx.GetHeader("User-Agent")
}
return ""
}
// getCloakConfigFromAuth extracts cloak configuration from auth attributes.
// Returns (cloakMode, strictMode, sensitiveWords, cacheUserID).
func getCloakConfigFromAuth(auth *cliproxyauth.Auth) (string, bool, []string, bool) {
if auth == nil || auth.Attributes == nil {
return "auto", false, nil, false
}
cloakMode := auth.Attributes["cloak_mode"]
if cloakMode == "" {
cloakMode = "auto"
}
strictMode := strings.ToLower(auth.Attributes["cloak_strict_mode"]) == "true"
var sensitiveWords []string
if wordsStr := auth.Attributes["cloak_sensitive_words"]; wordsStr != "" {
sensitiveWords = strings.Split(wordsStr, ",")
for i := range sensitiveWords {
sensitiveWords[i] = strings.TrimSpace(sensitiveWords[i])
}
}
cacheUserID := strings.EqualFold(strings.TrimSpace(auth.Attributes["cloak_cache_user_id"]), "true")
return cloakMode, strictMode, sensitiveWords, cacheUserID
}
// resolveClaudeKeyCloakConfig finds the matching ClaudeKey config and returns its CloakConfig.
func resolveClaudeKeyCloakConfig(cfg *config.Config, auth *cliproxyauth.Auth) *config.CloakConfig {
if cfg == nil || auth == nil {
return nil
}
apiKey, baseURL := claudeCreds(auth)
if apiKey == "" {
return nil
}
for i := range cfg.ClaudeKey {
entry := &cfg.ClaudeKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
// Match by API key
if strings.EqualFold(cfgKey, apiKey) {
// If baseURL is specified, also check it
if baseURL != "" && cfgBase != "" && !strings.EqualFold(cfgBase, baseURL) {
continue
}
return entry.Cloak
}
}
return nil
}
// injectFakeUserID generates and injects a fake user ID into the request metadata.
// When useCache is false, a new user ID is generated for every call.
func injectFakeUserID(payload []byte, apiKey string, useCache bool) []byte {
generateID := func() string {
if useCache {
return cachedUserID(apiKey)
}
return generateFakeUserID()
}
metadata := gjson.GetBytes(payload, "metadata")
if !metadata.Exists() {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
return payload
}
existingUserID := gjson.GetBytes(payload, "metadata.user_id").String()
if existingUserID == "" || !isValidUserID(existingUserID) {
payload, _ = sjson.SetBytes(payload, "metadata.user_id", generateID())
}
return payload
}
// generateBillingHeader creates the x-anthropic-billing-header text block that
// real Claude Code prepends to every system prompt array.
// Format: x-anthropic-billing-header: cc_version=.; cc_entrypoint=cli; cch=;
func generateBillingHeader(payload []byte) string {
// Generate a deterministic cch hash from the payload content (system + messages + tools).
// Real Claude Code uses a 5-char hex hash that varies per request.
h := sha256.Sum256(payload)
cch := hex.EncodeToString(h[:])[:5]
// Build hash: 3-char hex, matches the pattern seen in real requests (e.g. "a43")
buildBytes := make([]byte, 2)
_, _ = rand.Read(buildBytes)
buildHash := hex.EncodeToString(buildBytes)[:3]
return fmt.Sprintf("x-anthropic-billing-header: cc_version=2.1.63.%s; cc_entrypoint=cli; cch=%s;", buildHash, cch)
}
// checkSystemInstructionsWithMode injects Claude Code-style system blocks:
//
// system[0]: billing header (no cache_control)
// system[1]: agent identifier (no cache_control)
// system[2..]: user system messages (cache_control added when missing)
func checkSystemInstructionsWithMode(payload []byte, strictMode bool) []byte {
system := gjson.GetBytes(payload, "system")
billingText := generateBillingHeader(payload)
billingBlock := fmt.Sprintf(`{"type":"text","text":"%s"}`, billingText)
// No cache_control on the agent block. It is a cloaking artifact with zero cache
// value (the last system block is what actually triggers caching of all system content).
// Including any cache_control here creates an intra-system TTL ordering violation
// when the client's system blocks use ttl='1h' (prompt-caching-scope-2026-01-05 beta
// forbids 1h blocks after 5m blocks, and a no-TTL block defaults to 5m).
agentBlock := `{"type":"text","text":"You are a Claude agent, built on Anthropic's Claude Agent SDK."}`
if strictMode {
// Strict mode: billing header + agent identifier only
result := "[" + billingBlock + "," + agentBlock + "]"
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
return payload
}
// Non-strict mode: billing header + agent identifier + user system messages
// Skip if already injected
firstText := gjson.GetBytes(payload, "system.0.text").String()
if strings.HasPrefix(firstText, "x-anthropic-billing-header:") {
return payload
}
result := "[" + billingBlock + "," + agentBlock
if system.IsArray() {
system.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
// Add cache_control to user system messages if not present.
// Do NOT add ttl — let it inherit the default (5m) to avoid
// TTL ordering violations with the prompt-caching-scope-2026-01-05 beta.
partJSON := part.Raw
if !part.Get("cache_control").Exists() {
partJSON, _ = sjson.Set(partJSON, "cache_control.type", "ephemeral")
}
result += "," + partJSON
}
return true
})
} else if system.Type == gjson.String && system.String() != "" {
partJSON := `{"type":"text","cache_control":{"type":"ephemeral"}}`
partJSON, _ = sjson.Set(partJSON, "text", system.String())
result += "," + partJSON
}
result += "]"
payload, _ = sjson.SetRawBytes(payload, "system", []byte(result))
return payload
}
// applyCloaking applies cloaking transformations to the payload based on config and client.
// Cloaking includes: system prompt injection, fake user ID, and sensitive word obfuscation.
func applyCloaking(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, payload []byte, model string, apiKey string) []byte {
clientUserAgent := getClientUserAgent(ctx)
// Get cloak config from ClaudeKey configuration
cloakCfg := resolveClaudeKeyCloakConfig(cfg, auth)
// Determine cloak settings
var cloakMode string
var strictMode bool
var sensitiveWords []string
var cacheUserID bool
if cloakCfg != nil {
cloakMode = cloakCfg.Mode
strictMode = cloakCfg.StrictMode
sensitiveWords = cloakCfg.SensitiveWords
if cloakCfg.CacheUserID != nil {
cacheUserID = *cloakCfg.CacheUserID
}
}
// Fallback to auth attributes if no config found
if cloakMode == "" {
attrMode, attrStrict, attrWords, attrCache := getCloakConfigFromAuth(auth)
cloakMode = attrMode
if !strictMode {
strictMode = attrStrict
}
if len(sensitiveWords) == 0 {
sensitiveWords = attrWords
}
if cloakCfg == nil || cloakCfg.CacheUserID == nil {
cacheUserID = attrCache
}
} else if cloakCfg == nil || cloakCfg.CacheUserID == nil {
_, _, _, attrCache := getCloakConfigFromAuth(auth)
cacheUserID = attrCache
}
// Determine if cloaking should be applied
if !shouldCloak(cloakMode, clientUserAgent) {
return payload
}
// Skip system instructions for claude-3-5-haiku models
if !strings.HasPrefix(model, "claude-3-5-haiku") {
payload = checkSystemInstructionsWithMode(payload, strictMode)
}
// Inject fake user ID
payload = injectFakeUserID(payload, apiKey, cacheUserID)
// Apply sensitive word obfuscation
if len(sensitiveWords) > 0 {
matcher := buildSensitiveWordMatcher(sensitiveWords)
payload = obfuscateSensitiveWords(payload, matcher)
}
return payload
}
// ensureCacheControl injects cache_control breakpoints into the payload for optimal prompt caching.
// According to Anthropic's documentation, cache prefixes are created in order: tools -> system -> messages.
// This function adds cache_control to:
// 1. The LAST tool in the tools array (caches all tool definitions)
// 2. The LAST element in the system array (caches system prompt)
// 3. The SECOND-TO-LAST user turn (caches conversation history for multi-turn)
//
// Up to 4 cache breakpoints are allowed per request. Tools, System, and Messages are INDEPENDENT breakpoints.
// This enables up to 90% cost reduction on cached tokens (cache read = 0.1x base price).
// See: https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
func ensureCacheControl(payload []byte) []byte {
// 1. Inject cache_control into the LAST tool (caches all tool definitions)
// Tools are cached first in the hierarchy, so this is the most important breakpoint.
payload = injectToolsCacheControl(payload)
// 2. Inject cache_control into the LAST system prompt element
// System is the second level in the cache hierarchy.
payload = injectSystemCacheControl(payload)
// 3. Inject cache_control into messages for multi-turn conversation caching
// This caches the conversation history up to the second-to-last user turn.
payload = injectMessagesCacheControl(payload)
return payload
}
func countCacheControls(payload []byte) int {
count := 0
// Check system
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(_, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
count++
}
return true
})
}
// Check tools
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(_, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
count++
}
return true
})
}
// Check messages
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(_, msg gjson.Result) bool {
content := msg.Get("content")
if content.IsArray() {
content.ForEach(func(_, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
count++
}
return true
})
}
return true
})
}
return count
}
func parsePayloadObject(payload []byte) (map[string]any, bool) {
if len(payload) == 0 {
return nil, false
}
var root map[string]any
if err := json.Unmarshal(payload, &root); err != nil {
return nil, false
}
return root, true
}
func marshalPayloadObject(original []byte, root map[string]any) []byte {
if root == nil {
return original
}
out, err := json.Marshal(root)
if err != nil {
return original
}
return out
}
func asObject(v any) (map[string]any, bool) {
obj, ok := v.(map[string]any)
return obj, ok
}
func asArray(v any) ([]any, bool) {
arr, ok := v.([]any)
return arr, ok
}
func countCacheControlsMap(root map[string]any) int {
count := 0
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if tools, ok := asArray(root["tools"]); ok {
for _, item := range tools {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
if _, exists := obj["cache_control"]; exists {
count++
}
}
}
}
}
return count
}
func normalizeTTLForBlock(obj map[string]any, seen5m *bool) bool {
ccRaw, exists := obj["cache_control"]
if !exists {
return false
}
cc, ok := asObject(ccRaw)
if !ok {
*seen5m = true
return false
}
ttlRaw, ttlExists := cc["ttl"]
ttl, ttlIsString := ttlRaw.(string)
if !ttlExists || !ttlIsString || ttl != "1h" {
*seen5m = true
return false
}
if *seen5m {
delete(cc, "ttl")
return true
}
return false
}
func findLastCacheControlIndex(arr []any) int {
last := -1
for idx, item := range arr {
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
last = idx
}
}
return last
}
func stripCacheControlExceptIndex(arr []any, preserveIdx int, excess *int) {
for idx, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists && idx != preserveIdx {
delete(obj, "cache_control")
*excess--
}
}
}
func stripAllCacheControl(arr []any, excess *int) {
for _, item := range arr {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
func stripMessageCacheControl(messages []any, excess *int) {
for _, msg := range messages {
if *excess <= 0 {
return
}
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if *excess <= 0 {
return
}
obj, ok := asObject(item)
if !ok {
continue
}
if _, exists := obj["cache_control"]; exists {
delete(obj, "cache_control")
*excess--
}
}
}
}
// normalizeCacheControlTTL ensures cache_control TTL values don't violate the
// prompt-caching-scope-2026-01-05 ordering constraint: a 1h-TTL block must not
// appear after a 5m-TTL block anywhere in the evaluation order.
//
// Anthropic evaluates blocks in order: tools → system (index 0..N) → messages.
// Within each section, blocks are evaluated in array order. A 5m (default) block
// followed by a 1h block at ANY later position is an error — including within
// the same section (e.g. system[1]=5m then system[3]=1h).
//
// Strategy: walk all cache_control blocks in evaluation order. Once a 5m block
// is seen, strip ttl from ALL subsequent 1h blocks (downgrading them to 5m).
func normalizeCacheControlTTL(payload []byte) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
return payload
}
seen5m := false
modified := false
if tools, ok := asArray(root["tools"]); ok {
for _, tool := range tools {
if obj, ok := asObject(tool); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
if system, ok := asArray(root["system"]); ok {
for _, item := range system {
if obj, ok := asObject(item); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
if messages, ok := asArray(root["messages"]); ok {
for _, msg := range messages {
msgObj, ok := asObject(msg)
if !ok {
continue
}
content, ok := asArray(msgObj["content"])
if !ok {
continue
}
for _, item := range content {
if obj, ok := asObject(item); ok {
if normalizeTTLForBlock(obj, &seen5m) {
modified = true
}
}
}
}
}
if !modified {
return payload
}
return marshalPayloadObject(payload, root)
}
// enforceCacheControlLimit removes excess cache_control blocks from a payload
// so the total does not exceed the Anthropic API limit (currently 4).
//
// Anthropic evaluates cache breakpoints in order: tools → system → messages.
// The most valuable breakpoints are:
// 1. Last tool — caches ALL tool definitions
// 2. Last system block — caches ALL system content
// 3. Recent messages — cache conversation context
//
// Removal priority (strip lowest-value first):
//
// Phase 1: system blocks earliest-first, preserving the last one.
// Phase 2: tool blocks earliest-first, preserving the last one.
// Phase 3: message content blocks earliest-first.
// Phase 4: remaining system blocks (last system).
// Phase 5: remaining tool blocks (last tool).
func enforceCacheControlLimit(payload []byte, maxBlocks int) []byte {
root, ok := parsePayloadObject(payload)
if !ok {
return payload
}
total := countCacheControlsMap(root)
if total <= maxBlocks {
return payload
}
excess := total - maxBlocks
var system []any
if arr, ok := asArray(root["system"]); ok {
system = arr
}
var tools []any
if arr, ok := asArray(root["tools"]); ok {
tools = arr
}
var messages []any
if arr, ok := asArray(root["messages"]); ok {
messages = arr
}
if len(system) > 0 {
stripCacheControlExceptIndex(system, findLastCacheControlIndex(system), &excess)
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
}
if len(tools) > 0 {
stripCacheControlExceptIndex(tools, findLastCacheControlIndex(tools), &excess)
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
}
if len(messages) > 0 {
stripMessageCacheControl(messages, &excess)
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
}
if len(system) > 0 {
stripAllCacheControl(system, &excess)
}
if excess <= 0 {
return marshalPayloadObject(payload, root)
}
if len(tools) > 0 {
stripAllCacheControl(tools, &excess)
}
return marshalPayloadObject(payload, root)
}
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
// Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache."
// This enables caching of conversation history, which is especially beneficial for long multi-turn conversations.
// Only adds cache_control if:
// - There are at least 2 user turns in the conversation
// - No message content already has cache_control
func injectMessagesCacheControl(payload []byte) []byte {
messages := gjson.GetBytes(payload, "messages")
if !messages.Exists() || !messages.IsArray() {
return payload
}
// Check if ANY message content already has cache_control
hasCacheControlInMessages := false
messages.ForEach(func(_, msg gjson.Result) bool {
content := msg.Get("content")
if content.IsArray() {
content.ForEach(func(_, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
hasCacheControlInMessages = true
return false
}
return true
})
}
return !hasCacheControlInMessages
})
if hasCacheControlInMessages {
return payload
}
// Find all user message indices
var userMsgIndices []int
messages.ForEach(func(index gjson.Result, msg gjson.Result) bool {
if msg.Get("role").String() == "user" {
userMsgIndices = append(userMsgIndices, int(index.Int()))
}
return true
})
// Need at least 2 user turns to cache the second-to-last
if len(userMsgIndices) < 2 {
return payload
}
// Get the second-to-last user message index
secondToLastUserIdx := userMsgIndices[len(userMsgIndices)-2]
// Get the content of this message
contentPath := fmt.Sprintf("messages.%d.content", secondToLastUserIdx)
content := gjson.GetBytes(payload, contentPath)
if content.IsArray() {
// Add cache_control to the last content block of this message
contentCount := int(content.Get("#").Int())
if contentCount > 0 {
cacheControlPath := fmt.Sprintf("messages.%d.content.%d.cache_control", secondToLastUserIdx, contentCount-1)
result, err := sjson.SetBytes(payload, cacheControlPath, map[string]string{"type": "ephemeral"})
if err != nil {
log.Warnf("failed to inject cache_control into messages: %v", err)
return payload
}
payload = result
}
} else if content.Type == gjson.String {
// Convert string content to array with cache_control
text := content.String()
newContent := []map[string]interface{}{
{
"type": "text",
"text": text,
"cache_control": map[string]string{
"type": "ephemeral",
},
},
}
result, err := sjson.SetBytes(payload, contentPath, newContent)
if err != nil {
log.Warnf("failed to inject cache_control into message string content: %v", err)
return payload
}
payload = result
}
return payload
}
// injectToolsCacheControl adds cache_control to the last tool in the tools array.
// Per Anthropic docs: "The cache_control parameter on the last tool definition caches all tool definitions."
// This only adds cache_control if NO tool in the array already has it.
func injectToolsCacheControl(payload []byte) []byte {
tools := gjson.GetBytes(payload, "tools")
if !tools.Exists() || !tools.IsArray() {
return payload
}
toolCount := int(tools.Get("#").Int())
if toolCount == 0 {
return payload
}
// Check if ANY tool already has cache_control - if so, don't modify tools
hasCacheControlInTools := false
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("cache_control").Exists() {
hasCacheControlInTools = true
return false
}
return true
})
if hasCacheControlInTools {
return payload
}
// Add cache_control to the last tool
lastToolPath := fmt.Sprintf("tools.%d.cache_control", toolCount-1)
result, err := sjson.SetBytes(payload, lastToolPath, map[string]string{"type": "ephemeral"})
if err != nil {
log.Warnf("failed to inject cache_control into tools array: %v", err)
return payload
}
return result
}
// injectSystemCacheControl adds cache_control to the last element in the system prompt.
// Converts string system prompts to array format if needed.
// This only adds cache_control if NO system element already has it.
func injectSystemCacheControl(payload []byte) []byte {
system := gjson.GetBytes(payload, "system")
if !system.Exists() {
return payload
}
if system.IsArray() {
count := int(system.Get("#").Int())
if count == 0 {
return payload
}
// Check if ANY system element already has cache_control
hasCacheControlInSystem := false
system.ForEach(func(_, item gjson.Result) bool {
if item.Get("cache_control").Exists() {
hasCacheControlInSystem = true
return false
}
return true
})
if hasCacheControlInSystem {
return payload
}
// Add cache_control to the last system element
lastSystemPath := fmt.Sprintf("system.%d.cache_control", count-1)
result, err := sjson.SetBytes(payload, lastSystemPath, map[string]string{"type": "ephemeral"})
if err != nil {
log.Warnf("failed to inject cache_control into system array: %v", err)
return payload
}
payload = result
} else if system.Type == gjson.String {
// Convert string system prompt to array with cache_control
// "system": "text" -> "system": [{"type": "text", "text": "text", "cache_control": {"type": "ephemeral"}}]
text := system.String()
newSystem := []map[string]interface{}{
{
"type": "text",
"text": text,
"cache_control": map[string]string{
"type": "ephemeral",
},
},
}
result, err := sjson.SetBytes(payload, "system", newSystem)
if err != nil {
log.Warnf("failed to inject cache_control into system string: %v", err)
return payload
}
payload = result
}
return payload
}
================================================
FILE: internal/runtime/executor/claude_executor_test.go
================================================
package executor
import (
"bytes"
"compress/gzip"
"context"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/klauspost/compress/zstd"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func TestApplyClaudeToolPrefix(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"},{"name":"proxy_bravo"}],"tool_choice":{"type":"tool","name":"charlie"},"messages":[{"role":"assistant","content":[{"type":"tool_use","name":"delta","id":"t1","input":{}}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_alpha" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_alpha")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_bravo" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_bravo")
}
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "proxy_charlie" {
t.Fatalf("tool_choice.name = %q, want %q", got, "proxy_charlie")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_delta" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_delta")
}
}
func TestApplyClaudeToolPrefix_WithToolReference(t *testing.T) {
input := []byte(`{"tools":[{"name":"alpha"}],"messages":[{"role":"user","content":[{"type":"tool_reference","tool_name":"beta"},{"type":"tool_reference","tool_name":"proxy_gamma"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.tool_name").String(); got != "proxy_beta" {
t.Fatalf("messages.0.content.0.tool_name = %q, want %q", got, "proxy_beta")
}
if got := gjson.GetBytes(out, "messages.0.content.1.tool_name").String(); got != "proxy_gamma" {
t.Fatalf("messages.0.content.1.tool_name = %q, want %q", got, "proxy_gamma")
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinTools(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"},{"name":"my_custom_tool","input_schema":{"type":"object"}}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
t.Fatalf("built-in tool name should not be prefixed: tools.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_my_custom_tool" {
t.Fatalf("custom tool should be prefixed: tools.1.name = %q, want %q", got, "proxy_my_custom_tool")
}
}
func TestApplyClaudeToolPrefix_BuiltinToolSkipped(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search", "max_uses": 5},
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}},
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "web_search" {
t.Fatalf("tools.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Read" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_KnownBuiltinInHistoryOnly(t *testing.T) {
body := []byte(`{
"tools": [
{"name": "Read"}
],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "web_search", "id": "ws1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "web_search" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "web_search")
}
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
}
func TestApplyClaudeToolPrefix_CustomToolsPrefixed(t *testing.T) {
body := []byte(`{
"tools": [{"name": "Read"}, {"name": "Write"}],
"messages": [
{"role": "user", "content": [
{"type": "tool_use", "name": "Read", "id": "r1", "input": {}},
{"type": "tool_use", "name": "Write", "id": "w1", "input": {}}
]}
]
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tools.0.name").String(); got != "proxy_Read" {
t.Fatalf("tools.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "tools.1.name").String(); got != "proxy_Write" {
t.Fatalf("tools.1.name = %q, want %q", got, "proxy_Write")
}
if got := gjson.GetBytes(out, "messages.0.content.0.name").String(); got != "proxy_Read" {
t.Fatalf("messages.0.content.0.name = %q, want %q", got, "proxy_Read")
}
if got := gjson.GetBytes(out, "messages.0.content.1.name").String(); got != "proxy_Write" {
t.Fatalf("messages.0.content.1.name = %q, want %q", got, "proxy_Write")
}
}
func TestApplyClaudeToolPrefix_ToolChoiceBuiltin(t *testing.T) {
body := []byte(`{
"tools": [
{"type": "web_search_20250305", "name": "web_search"},
{"name": "Read"}
],
"tool_choice": {"type": "tool", "name": "web_search"}
}`)
out := applyClaudeToolPrefix(body, "proxy_")
if got := gjson.GetBytes(out, "tool_choice.name").String(); got != "web_search" {
t.Fatalf("tool_choice.name = %q, want %q", got, "web_search")
}
}
func TestStripClaudeToolPrefixFromResponse(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_use","name":"proxy_alpha","id":"t1","input":{}},{"type":"tool_use","name":"bravo","id":"t2","input":{}}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
if got := gjson.GetBytes(out, "content.0.name").String(); got != "alpha" {
t.Fatalf("content.0.name = %q, want %q", got, "alpha")
}
if got := gjson.GetBytes(out, "content.1.name").String(); got != "bravo" {
t.Fatalf("content.1.name = %q, want %q", got, "bravo")
}
}
func TestStripClaudeToolPrefixFromResponse_WithToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_reference","tool_name":"proxy_alpha"},{"type":"tool_reference","tool_name":"bravo"}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
if got := gjson.GetBytes(out, "content.0.tool_name").String(); got != "alpha" {
t.Fatalf("content.0.tool_name = %q, want %q", got, "alpha")
}
if got := gjson.GetBytes(out, "content.1.tool_name").String(); got != "bravo" {
t.Fatalf("content.1.tool_name = %q, want %q", got, "bravo")
}
}
func TestStripClaudeToolPrefixFromStreamLine(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_use","name":"proxy_alpha","id":"t1"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
payload := bytes.TrimSpace(out)
if bytes.HasPrefix(payload, []byte("data:")) {
payload = bytes.TrimSpace(payload[len("data:"):])
}
if got := gjson.GetBytes(payload, "content_block.name").String(); got != "alpha" {
t.Fatalf("content_block.name = %q, want %q", got, "alpha")
}
}
func TestStripClaudeToolPrefixFromStreamLine_WithToolReference(t *testing.T) {
line := []byte(`data: {"type":"content_block_start","content_block":{"type":"tool_reference","tool_name":"proxy_beta"},"index":0}`)
out := stripClaudeToolPrefixFromStreamLine(line, "proxy_")
payload := bytes.TrimSpace(out)
if bytes.HasPrefix(payload, []byte("data:")) {
payload = bytes.TrimSpace(payload[len("data:"):])
}
if got := gjson.GetBytes(payload, "content_block.tool_name").String(); got != "beta" {
t.Fatalf("content_block.tool_name = %q, want %q", got, "beta")
}
}
func TestApplyClaudeToolPrefix_NestedToolReference(t *testing.T) {
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"mcp__nia__manage_resource"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "proxy_mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "proxy_mcp__nia__manage_resource")
}
}
func TestClaudeExecutor_ReusesUserIDAcrossModelsWhenCacheEnabled(t *testing.T) {
resetUserIDCache()
var userIDs []string
var requestModels []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userID := gjson.GetBytes(body, "metadata.user_id").String()
model := gjson.GetBytes(body, "model").String()
userIDs = append(userIDs, userID)
requestModels = append(requestModels, model)
t.Logf("HTTP Server received request: model=%s, user_id=%s, url=%s", model, userID, r.URL.String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
t.Logf("End-to-end test: Fake HTTP server started at %s", server.URL)
cacheEnabled := true
executor := NewClaudeExecutor(&config.Config{
ClaudeKey: []config.ClaudeKey{
{
APIKey: "key-123",
BaseURL: server.URL,
Cloak: &config.CloakConfig{
CacheUserID: &cacheEnabled,
},
},
},
})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
models := []string{"claude-3-5-sonnet", "claude-3-5-haiku"}
for _, model := range models {
t.Logf("Sending request for model: %s", model)
modelPayload, _ := sjson.SetBytes(payload, "model", model)
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: model,
Payload: modelPayload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute(%s) error: %v", model, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
t.Logf("user_id[0] (model=%s): %s", requestModels[0], userIDs[0])
t.Logf("user_id[1] (model=%s): %s", requestModels[1], userIDs[1])
if userIDs[0] != userIDs[1] {
t.Fatalf("expected user_id to be reused across models, got %q and %q", userIDs[0], userIDs[1])
}
if !isValidUserID(userIDs[0]) {
t.Fatalf("user_id %q is not valid", userIDs[0])
}
t.Logf("✓ End-to-end test passed: Same user_id (%s) was used for both models", userIDs[0])
}
func TestClaudeExecutor_GeneratesNewUserIDByDefault(t *testing.T) {
resetUserIDCache()
var userIDs []string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
userIDs = append(userIDs, gjson.GetBytes(body, "metadata.user_id").String())
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet","role":"assistant","content":[{"type":"text","text":"ok"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
for i := 0; i < 2; i++ {
if _, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
}); err != nil {
t.Fatalf("Execute call %d error: %v", i, err)
}
}
if len(userIDs) != 2 {
t.Fatalf("expected 2 requests, got %d", len(userIDs))
}
if userIDs[0] == "" || userIDs[1] == "" {
t.Fatal("expected user_id to be populated")
}
if userIDs[0] == userIDs[1] {
t.Fatalf("expected user_id to change when caching is not enabled, got identical values %q", userIDs[0])
}
if !isValidUserID(userIDs[0]) || !isValidUserID(userIDs[1]) {
t.Fatalf("user_ids should be valid, got %q and %q", userIDs[0], userIDs[1])
}
}
func TestStripClaudeToolPrefixFromResponse_NestedToolReference(t *testing.T) {
input := []byte(`{"content":[{"type":"tool_result","tool_use_id":"toolu_123","content":[{"type":"tool_reference","tool_name":"proxy_mcp__nia__manage_resource"}]}]}`)
out := stripClaudeToolPrefixFromResponse(input, "proxy_")
got := gjson.GetBytes(out, "content.0.content.0.tool_name").String()
if got != "mcp__nia__manage_resource" {
t.Fatalf("nested tool_reference tool_name = %q, want %q", got, "mcp__nia__manage_resource")
}
}
func TestApplyClaudeToolPrefix_NestedToolReferenceWithStringContent(t *testing.T) {
// tool_result.content can be a string - should not be processed
input := []byte(`{"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"toolu_123","content":"plain string result"}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content").String()
if got != "plain string result" {
t.Fatalf("string content should remain unchanged = %q", got)
}
}
func TestApplyClaudeToolPrefix_SkipsBuiltinToolReference(t *testing.T) {
input := []byte(`{"tools":[{"type":"web_search_20250305","name":"web_search"}],"messages":[{"role":"user","content":[{"type":"tool_result","tool_use_id":"t1","content":[{"type":"tool_reference","tool_name":"web_search"}]}]}]}`)
out := applyClaudeToolPrefix(input, "proxy_")
got := gjson.GetBytes(out, "messages.0.content.0.content.0.tool_name").String()
if got != "web_search" {
t.Fatalf("built-in tool_reference should not be prefixed, got %q", got)
}
}
func TestNormalizeCacheControlTTL_DowngradesLaterOneHourBlocks(t *testing.T) {
payload := []byte(`{
"tools": [{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
"messages": [{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]}]
}`)
out := normalizeCacheControlTTL(payload)
if got := gjson.GetBytes(out, "tools.0.cache_control.ttl").String(); got != "1h" {
t.Fatalf("tools.0.cache_control.ttl = %q, want %q", got, "1h")
}
if gjson.GetBytes(out, "messages.0.content.0.cache_control.ttl").Exists() {
t.Fatalf("messages.0.content.0.cache_control.ttl should be removed after a default-5m block")
}
}
func TestNormalizeCacheControlTTL_PreservesOriginalBytesWhenNoChange(t *testing.T) {
// Payload where no TTL normalization is needed (all blocks use 1h with no
// preceding 5m block). The text intentionally contains HTML chars (<, >, &)
// that json.Marshal would escape to \u003c etc., altering byte identity.
payload := []byte(`{"tools":[{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}}],"system":[{"type":"text","text":"foo & bar ","cache_control":{"type":"ephemeral","ttl":"1h"}}],"messages":[{"role":"user","content":[{"type":"text","text":"hello"}]}]}`)
out := normalizeCacheControlTTL(payload)
if !bytes.Equal(out, payload) {
t.Fatalf("normalizeCacheControlTTL altered bytes when no change was needed.\noriginal: %s\ngot: %s", payload, out)
}
}
func TestEnforceCacheControlLimit_StripsNonLastToolBeforeMessages(t *testing.T) {
payload := []byte(`{
"tools": [
{"name":"t1","cache_control":{"type":"ephemeral"}},
{"name":"t2","cache_control":{"type":"ephemeral"}}
],
"system": [{"type":"text","text":"s1","cache_control":{"type":"ephemeral"}}],
"messages": [
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral"}}]},
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral"}}]}
]
}`)
out := enforceCacheControlLimit(payload, 4)
if got := countCacheControls(out); got != 4 {
t.Fatalf("cache_control count = %d, want 4", got)
}
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
t.Fatalf("tools.0.cache_control should be removed first (non-last tool)")
}
if !gjson.GetBytes(out, "tools.1.cache_control").Exists() {
t.Fatalf("tools.1.cache_control (last tool) should be preserved")
}
if !gjson.GetBytes(out, "messages.0.content.0.cache_control").Exists() || !gjson.GetBytes(out, "messages.1.content.0.cache_control").Exists() {
t.Fatalf("message cache_control blocks should be preserved when non-last tool removal is enough")
}
}
func TestEnforceCacheControlLimit_ToolOnlyPayloadStillRespectsLimit(t *testing.T) {
payload := []byte(`{
"tools": [
{"name":"t1","cache_control":{"type":"ephemeral"}},
{"name":"t2","cache_control":{"type":"ephemeral"}},
{"name":"t3","cache_control":{"type":"ephemeral"}},
{"name":"t4","cache_control":{"type":"ephemeral"}},
{"name":"t5","cache_control":{"type":"ephemeral"}}
]
}`)
out := enforceCacheControlLimit(payload, 4)
if got := countCacheControls(out); got != 4 {
t.Fatalf("cache_control count = %d, want 4", got)
}
if gjson.GetBytes(out, "tools.0.cache_control").Exists() {
t.Fatalf("tools.0.cache_control should be removed to satisfy max=4")
}
if !gjson.GetBytes(out, "tools.4.cache_control").Exists() {
t.Fatalf("last tool cache_control should be preserved when possible")
}
}
func TestClaudeExecutor_CountTokens_AppliesCacheControlGuards(t *testing.T) {
var seenBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
seenBody = bytes.Clone(body)
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"input_tokens":42}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{
"tools": [
{"name":"t1","cache_control":{"type":"ephemeral","ttl":"1h"}},
{"name":"t2","cache_control":{"type":"ephemeral"}}
],
"system": [
{"type":"text","text":"s1","cache_control":{"type":"ephemeral","ttl":"1h"}},
{"type":"text","text":"s2","cache_control":{"type":"ephemeral","ttl":"1h"}}
],
"messages": [
{"role":"user","content":[{"type":"text","text":"u1","cache_control":{"type":"ephemeral","ttl":"1h"}}]},
{"role":"user","content":[{"type":"text","text":"u2","cache_control":{"type":"ephemeral","ttl":"1h"}}]}
]
}`)
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-haiku-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
if err != nil {
t.Fatalf("CountTokens error: %v", err)
}
if len(seenBody) == 0 {
t.Fatal("expected count_tokens request body to be captured")
}
if got := countCacheControls(seenBody); got > 4 {
t.Fatalf("count_tokens body has %d cache_control blocks, want <= 4", got)
}
if hasTTLOrderingViolation(seenBody) {
t.Fatalf("count_tokens body still has ttl ordering violations: %s", string(seenBody))
}
}
func hasTTLOrderingViolation(payload []byte) bool {
seen5m := false
violates := false
checkCC := func(cc gjson.Result) {
if !cc.Exists() || violates {
return
}
ttl := cc.Get("ttl").String()
if ttl != "1h" {
seen5m = true
return
}
if seen5m {
violates = true
}
}
tools := gjson.GetBytes(payload, "tools")
if tools.IsArray() {
tools.ForEach(func(_, tool gjson.Result) bool {
checkCC(tool.Get("cache_control"))
return !violates
})
}
system := gjson.GetBytes(payload, "system")
if system.IsArray() {
system.ForEach(func(_, item gjson.Result) bool {
checkCC(item.Get("cache_control"))
return !violates
})
}
messages := gjson.GetBytes(payload, "messages")
if messages.IsArray() {
messages.ForEach(func(_, msg gjson.Result) bool {
content := msg.Get("content")
if content.IsArray() {
content.ForEach(func(_, item gjson.Result) bool {
checkCC(item.Get("cache_control"))
return !violates
})
}
return !violates
})
}
return violates
}
func TestClaudeExecutor_Execute_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
return err
})
}
func TestClaudeExecutor_ExecuteStream_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
return err
})
}
func TestClaudeExecutor_CountTokens_InvalidGzipErrorBodyReturnsDecodeMessage(t *testing.T) {
testClaudeExecutorInvalidCompressedErrorBody(t, func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error {
_, err := executor.CountTokens(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{SourceFormat: sdktranslator.FromString("claude")})
return err
})
}
func testClaudeExecutorInvalidCompressedErrorBody(
t *testing.T,
invoke func(executor *ClaudeExecutor, auth *cliproxyauth.Auth, payload []byte) error,
) {
t.Helper()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Encoding", "gzip")
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write([]byte("not-a-valid-gzip-stream"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
err := invoke(executor, auth, payload)
if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "failed to decode error response body") {
t.Fatalf("expected decode failure message, got: %v", err)
}
if statusProvider, ok := err.(interface{ StatusCode() int }); !ok || statusProvider.StatusCode() != http.StatusBadRequest {
t.Fatalf("expected status code 400, got: %v", err)
}
}
// TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding verifies that streaming
// requests use Accept-Encoding: identity so the upstream cannot respond with a
// compressed SSE body that would silently break the line scanner.
func TestClaudeExecutor_ExecuteStream_SetsIdentityAcceptEncoding(t *testing.T) {
var gotEncoding, gotAccept string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
gotAccept = r.Header.Get("Accept")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "identity")
}
if gotAccept != "text/event-stream" {
t.Errorf("Accept = %q, want %q", gotAccept, "text/event-stream")
}
}
// TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding verifies that non-streaming
// requests keep the full accept-encoding to allow response compression (which
// decodeResponseBody handles correctly).
func TestClaudeExecutor_Execute_SetsCompressedAcceptEncoding(t *testing.T) {
var gotEncoding, gotAccept string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
gotAccept = r.Header.Get("Accept")
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"msg_1","type":"message","model":"claude-3-5-sonnet-20241022","role":"assistant","content":[{"type":"text","text":"hi"}],"usage":{"input_tokens":1,"output_tokens":1}}`))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotEncoding != "gzip, deflate, br, zstd" {
t.Errorf("Accept-Encoding = %q, want %q", gotEncoding, "gzip, deflate, br, zstd")
}
if gotAccept != "application/json" {
t.Errorf("Accept = %q, want %q", gotAccept, "application/json")
}
}
// TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded verifies that a streaming
// HTTP 200 response with Content-Encoding: gzip is correctly decompressed before
// the line scanner runs, so SSE chunks are not silently dropped.
func TestClaudeExecutor_ExecuteStream_GzipSuccessBodyDecoded(t *testing.T) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
_ = gz.Close()
compressedBody := buf.Bytes()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
w.Header().Set("Content-Encoding", "gzip")
_, _ = w.Write(compressedBody)
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var combined strings.Builder
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("chunk error: %v", chunk.Err)
}
combined.Write(chunk.Payload)
}
if combined.Len() == 0 {
t.Fatal("expected at least one chunk from gzip-encoded SSE body, got none (body was not decompressed)")
}
if !strings.Contains(combined.String(), "message_stop") {
t.Errorf("expected SSE content in chunks, got: %q", combined.String())
}
}
// TestDecodeResponseBody_MagicByteGzipNoHeader verifies that decodeResponseBody
// detects gzip-compressed content via magic bytes even when Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteGzipNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte(plaintext))
_ = gz.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestDecodeResponseBody_PlainTextNoHeader verifies that decodeResponseBody returns
// plain text untouched when Content-Encoding is absent and no magic bytes match.
func TestDecodeResponseBody_PlainTextNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
rc := io.NopCloser(strings.NewReader(plaintext))
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader verifies the full
// pipeline: when the upstream returns a gzip-compressed SSE body WITHOUT setting
// Content-Encoding (a misbehaving upstream), the magic-byte sniff in
// decodeResponseBody still decompresses it, so chunks reach the caller.
func TestClaudeExecutor_ExecuteStream_GzipNoContentEncodingHeader(t *testing.T) {
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte("data: {\"type\":\"message_stop\"}\n"))
_ = gz.Close()
compressedBody := buf.Bytes()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
_, _ = w.Write(compressedBody)
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
var combined strings.Builder
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("chunk error: %v", chunk.Err)
}
combined.Write(chunk.Payload)
}
if combined.Len() == 0 {
t.Fatal("expected chunks from gzip body without Content-Encoding header, got none (magic-byte sniff failed)")
}
if !strings.Contains(combined.String(), "message_stop") {
t.Errorf("unexpected chunk content: %q", combined.String())
}
}
// TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity verifies
// that injecting Accept-Encoding via auth.Attributes cannot override the stream
// path's enforced identity encoding.
func TestClaudeExecutor_ExecuteStream_AcceptEncodingOverrideCannotBypassIdentity(t *testing.T) {
var gotEncoding string
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotEncoding = r.Header.Get("Accept-Encoding")
w.Header().Set("Content-Type", "text/event-stream")
_, _ = w.Write([]byte("data: {\"type\":\"message_stop\"}\n\n"))
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
// Inject Accept-Encoding via the custom header attribute mechanism.
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
"header:Accept-Encoding": "gzip, deflate, br, zstd",
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
result, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err != nil {
t.Fatalf("ExecuteStream error: %v", err)
}
for chunk := range result.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected chunk error: %v", chunk.Err)
}
}
if gotEncoding != "identity" {
t.Errorf("Accept-Encoding = %q; stream path must enforce identity regardless of auth.Attributes override", gotEncoding)
}
}
// TestDecodeResponseBody_MagicByteZstdNoHeader verifies that decodeResponseBody
// detects zstd-compressed content via magic bytes (28 b5 2f fd) even when
// Content-Encoding is absent.
func TestDecodeResponseBody_MagicByteZstdNoHeader(t *testing.T) {
const plaintext = "data: {\"type\":\"message_stop\"}\n"
var buf bytes.Buffer
enc, err := zstd.NewWriter(&buf)
if err != nil {
t.Fatalf("zstd.NewWriter: %v", err)
}
_, _ = enc.Write([]byte(plaintext))
_ = enc.Close()
rc := io.NopCloser(&buf)
decoded, err := decodeResponseBody(rc, "")
if err != nil {
t.Fatalf("decodeResponseBody error: %v", err)
}
defer decoded.Close()
got, err := io.ReadAll(decoded)
if err != nil {
t.Fatalf("ReadAll error: %v", err)
}
if string(got) != plaintext {
t.Errorf("decoded = %q, want %q", got, plaintext)
}
}
// TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader verifies that the
// error path (4xx) correctly decompresses a gzip body even when the upstream omits
// the Content-Encoding header. This closes the gap left by PR #1771, which only
// fixed header-declared compression on the error path.
func TestClaudeExecutor_Execute_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"test error"}}`
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte(errJSON))
_ = gz.Close()
compressedBody := buf.Bytes()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(compressedBody)
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
_, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err == nil {
t.Fatal("expected an error for 400 response, got nil")
}
if !strings.Contains(err.Error(), "test error") {
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
}
}
// TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader verifies
// the same for the streaming executor: 4xx gzip body without Content-Encoding is
// decoded and the error message is readable.
func TestClaudeExecutor_ExecuteStream_GzipErrorBodyNoContentEncodingHeader(t *testing.T) {
const errJSON = `{"type":"error","error":{"type":"invalid_request_error","message":"stream test error"}}`
var buf bytes.Buffer
gz := gzip.NewWriter(&buf)
_, _ = gz.Write([]byte(errJSON))
_ = gz.Close()
compressedBody := buf.Bytes()
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
// Intentionally omit Content-Encoding to simulate misbehaving upstream.
w.WriteHeader(http.StatusBadRequest)
_, _ = w.Write(compressedBody)
}))
defer server.Close()
executor := NewClaudeExecutor(&config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"api_key": "key-123",
"base_url": server.URL,
}}
payload := []byte(`{"messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}]}`)
_, err := executor.ExecuteStream(context.Background(), auth, cliproxyexecutor.Request{
Model: "claude-3-5-sonnet-20241022",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("claude"),
})
if err == nil {
t.Fatal("expected an error for 400 response, got nil")
}
if !strings.Contains(err.Error(), "stream test error") {
t.Errorf("error message should contain decompressed JSON, got: %q", err.Error())
}
}
// Test case 1: String system prompt is preserved and converted to a content block
func TestCheckSystemInstructionsWithMode_StringSystemPreserved(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
system := gjson.GetBytes(out, "system")
if !system.IsArray() {
t.Fatalf("system should be an array, got %s", system.Type)
}
blocks := system.Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if !strings.HasPrefix(blocks[0].Get("text").String(), "x-anthropic-billing-header:") {
t.Fatalf("blocks[0] should be billing header, got %q", blocks[0].Get("text").String())
}
if blocks[1].Get("text").String() != "You are a Claude agent, built on Anthropic's Claude Agent SDK." {
t.Fatalf("blocks[1] should be agent block, got %q", blocks[1].Get("text").String())
}
if blocks[2].Get("text").String() != "You are a helpful assistant." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
if blocks[2].Get("cache_control.type").String() != "ephemeral" {
t.Fatalf("blocks[2] should have cache_control.type=ephemeral")
}
}
// Test case 2: Strict mode drops the string system prompt
func TestCheckSystemInstructionsWithMode_StringSystemStrict(t *testing.T) {
payload := []byte(`{"system":"You are a helpful assistant.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, true)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("strict mode should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 3: Empty string system prompt does not produce a spurious block
func TestCheckSystemInstructionsWithMode_EmptyStringSystemIgnored(t *testing.T) {
payload := []byte(`{"system":"","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 2 {
t.Fatalf("empty string system should produce 2 blocks, got %d", len(blocks))
}
}
// Test case 4: Array system prompt is unaffected by the string handling
func TestCheckSystemInstructionsWithMode_ArraySystemStillWorks(t *testing.T) {
payload := []byte(`{"system":[{"type":"text","text":"Be concise."}],"messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != "Be concise." {
t.Fatalf("blocks[2] should be user system prompt, got %q", blocks[2].Get("text").String())
}
}
// Test case 5: Special characters in string system prompt survive conversion
func TestCheckSystemInstructionsWithMode_StringWithSpecialChars(t *testing.T) {
payload := []byte(`{"system":"Use tags & \"quotes\" in output.","messages":[{"role":"user","content":"hi"}]}`)
out := checkSystemInstructionsWithMode(payload, false)
blocks := gjson.GetBytes(out, "system").Array()
if len(blocks) != 3 {
t.Fatalf("expected 3 system blocks, got %d", len(blocks))
}
if blocks[2].Get("text").String() != `Use tags & "quotes" in output.` {
t.Fatalf("blocks[2] text mangled, got %q", blocks[2].Get("text").String())
}
}
================================================
FILE: internal/runtime/executor/cloak_obfuscate.go
================================================
package executor
import (
"regexp"
"sort"
"strings"
"unicode/utf8"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// zeroWidthSpace is the Unicode zero-width space character used for obfuscation.
const zeroWidthSpace = "\u200B"
// SensitiveWordMatcher holds the compiled regex for matching sensitive words.
type SensitiveWordMatcher struct {
regex *regexp.Regexp
}
// buildSensitiveWordMatcher compiles a regex from the word list.
// Words are sorted by length (longest first) for proper matching.
func buildSensitiveWordMatcher(words []string) *SensitiveWordMatcher {
if len(words) == 0 {
return nil
}
// Filter and normalize words
var validWords []string
for _, w := range words {
w = strings.TrimSpace(w)
if utf8.RuneCountInString(w) >= 2 && !strings.Contains(w, zeroWidthSpace) {
validWords = append(validWords, w)
}
}
if len(validWords) == 0 {
return nil
}
// Sort by length (longest first) for proper matching
sort.Slice(validWords, func(i, j int) bool {
return len(validWords[i]) > len(validWords[j])
})
// Escape and join
escaped := make([]string, len(validWords))
for i, w := range validWords {
escaped[i] = regexp.QuoteMeta(w)
}
pattern := "(?i)" + strings.Join(escaped, "|")
re, err := regexp.Compile(pattern)
if err != nil {
return nil
}
return &SensitiveWordMatcher{regex: re}
}
// obfuscateWord inserts a zero-width space after the first grapheme.
func obfuscateWord(word string) string {
if strings.Contains(word, zeroWidthSpace) {
return word
}
// Get first rune
r, size := utf8.DecodeRuneInString(word)
if r == utf8.RuneError || size >= len(word) {
return word
}
return string(r) + zeroWidthSpace + word[size:]
}
// obfuscateText replaces all sensitive words in the text.
func (m *SensitiveWordMatcher) obfuscateText(text string) string {
if m == nil || m.regex == nil {
return text
}
return m.regex.ReplaceAllStringFunc(text, obfuscateWord)
}
// obfuscateSensitiveWords processes the payload and obfuscates sensitive words
// in system blocks and message content.
func obfuscateSensitiveWords(payload []byte, matcher *SensitiveWordMatcher) []byte {
if matcher == nil || matcher.regex == nil {
return payload
}
// Obfuscate in system blocks
payload = obfuscateSystemBlocks(payload, matcher)
// Obfuscate in messages
payload = obfuscateMessages(payload, matcher)
return payload
}
// obfuscateSystemBlocks obfuscates sensitive words in system blocks.
func obfuscateSystemBlocks(payload []byte, matcher *SensitiveWordMatcher) []byte {
system := gjson.GetBytes(payload, "system")
if !system.Exists() {
return payload
}
if system.IsArray() {
modified := false
system.ForEach(func(key, value gjson.Result) bool {
if value.Get("type").String() == "text" {
text := value.Get("text").String()
obfuscated := matcher.obfuscateText(text)
if obfuscated != text {
path := "system." + key.String() + ".text"
payload, _ = sjson.SetBytes(payload, path, obfuscated)
modified = true
}
}
return true
})
if modified {
return payload
}
} else if system.Type == gjson.String {
text := system.String()
obfuscated := matcher.obfuscateText(text)
if obfuscated != text {
payload, _ = sjson.SetBytes(payload, "system", obfuscated)
}
}
return payload
}
// obfuscateMessages obfuscates sensitive words in message content.
func obfuscateMessages(payload []byte, matcher *SensitiveWordMatcher) []byte {
messages := gjson.GetBytes(payload, "messages")
if !messages.Exists() || !messages.IsArray() {
return payload
}
messages.ForEach(func(msgKey, msg gjson.Result) bool {
content := msg.Get("content")
if !content.Exists() {
return true
}
msgPath := "messages." + msgKey.String()
if content.Type == gjson.String {
// Simple string content
text := content.String()
obfuscated := matcher.obfuscateText(text)
if obfuscated != text {
payload, _ = sjson.SetBytes(payload, msgPath+".content", obfuscated)
}
} else if content.IsArray() {
// Array of content blocks
content.ForEach(func(blockKey, block gjson.Result) bool {
if block.Get("type").String() == "text" {
text := block.Get("text").String()
obfuscated := matcher.obfuscateText(text)
if obfuscated != text {
path := msgPath + ".content." + blockKey.String() + ".text"
payload, _ = sjson.SetBytes(payload, path, obfuscated)
}
}
return true
})
}
return true
})
return payload
}
================================================
FILE: internal/runtime/executor/cloak_utils.go
================================================
package executor
import (
"crypto/rand"
"encoding/hex"
"regexp"
"strings"
"github.com/google/uuid"
)
// userIDPattern matches Claude Code format: user_[64-hex]_account_[uuid]_session_[uuid]
var userIDPattern = regexp.MustCompile(`^user_[a-fA-F0-9]{64}_account_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}_session_[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$`)
// generateFakeUserID generates a fake user ID in Claude Code format.
// Format: user_[64-hex-chars]_account_[UUID-v4]_session_[UUID-v4]
func generateFakeUserID() string {
hexBytes := make([]byte, 32)
_, _ = rand.Read(hexBytes)
hexPart := hex.EncodeToString(hexBytes)
accountUUID := uuid.New().String()
sessionUUID := uuid.New().String()
return "user_" + hexPart + "_account_" + accountUUID + "_session_" + sessionUUID
}
// isValidUserID checks if a user ID matches Claude Code format.
func isValidUserID(userID string) bool {
return userIDPattern.MatchString(userID)
}
// shouldCloak determines if request should be cloaked based on config and client User-Agent.
// Returns true if cloaking should be applied.
func shouldCloak(cloakMode string, userAgent string) bool {
switch strings.ToLower(cloakMode) {
case "always":
return true
case "never":
return false
default: // "auto" or empty
// If client is Claude Code, don't cloak
return !strings.HasPrefix(userAgent, "claude-cli")
}
}
// isClaudeCodeClient checks if the User-Agent indicates a Claude Code client.
func isClaudeCodeClient(userAgent string) bool {
return strings.HasPrefix(userAgent, "claude-cli")
}
================================================
FILE: internal/runtime/executor/codex_executor.go
================================================
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
codexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"github.com/tiktoken-go/tokenizer"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
)
const (
codexClientVersion = "0.101.0"
codexUserAgent = "codex_cli_rs/0.101.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
)
var dataTag = []byte("data:")
// CodexExecutor is a stateless executor for Codex (OpenAI Responses API entrypoint).
// If api_key is unavailable on auth, it falls back to legacy via ClientAdapter.
type CodexExecutor struct {
cfg *config.Config
}
func NewCodexExecutor(cfg *config.Config) *CodexExecutor { return &CodexExecutor{cfg: cfg} }
func (e *CodexExecutor) Identifier() string { return "codex" }
// PrepareRequest injects Codex credentials into the outgoing HTTP request.
func (e *CodexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, _ := codexCreds(auth)
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest injects Codex credentials into the request and executes it.
func (e *CodexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("codex executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
return e.executeCompact(ctx, auth, req, opts)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines {
if !bytes.HasPrefix(line, dataTag) {
continue
}
line = bytes.TrimSpace(line[5:])
if gjson.GetBytes(line, "type").String() != "response.completed" {
continue
}
if detail, ok := parseCodexUsage(line); ok {
reporter.publish(ctx, detail)
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
err = statusErr{code: 408, msg: "stream error: stream disconnected before completion: stream closed before response.completed"}
return resp, err
}
func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai-response")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.DeleteBytes(body, "stream")
url := strings.TrimSuffix(baseURL, "/") + "/responses/compact"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
if err != nil {
return resp, err
}
applyCodexHeaders(httpReq, auth, apiKey, false, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = newCodexStatusErr(httpResp.StatusCode, b)
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
reporter.ensurePublished(ctx)
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.SetBytes(body, "model", baseModel)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
url := strings.TrimSuffix(baseURL, "/") + "/responses"
httpReq, err := e.cacheHelper(ctx, from, url, req, body)
if err != nil {
return nil, err
}
applyCodexHeaders(httpReq, auth, apiKey, true, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, readErr := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
if readErr != nil {
recordAPIResponseError(ctx, e.cfg, readErr)
return nil, readErr
}
appendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = newCodexStatusErr(httpResp.StatusCode, data)
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("codex executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if bytes.HasPrefix(line, dataTag) {
data := bytes.TrimSpace(line[5:])
if gjson.GetBytes(data, "type").String() == "response.completed" {
if detail, ok := parseCodexUsage(data); ok {
reporter.publish(ctx, detail)
}
}
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
body, _ = sjson.SetBytes(body, "stream", false)
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
enc, err := tokenizerForCodexModel(baseModel)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: tokenizer init failed: %w", err)
}
count, err := countCodexInputTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex executor: token counting failed: %w", err)
}
usageJSON := fmt.Sprintf(`{"response":{"usage":{"input_tokens":%d,"output_tokens":0,"total_tokens":%d}}}`, count, count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, []byte(usageJSON))
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func tokenizerForCodexModel(model string) (tokenizer.Codec, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
switch {
case sanitized == "":
return tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5"):
return tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
return tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
return tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
return tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
return tokenizer.ForModel(tokenizer.GPT35Turbo)
default:
return tokenizer.Get(tokenizer.Cl100kBase)
}
}
func countCodexInputTokens(enc tokenizer.Codec, body []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
if len(body) == 0 {
return 0, nil
}
root := gjson.ParseBytes(body)
var segments []string
if inst := strings.TrimSpace(root.Get("instructions").String()); inst != "" {
segments = append(segments, inst)
}
inputItems := root.Get("input")
if inputItems.IsArray() {
arr := inputItems.Array()
for i := range arr {
item := arr[i]
switch item.Get("type").String() {
case "message":
content := item.Get("content")
if content.IsArray() {
parts := content.Array()
for j := range parts {
part := parts[j]
if text := strings.TrimSpace(part.Get("text").String()); text != "" {
segments = append(segments, text)
}
}
}
case "function_call":
if name := strings.TrimSpace(item.Get("name").String()); name != "" {
segments = append(segments, name)
}
if args := strings.TrimSpace(item.Get("arguments").String()); args != "" {
segments = append(segments, args)
}
case "function_call_output":
if out := strings.TrimSpace(item.Get("output").String()); out != "" {
segments = append(segments, out)
}
default:
if text := strings.TrimSpace(item.Get("text").String()); text != "" {
segments = append(segments, text)
}
}
}
}
tools := root.Get("tools")
if tools.IsArray() {
tarr := tools.Array()
for i := range tarr {
tool := tarr[i]
if name := strings.TrimSpace(tool.Get("name").String()); name != "" {
segments = append(segments, name)
}
if desc := strings.TrimSpace(tool.Get("description").String()); desc != "" {
segments = append(segments, desc)
}
if params := tool.Get("parameters"); params.Exists() {
val := params.Raw
if params.Type == gjson.String {
val = params.String()
}
if trimmed := strings.TrimSpace(val); trimmed != "" {
segments = append(segments, trimmed)
}
}
}
}
textFormat := root.Get("text.format")
if textFormat.Exists() {
if name := strings.TrimSpace(textFormat.Get("name").String()); name != "" {
segments = append(segments, name)
}
if schema := textFormat.Get("schema"); schema.Exists() {
val := schema.Raw
if schema.Type == gjson.String {
val = schema.String()
}
if trimmed := strings.TrimSpace(val); trimmed != "" {
segments = append(segments, trimmed)
}
}
}
text := strings.Join(segments, "\n")
if text == "" {
return 0, nil
}
count, err := enc.Count(text)
if err != nil {
return 0, err
}
return int64(count), nil
}
func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("codex executor: refresh called")
if auth == nil {
return nil, statusErr{code: 500, msg: "codex executor: auth is nil"}
}
var refreshToken string
if auth.Metadata != nil {
if v, ok := auth.Metadata["refresh_token"].(string); ok && v != "" {
refreshToken = v
}
}
if refreshToken == "" {
return auth, nil
}
svc := codexauth.NewCodexAuth(e.cfg)
td, err := svc.RefreshTokensWithRetry(ctx, refreshToken, 3)
if err != nil {
return nil, err
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["id_token"] = td.IDToken
auth.Metadata["access_token"] = td.AccessToken
if td.RefreshToken != "" {
auth.Metadata["refresh_token"] = td.RefreshToken
}
if td.AccountID != "" {
auth.Metadata["account_id"] = td.AccountID
}
auth.Metadata["email"] = td.Email
// Use unified key in files
auth.Metadata["expired"] = td.Expire
auth.Metadata["type"] = "codex"
now := time.Now().Format(time.RFC3339)
auth.Metadata["last_refresh"] = now
return auth, nil
}
func (e *CodexExecutor) cacheHelper(ctx context.Context, from sdktranslator.Format, url string, req cliproxyexecutor.Request, rawJSON []byte) (*http.Request, error) {
var cache codexCache
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
var ok bool
if cache, ok = getCodexCache(key); !ok {
cache = codexCache{
ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour),
}
setCodexCache(key, cache)
}
}
} else if from == "openai-response" {
promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key")
if promptCacheKey.Exists() {
cache.ID = promptCacheKey.String()
}
} else if from == "openai" {
if apiKey := strings.TrimSpace(apiKeyFromContext(ctx)); apiKey != "" {
cache.ID = uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:"+apiKey)).String()
}
}
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
}
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(rawJSON))
if err != nil {
return nil, err
}
if cache.ID != "" {
httpReq.Header.Set("Conversation_id", cache.ID)
httpReq.Header.Set("Session_id", cache.ID)
}
return httpReq, nil
}
func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, stream bool, cfg *config.Config) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
var ginHeaders http.Header
if ginCtx, ok := r.Context().Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header
}
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
cfgUserAgent, _ := codexHeaderDefaults(cfg, auth)
ensureHeaderWithConfigPrecedence(r.Header, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
if stream {
r.Header.Set("Accept", "text/event-stream")
} else {
r.Header.Set("Accept", "application/json")
}
r.Header.Set("Connection", "Keep-Alive")
isAPIKey := false
if auth != nil && auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
isAPIKey = true
}
}
if !isAPIKey {
r.Header.Set("Originator", "codex_cli_rs")
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
r.Header.Set("Chatgpt-Account-Id", accountID)
}
}
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(r, attrs)
}
func newCodexStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)}
if retryAfter := parseCodexRetryAfter(statusCode, body, time.Now()); retryAfter != nil {
err.retryAfter = retryAfter
}
return err
}
func parseCodexRetryAfter(statusCode int, errorBody []byte, now time.Time) *time.Duration {
if statusCode != http.StatusTooManyRequests || len(errorBody) == 0 {
return nil
}
if strings.TrimSpace(gjson.GetBytes(errorBody, "error.type").String()) != "usage_limit_reached" {
return nil
}
if resetsAt := gjson.GetBytes(errorBody, "error.resets_at").Int(); resetsAt > 0 {
resetAtTime := time.Unix(resetsAt, 0)
if resetAtTime.After(now) {
retryAfter := resetAtTime.Sub(now)
return &retryAfter
}
}
if resetsInSeconds := gjson.GetBytes(errorBody, "error.resets_in_seconds").Int(); resetsInSeconds > 0 {
retryAfter := time.Duration(resetsInSeconds) * time.Second
return &retryAfter
}
return nil
}
func codexCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
if a == nil {
return "", ""
}
if a.Attributes != nil {
apiKey = a.Attributes["api_key"]
baseURL = a.Attributes["base_url"]
}
if apiKey == "" && a.Metadata != nil {
if v, ok := a.Metadata["access_token"].(string); ok {
apiKey = v
}
}
return
}
func (e *CodexExecutor) resolveCodexConfig(auth *cliproxyauth.Auth) *config.CodexKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.CodexKey {
entry := &e.cfg.CodexKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
================================================
FILE: internal/runtime/executor/codex_executor_cache_test.go
================================================
package executor
import (
"context"
"io"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestCodexExecutorCacheHelper_OpenAIChatCompletions_StablePromptCacheKeyFromAPIKey(t *testing.T) {
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginCtx.Set("apiKey", "test-api-key")
ctx := context.WithValue(context.Background(), "gin", ginCtx)
executor := &CodexExecutor{}
rawJSON := []byte(`{"model":"gpt-5.3-codex","stream":true}`)
req := cliproxyexecutor.Request{
Model: "gpt-5.3-codex",
Payload: []byte(`{"model":"gpt-5.3-codex"}`),
}
url := "https://example.com/responses"
httpReq, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error: %v", err)
}
body, errRead := io.ReadAll(httpReq.Body)
if errRead != nil {
t.Fatalf("read request body: %v", errRead)
}
expectedKey := uuid.NewSHA1(uuid.NameSpaceOID, []byte("cli-proxy-api:codex:prompt-cache:test-api-key")).String()
gotKey := gjson.GetBytes(body, "prompt_cache_key").String()
if gotKey != expectedKey {
t.Fatalf("prompt_cache_key = %q, want %q", gotKey, expectedKey)
}
if gotConversation := httpReq.Header.Get("Conversation_id"); gotConversation != expectedKey {
t.Fatalf("Conversation_id = %q, want %q", gotConversation, expectedKey)
}
if gotSession := httpReq.Header.Get("Session_id"); gotSession != expectedKey {
t.Fatalf("Session_id = %q, want %q", gotSession, expectedKey)
}
httpReq2, err := executor.cacheHelper(ctx, sdktranslator.FromString("openai"), url, req, rawJSON)
if err != nil {
t.Fatalf("cacheHelper error (second call): %v", err)
}
body2, errRead2 := io.ReadAll(httpReq2.Body)
if errRead2 != nil {
t.Fatalf("read request body (second call): %v", errRead2)
}
gotKey2 := gjson.GetBytes(body2, "prompt_cache_key").String()
if gotKey2 != expectedKey {
t.Fatalf("prompt_cache_key (second call) = %q, want %q", gotKey2, expectedKey)
}
}
================================================
FILE: internal/runtime/executor/codex_executor_retry_test.go
================================================
package executor
import (
"net/http"
"strconv"
"testing"
"time"
)
func TestParseCodexRetryAfter(t *testing.T) {
now := time.Unix(1_700_000_000, 0)
t.Run("resets_in_seconds", func(t *testing.T) {
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":123}}`)
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
if retryAfter == nil {
t.Fatalf("expected retryAfter, got nil")
}
if *retryAfter != 123*time.Second {
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 123*time.Second)
}
})
t.Run("prefers resets_at", func(t *testing.T) {
resetAt := now.Add(5 * time.Minute).Unix()
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":1}}`)
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
if retryAfter == nil {
t.Fatalf("expected retryAfter, got nil")
}
if *retryAfter != 5*time.Minute {
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 5*time.Minute)
}
})
t.Run("fallback when resets_at is past", func(t *testing.T) {
resetAt := now.Add(-1 * time.Minute).Unix()
body := []byte(`{"error":{"type":"usage_limit_reached","resets_at":` + itoa(resetAt) + `,"resets_in_seconds":77}}`)
retryAfter := parseCodexRetryAfter(http.StatusTooManyRequests, body, now)
if retryAfter == nil {
t.Fatalf("expected retryAfter, got nil")
}
if *retryAfter != 77*time.Second {
t.Fatalf("retryAfter = %v, want %v", *retryAfter, 77*time.Second)
}
})
t.Run("non-429 status code", func(t *testing.T) {
body := []byte(`{"error":{"type":"usage_limit_reached","resets_in_seconds":30}}`)
if got := parseCodexRetryAfter(http.StatusBadRequest, body, now); got != nil {
t.Fatalf("expected nil for non-429, got %v", *got)
}
})
t.Run("non usage_limit_reached error type", func(t *testing.T) {
body := []byte(`{"error":{"type":"server_error","resets_in_seconds":30}}`)
if got := parseCodexRetryAfter(http.StatusTooManyRequests, body, now); got != nil {
t.Fatalf("expected nil for non-usage_limit_reached, got %v", *got)
}
})
}
func itoa(v int64) string {
return strconv.FormatInt(v, 10)
}
================================================
FILE: internal/runtime/executor/codex_websockets_executor.go
================================================
// Package executor provides runtime execution capabilities for various AI service providers.
// This file implements a Codex executor that uses the Responses API WebSocket transport.
package executor
import (
"bytes"
"context"
"fmt"
"io"
"net"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/net/proxy"
)
const (
codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-06"
codexResponsesWebsocketIdleTimeout = 5 * time.Minute
codexResponsesWebsocketHandshakeTO = 30 * time.Second
)
// CodexWebsocketsExecutor executes Codex Responses requests using a WebSocket transport.
//
// It preserves the existing CodexExecutor HTTP implementation as a fallback for endpoints
// not available over WebSocket (e.g. /responses/compact) and for websocket upgrade failures.
type CodexWebsocketsExecutor struct {
*CodexExecutor
sessMu sync.Mutex
sessions map[string]*codexWebsocketSession
}
type codexWebsocketSession struct {
sessionID string
reqMu sync.Mutex
connMu sync.Mutex
conn *websocket.Conn
wsURL string
authID string
writeMu sync.Mutex
activeMu sync.Mutex
activeCh chan codexWebsocketRead
activeDone <-chan struct{}
activeCancel context.CancelFunc
readerConn *websocket.Conn
}
func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor {
return &CodexWebsocketsExecutor{
CodexExecutor: NewCodexExecutor(cfg),
sessions: make(map[string]*codexWebsocketSession),
}
}
type codexWebsocketRead struct {
conn *websocket.Conn
msgType int
payload []byte
err error
}
func (s *codexWebsocketSession) setActive(ch chan codexWebsocketRead) {
if s == nil {
return
}
s.activeMu.Lock()
if s.activeCancel != nil {
s.activeCancel()
s.activeCancel = nil
s.activeDone = nil
}
s.activeCh = ch
if ch != nil {
activeCtx, activeCancel := context.WithCancel(context.Background())
s.activeDone = activeCtx.Done()
s.activeCancel = activeCancel
}
s.activeMu.Unlock()
}
func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) {
if s == nil {
return
}
s.activeMu.Lock()
if s.activeCh == ch {
s.activeCh = nil
if s.activeCancel != nil {
s.activeCancel()
}
s.activeCancel = nil
s.activeDone = nil
}
s.activeMu.Unlock()
}
func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error {
if s == nil {
return fmt.Errorf("codex websockets executor: session is nil")
}
if conn == nil {
return fmt.Errorf("codex websockets executor: websocket conn is nil")
}
s.writeMu.Lock()
defer s.writeMu.Unlock()
return conn.WriteMessage(msgType, payload)
}
func (s *codexWebsocketSession) configureConn(conn *websocket.Conn) {
if s == nil || conn == nil {
return
}
conn.SetPingHandler(func(appData string) error {
s.writeMu.Lock()
defer s.writeMu.Unlock()
// Reply pongs from the same write lock to avoid concurrent writes.
return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(10*time.Second))
})
}
func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if ctx == nil {
ctx = context.Background()
}
if opts.Alt == "responses/compact" {
return e.CodexExecutor.executeCompact(ctx, auth, req, opts)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, _ = sjson.SetBytes(body, "stream", true)
body, _ = sjson.DeleteBytes(body, "previous_response_id")
body, _ = sjson.DeleteBytes(body, "prompt_cache_retention")
body, _ = sjson.DeleteBytes(body, "safety_identifier")
if !gjson.GetBytes(body, "instructions").Exists() {
body, _ = sjson.SetBytes(body, "instructions", "")
}
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
if err != nil {
return resp, err
}
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
executionSessionID := executionSessionIDFromOptions(opts)
var sess *codexWebsocketSession
if executionSessionID != "" {
sess = e.getOrCreateSession(executionSessionID)
sess.reqMu.Lock()
defer sess.reqMu.Unlock()
}
wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
Body: wsReqBody,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if respHS != nil {
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil {
bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
}
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.Execute(ctx, auth, req, opts)
}
if respHS != nil && respHS.StatusCode > 0 {
return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
}
recordAPIResponseError(ctx, e.cfg, errDial)
return resp, errDial
}
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
defer func() {
reason := "completed"
if err != nil {
reason = "error"
}
logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, reason, err)
if errClose := conn.Close(); errClose != nil {
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
}
}()
}
var readCh chan codexWebsocketRead
if sess != nil {
readCh = make(chan codexWebsocketRead, 4096)
sess.setActive(readCh)
defer sess.clearActive(readCh)
}
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
// Retry once with a fresh websocket connection. This is mainly to handle
// upstream closing the socket between sequential requests within the same
// execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry == nil && connRetry != nil {
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
Body: wsReqBodyRetry,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil {
conn = connRetry
wsReqBody = wsReqBodyRetry
} else {
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
recordAPIResponseError(ctx, e.cfg, errSendRetry)
return resp, errSendRetry
}
} else {
recordAPIResponseError(ctx, e.cfg, errDialRetry)
return resp, errDialRetry
}
} else {
recordAPIResponseError(ctx, e.cfg, errSend)
return resp, errSend
}
}
for {
if ctx != nil && ctx.Err() != nil {
return resp, ctx.Err()
}
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
if msgType != websocket.TextMessage {
if msgType == websocket.BinaryMessage {
err = fmt.Errorf("codex websockets executor: unexpected binary message")
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
}
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
continue
}
payload = bytes.TrimSpace(payload)
if len(payload) == 0 {
continue
}
appendAPIResponseChunk(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok {
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
}
recordAPIResponseError(ctx, e.cfg, wsErr)
return resp, wsErr
}
payload = normalizeCodexWebsocketCompletion(payload)
eventType := gjson.GetBytes(payload, "type").String()
if eventType == "response.completed" {
if detail, ok := parseCodexUsage(payload); ok {
reporter.publish(ctx, detail)
}
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
}
}
func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model)
if ctx == nil {
ctx = context.Background()
}
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := codexCreds(auth)
if baseURL == "" {
baseURL = "https://chatgpt.com/backend-api/codex"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("codex")
body := req.Payload
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel)
httpURL := strings.TrimSuffix(baseURL, "/") + "/responses"
wsURL, err := buildCodexResponsesWebsocketURL(httpURL)
if err != nil {
return nil, err
}
body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body)
wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey, e.cfg)
var authID, authLabel, authType, authValue string
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
executionSessionID := executionSessionIDFromOptions(opts)
var sess *codexWebsocketSession
if executionSessionID != "" {
sess = e.getOrCreateSession(executionSessionID)
if sess != nil {
sess.reqMu.Lock()
}
}
wsReqBody := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
Body: wsReqBody,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
var upstreamHeaders http.Header
if respHS != nil {
upstreamHeaders = respHS.Header.Clone()
recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone())
}
if errDial != nil {
bodyErr := websocketHandshakeBody(respHS)
if len(bodyErr) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bodyErr)
}
if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired {
return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts)
}
if respHS != nil && respHS.StatusCode > 0 {
return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)}
}
recordAPIResponseError(ctx, e.cfg, errDial)
if sess != nil {
sess.reqMu.Unlock()
}
return nil, errDial
}
closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error")
if sess == nil {
logCodexWebsocketConnected(executionSessionID, authID, wsURL)
}
var readCh chan codexWebsocketRead
if sess != nil {
readCh = make(chan codexWebsocketRead, 4096)
sess.setActive(readCh)
}
if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil {
recordAPIResponseError(ctx, e.cfg, errSend)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "send_error", errSend)
// Retry once with a new websocket connection for the same execution session.
connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders)
if errDialRetry != nil || connRetry == nil {
recordAPIResponseError(ctx, e.cfg, errDialRetry)
sess.clearActive(readCh)
sess.reqMu.Unlock()
return nil, errDialRetry
}
wsReqBodyRetry := buildCodexWebsocketRequestBody(body)
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: wsURL,
Method: "WEBSOCKET",
Headers: wsHeaders.Clone(),
Body: wsReqBodyRetry,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil {
recordAPIResponseError(ctx, e.cfg, errSendRetry)
e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry)
sess.clearActive(readCh)
sess.reqMu.Unlock()
return nil, errSendRetry
}
conn = connRetry
wsReqBody = wsReqBodyRetry
} else {
logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, "send_error", errSend)
if errClose := conn.Close(); errClose != nil {
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
}
return nil, errSend
}
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
terminateReason := "completed"
var terminateErr error
defer close(out)
defer func() {
if sess != nil {
sess.clearActive(readCh)
sess.reqMu.Unlock()
return
}
logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, terminateReason, terminateErr)
if errClose := conn.Close(); errClose != nil {
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
}
}()
send := func(chunk cliproxyexecutor.StreamChunk) bool {
if ctx == nil {
out <- chunk
return true
}
select {
case out <- chunk:
return true
case <-ctx.Done():
return false
}
}
var param any
for {
if ctx != nil && ctx.Err() != nil {
terminateReason = "context_done"
terminateErr = ctx.Err()
_ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()})
return
}
msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh)
if errRead != nil {
if sess != nil && ctx != nil && ctx.Err() != nil {
terminateReason = "context_done"
terminateErr = ctx.Err()
_ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()})
return
}
terminateReason = "read_error"
terminateErr = errRead
recordAPIResponseError(ctx, e.cfg, errRead)
reporter.publishFailure(ctx)
_ = send(cliproxyexecutor.StreamChunk{Err: errRead})
return
}
if msgType != websocket.TextMessage {
if msgType == websocket.BinaryMessage {
err = fmt.Errorf("codex websockets executor: unexpected binary message")
terminateReason = "unexpected_binary"
terminateErr = err
recordAPIResponseError(ctx, e.cfg, err)
reporter.publishFailure(ctx)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err)
}
_ = send(cliproxyexecutor.StreamChunk{Err: err})
return
}
continue
}
payload = bytes.TrimSpace(payload)
if len(payload) == 0 {
continue
}
appendAPIResponseChunk(ctx, e.cfg, payload)
if wsErr, ok := parseCodexWebsocketError(payload); ok {
terminateReason = "upstream_error"
terminateErr = wsErr
recordAPIResponseError(ctx, e.cfg, wsErr)
reporter.publishFailure(ctx)
if sess != nil {
e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr)
}
_ = send(cliproxyexecutor.StreamChunk{Err: wsErr})
return
}
payload = normalizeCodexWebsocketCompletion(payload)
eventType := gjson.GetBytes(payload, "type").String()
if eventType == "response.completed" || eventType == "response.done" {
if detail, ok := parseCodexUsage(payload); ok {
reporter.publish(ctx, detail)
}
}
line := encodeCodexWebsocketAsSSE(payload)
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, ¶m)
for i := range chunks {
if !send(cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}) {
terminateReason = "context_done"
terminateErr = ctx.Err()
return
}
}
if eventType == "response.completed" || eventType == "response.done" {
return
}
}
}()
return &cliproxyexecutor.StreamResult{Headers: upstreamHeaders, Chunks: out}, nil
}
func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
dialer := newProxyAwareWebsocketDialer(e.cfg, auth)
dialer.HandshakeTimeout = codexResponsesWebsocketHandshakeTO
dialer.EnableCompression = true
if ctx == nil {
ctx = context.Background()
}
conn, resp, err := dialer.DialContext(ctx, wsURL, headers)
if conn != nil {
// Avoid gorilla/websocket flate tail validation issues on some upstreams/Go versions.
// Negotiating permessage-deflate is fine; we just don't compress outbound messages.
conn.EnableWriteCompression(false)
}
return conn, resp, err
}
func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) error {
if sess != nil {
return sess.writeMessage(conn, websocket.TextMessage, payload)
}
if conn == nil {
return fmt.Errorf("codex websockets executor: websocket conn is nil")
}
return conn.WriteMessage(websocket.TextMessage, payload)
}
func buildCodexWebsocketRequestBody(body []byte) []byte {
if len(body) == 0 {
return nil
}
// Match codex-rs websocket v2 semantics: every request is `response.create`.
// Incremental follow-up turns continue on the same websocket using
// `previous_response_id` + incremental `input`, not `response.append`.
wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create")
if errSet == nil && len(wsReqBody) > 0 {
return wsReqBody
}
fallback := bytes.Clone(body)
fallback, _ = sjson.SetBytes(fallback, "type", "response.create")
return fallback
}
func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession, conn *websocket.Conn, readCh chan codexWebsocketRead) (int, []byte, error) {
if sess == nil {
if conn == nil {
return 0, nil, fmt.Errorf("codex websockets executor: websocket conn is nil")
}
_ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout))
msgType, payload, errRead := conn.ReadMessage()
return msgType, payload, errRead
}
if conn == nil {
return 0, nil, fmt.Errorf("codex websockets executor: websocket conn is nil")
}
if readCh == nil {
return 0, nil, fmt.Errorf("codex websockets executor: session read channel is nil")
}
for {
select {
case <-ctx.Done():
return 0, nil, ctx.Err()
case ev, ok := <-readCh:
if !ok {
return 0, nil, fmt.Errorf("codex websockets executor: session read channel closed")
}
if ev.conn != conn {
continue
}
if ev.err != nil {
return 0, nil, ev.err
}
return ev.msgType, ev.payload, nil
}
}
}
func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer {
dialer := &websocket.Dialer{
Proxy: http.ProxyFromEnvironment,
HandshakeTimeout: codexResponsesWebsocketHandshakeTO,
EnableCompression: true,
NetDialContext: (&net.Dialer{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).DialContext,
}
proxyURL := ""
if auth != nil {
proxyURL = strings.TrimSpace(auth.ProxyURL)
}
if proxyURL == "" && cfg != nil {
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
if proxyURL == "" {
return dialer
}
setting, errParse := proxyutil.Parse(proxyURL)
if errParse != nil {
log.Errorf("codex websockets executor: %v", errParse)
return dialer
}
switch setting.Mode {
case proxyutil.ModeDirect:
dialer.Proxy = nil
return dialer
case proxyutil.ModeProxy:
default:
return dialer
}
switch setting.URL.Scheme {
case "socks5":
var proxyAuth *proxy.Auth
if setting.URL.User != nil {
username := setting.URL.User.Username()
password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5)
return dialer
}
dialer.Proxy = nil
dialer.NetDialContext = func(_ context.Context, network, addr string) (net.Conn, error) {
return socksDialer.Dial(network, addr)
}
case "http", "https":
dialer.Proxy = http.ProxyURL(setting.URL)
default:
log.Errorf("codex websockets executor: unsupported proxy scheme: %s", setting.URL.Scheme)
}
return dialer
}
func buildCodexResponsesWebsocketURL(httpURL string) (string, error) {
parsed, err := url.Parse(strings.TrimSpace(httpURL))
if err != nil {
return "", err
}
switch strings.ToLower(parsed.Scheme) {
case "http":
parsed.Scheme = "ws"
case "https":
parsed.Scheme = "wss"
}
return parsed.String(), nil
}
func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecutor.Request, rawJSON []byte) ([]byte, http.Header) {
headers := http.Header{}
if len(rawJSON) == 0 {
return rawJSON, headers
}
var cache codexCache
if from == "claude" {
userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id")
if userIDResult.Exists() {
key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String())
if cached, ok := getCodexCache(key); ok {
cache = cached
} else {
cache = codexCache{
ID: uuid.New().String(),
Expire: time.Now().Add(1 * time.Hour),
}
setCodexCache(key, cache)
}
}
} else if from == "openai-response" {
if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() {
cache.ID = promptCacheKey.String()
}
}
if cache.ID != "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID)
headers.Set("Conversation_id", cache.ID)
headers.Set("Session_id", cache.ID)
}
return rawJSON, headers
}
func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string, cfg *config.Config) http.Header {
if headers == nil {
headers = http.Header{}
}
if strings.TrimSpace(token) != "" {
headers.Set("Authorization", "Bearer "+token)
}
var ginHeaders http.Header
if ginCtx := ginContextFrom(ctx); ginCtx != nil && ginCtx.Request != nil {
ginHeaders = ginCtx.Request.Header
}
cfgUserAgent, cfgBetaFeatures := codexHeaderDefaults(cfg, auth)
ensureHeaderWithPriority(headers, ginHeaders, "x-codex-beta-features", cfgBetaFeatures, "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "")
misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "")
misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "")
misc.EnsureHeader(headers, ginHeaders, "Version", codexClientVersion)
betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta"))
if betaHeader == "" && ginHeaders != nil {
betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta"))
}
if betaHeader == "" || !strings.Contains(betaHeader, "responses_websockets=") {
betaHeader = codexResponsesWebsocketBetaHeaderValue
}
headers.Set("OpenAI-Beta", betaHeader)
misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString())
ensureHeaderWithConfigPrecedence(headers, ginHeaders, "User-Agent", cfgUserAgent, codexUserAgent)
isAPIKey := false
if auth != nil && auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
isAPIKey = true
}
}
if !isAPIKey {
headers.Set("Originator", "codex_cli_rs")
if auth != nil && auth.Metadata != nil {
if accountID, ok := auth.Metadata["account_id"].(string); ok {
if trimmed := strings.TrimSpace(accountID); trimmed != "" {
headers.Set("Chatgpt-Account-Id", trimmed)
}
}
}
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(&http.Request{Header: headers}, attrs)
return headers
}
func codexHeaderDefaults(cfg *config.Config, auth *cliproxyauth.Auth) (string, string) {
if cfg == nil || auth == nil {
return "", ""
}
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" {
return "", ""
}
}
return strings.TrimSpace(cfg.CodexHeaderDefaults.UserAgent), strings.TrimSpace(cfg.CodexHeaderDefaults.BetaFeatures)
}
func ensureHeaderWithPriority(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(target.Get(key)) != "" {
return
}
if source != nil {
if val := strings.TrimSpace(source.Get(key)); val != "" {
target.Set(key, val)
return
}
}
if val := strings.TrimSpace(configValue); val != "" {
target.Set(key, val)
return
}
if val := strings.TrimSpace(fallbackValue); val != "" {
target.Set(key, val)
}
}
func ensureHeaderWithConfigPrecedence(target http.Header, source http.Header, key, configValue, fallbackValue string) {
if target == nil {
return
}
if strings.TrimSpace(target.Get(key)) != "" {
return
}
if val := strings.TrimSpace(configValue); val != "" {
target.Set(key, val)
return
}
if source != nil {
if val := strings.TrimSpace(source.Get(key)); val != "" {
target.Set(key, val)
return
}
}
if val := strings.TrimSpace(fallbackValue); val != "" {
target.Set(key, val)
}
}
type statusErrWithHeaders struct {
statusErr
headers http.Header
}
func (e statusErrWithHeaders) Headers() http.Header {
if e.headers == nil {
return nil
}
return e.headers.Clone()
}
func parseCodexWebsocketError(payload []byte) (error, bool) {
if len(payload) == 0 {
return nil, false
}
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "error" {
return nil, false
}
status := int(gjson.GetBytes(payload, "status").Int())
if status == 0 {
status = int(gjson.GetBytes(payload, "status_code").Int())
}
if status <= 0 {
return nil, false
}
out := []byte(`{}`)
if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() {
raw := errNode.Raw
if errNode.Type == gjson.String {
raw = errNode.Raw
}
out, _ = sjson.SetRawBytes(out, "error", []byte(raw))
} else {
out, _ = sjson.SetBytes(out, "error.type", "server_error")
out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status))
}
headers := parseCodexWebsocketErrorHeaders(payload)
return statusErrWithHeaders{
statusErr: statusErr{code: status, msg: string(out)},
headers: headers,
}, true
}
func parseCodexWebsocketErrorHeaders(payload []byte) http.Header {
headersNode := gjson.GetBytes(payload, "headers")
if !headersNode.Exists() || !headersNode.IsObject() {
return nil
}
mapped := make(http.Header)
headersNode.ForEach(func(key, value gjson.Result) bool {
name := strings.TrimSpace(key.String())
if name == "" {
return true
}
switch value.Type {
case gjson.String:
if v := strings.TrimSpace(value.String()); v != "" {
mapped.Set(name, v)
}
case gjson.Number, gjson.True, gjson.False:
if v := strings.TrimSpace(value.Raw); v != "" {
mapped.Set(name, v)
}
default:
}
return true
})
if len(mapped) == 0 {
return nil
}
return mapped
}
func normalizeCodexWebsocketCompletion(payload []byte) []byte {
if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.done" {
updated, err := sjson.SetBytes(payload, "type", "response.completed")
if err == nil && len(updated) > 0 {
return updated
}
}
return payload
}
func encodeCodexWebsocketAsSSE(payload []byte) []byte {
if len(payload) == 0 {
return nil
}
line := make([]byte, 0, len("data: ")+len(payload))
line = append(line, []byte("data: ")...)
line = append(line, payload...)
return line
}
func websocketHandshakeBody(resp *http.Response) []byte {
if resp == nil || resp.Body == nil {
return nil
}
body, _ := io.ReadAll(resp.Body)
closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error")
if len(body) == 0 {
return nil
}
return body
}
func closeHTTPResponseBody(resp *http.Response, logPrefix string) {
if resp == nil || resp.Body == nil {
return
}
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("%s: %v", logPrefix, errClose)
}
}
func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string {
if len(opts.Metadata) == 0 {
return ""
}
raw, ok := opts.Metadata[cliproxyexecutor.ExecutionSessionMetadataKey]
if !ok || raw == nil {
return ""
}
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWebsocketSession {
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return nil
}
e.sessMu.Lock()
defer e.sessMu.Unlock()
if e.sessions == nil {
e.sessions = make(map[string]*codexWebsocketSession)
}
if sess, ok := e.sessions[sessionID]; ok && sess != nil {
return sess
}
sess := &codexWebsocketSession{sessionID: sessionID}
e.sessions[sessionID] = sess
return sess
}
func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) {
if sess == nil {
return e.dialCodexWebsocket(ctx, auth, wsURL, headers)
}
sess.connMu.Lock()
conn := sess.conn
readerConn := sess.readerConn
sess.connMu.Unlock()
if conn != nil {
if readerConn != conn {
sess.connMu.Lock()
sess.readerConn = conn
sess.connMu.Unlock()
sess.configureConn(conn)
go e.readUpstreamLoop(sess, conn)
}
return conn, nil, nil
}
conn, resp, errDial := e.dialCodexWebsocket(ctx, auth, wsURL, headers)
if errDial != nil {
return nil, resp, errDial
}
sess.connMu.Lock()
if sess.conn != nil {
previous := sess.conn
sess.connMu.Unlock()
if errClose := conn.Close(); errClose != nil {
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
}
return previous, nil, nil
}
sess.conn = conn
sess.wsURL = wsURL
sess.authID = authID
sess.readerConn = conn
sess.connMu.Unlock()
sess.configureConn(conn)
go e.readUpstreamLoop(sess, conn)
logCodexWebsocketConnected(sess.sessionID, authID, wsURL)
return conn, resp, nil
}
func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, conn *websocket.Conn) {
if e == nil || sess == nil || conn == nil {
return
}
for {
_ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout))
msgType, payload, errRead := conn.ReadMessage()
if errRead != nil {
sess.activeMu.Lock()
ch := sess.activeCh
done := sess.activeDone
sess.activeMu.Unlock()
if ch != nil {
select {
case ch <- codexWebsocketRead{conn: conn, err: errRead}:
case <-done:
default:
}
sess.clearActive(ch)
close(ch)
}
e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead)
return
}
if msgType != websocket.TextMessage {
if msgType == websocket.BinaryMessage {
errBinary := fmt.Errorf("codex websockets executor: unexpected binary message")
sess.activeMu.Lock()
ch := sess.activeCh
done := sess.activeDone
sess.activeMu.Unlock()
if ch != nil {
select {
case ch <- codexWebsocketRead{conn: conn, err: errBinary}:
case <-done:
default:
}
sess.clearActive(ch)
close(ch)
}
e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary)
return
}
continue
}
sess.activeMu.Lock()
ch := sess.activeCh
done := sess.activeDone
sess.activeMu.Unlock()
if ch == nil {
continue
}
select {
case ch <- codexWebsocketRead{conn: conn, msgType: msgType, payload: payload}:
case <-done:
}
}
}
func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSession, conn *websocket.Conn, reason string, err error) {
if sess == nil || conn == nil {
return
}
sess.connMu.Lock()
current := sess.conn
authID := sess.authID
wsURL := sess.wsURL
sessionID := sess.sessionID
if current == nil || current != conn {
sess.connMu.Unlock()
return
}
sess.conn = nil
if sess.readerConn == conn {
sess.readerConn = nil
}
sess.connMu.Unlock()
logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, err)
if errClose := conn.Close(); errClose != nil {
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
}
}
func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) {
sessionID = strings.TrimSpace(sessionID)
if e == nil {
return
}
if sessionID == "" {
return
}
if sessionID == cliproxyauth.CloseAllExecutionSessionsID {
e.closeAllExecutionSessions("executor_replaced")
return
}
e.sessMu.Lock()
sess := e.sessions[sessionID]
delete(e.sessions, sessionID)
e.sessMu.Unlock()
e.closeExecutionSession(sess, "session_closed")
}
func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) {
if e == nil {
return
}
e.sessMu.Lock()
sessions := make([]*codexWebsocketSession, 0, len(e.sessions))
for sessionID, sess := range e.sessions {
delete(e.sessions, sessionID)
if sess != nil {
sessions = append(sessions, sess)
}
}
e.sessMu.Unlock()
for i := range sessions {
e.closeExecutionSession(sessions[i], reason)
}
}
func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) {
if sess == nil {
return
}
reason = strings.TrimSpace(reason)
if reason == "" {
reason = "session_closed"
}
sess.connMu.Lock()
conn := sess.conn
authID := sess.authID
wsURL := sess.wsURL
sess.conn = nil
if sess.readerConn == conn {
sess.readerConn = nil
}
sessionID := sess.sessionID
sess.connMu.Unlock()
if conn == nil {
return
}
logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, nil)
if errClose := conn.Close(); errClose != nil {
log.Errorf("codex websockets executor: close websocket error: %v", errClose)
}
}
func logCodexWebsocketConnected(sessionID string, authID string, wsURL string) {
log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL))
}
func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string, reason string, err error) {
if err != nil {
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason), err)
return
}
log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason))
}
// CodexAutoExecutor routes Codex requests to the websocket transport only when:
// 1. The downstream transport is websocket, and
// 2. The selected auth enables websockets.
//
// For non-websocket downstream requests, it always uses the legacy HTTP implementation.
type CodexAutoExecutor struct {
httpExec *CodexExecutor
wsExec *CodexWebsocketsExecutor
}
func NewCodexAutoExecutor(cfg *config.Config) *CodexAutoExecutor {
return &CodexAutoExecutor{
httpExec: NewCodexExecutor(cfg),
wsExec: NewCodexWebsocketsExecutor(cfg),
}
}
func (e *CodexAutoExecutor) Identifier() string { return "codex" }
func (e *CodexAutoExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if e == nil || e.httpExec == nil {
return nil
}
return e.httpExec.PrepareRequest(req, auth)
}
func (e *CodexAutoExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if e == nil || e.httpExec == nil {
return nil, fmt.Errorf("codex auto executor: http executor is nil")
}
return e.httpExec.HttpRequest(ctx, auth, req)
}
func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if e == nil || e.httpExec == nil || e.wsExec == nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex auto executor: executor is nil")
}
if cliproxyexecutor.DownstreamWebsocket(ctx) && codexWebsocketsEnabled(auth) {
return e.wsExec.Execute(ctx, auth, req, opts)
}
return e.httpExec.Execute(ctx, auth, req, opts)
}
func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if e == nil || e.httpExec == nil || e.wsExec == nil {
return nil, fmt.Errorf("codex auto executor: executor is nil")
}
if cliproxyexecutor.DownstreamWebsocket(ctx) && codexWebsocketsEnabled(auth) {
return e.wsExec.ExecuteStream(ctx, auth, req, opts)
}
return e.httpExec.ExecuteStream(ctx, auth, req, opts)
}
func (e *CodexAutoExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
if e == nil || e.httpExec == nil {
return nil, fmt.Errorf("codex auto executor: http executor is nil")
}
return e.httpExec.Refresh(ctx, auth)
}
func (e *CodexAutoExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if e == nil || e.httpExec == nil {
return cliproxyexecutor.Response{}, fmt.Errorf("codex auto executor: http executor is nil")
}
return e.httpExec.CountTokens(ctx, auth, req, opts)
}
func (e *CodexAutoExecutor) CloseExecutionSession(sessionID string) {
if e == nil || e.wsExec == nil {
return
}
e.wsExec.CloseExecutionSession(sessionID)
}
func codexWebsocketsEnabled(auth *cliproxyauth.Auth) bool {
if auth == nil {
return false
}
if len(auth.Attributes) > 0 {
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed
}
}
}
if len(auth.Metadata) == 0 {
return false
}
raw, ok := auth.Metadata["websockets"]
if !ok || raw == nil {
return false
}
switch v := raw.(type) {
case bool:
return v
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
if errParse == nil {
return parsed
}
default:
}
return false
}
================================================
FILE: internal/runtime/executor/codex_websockets_executor_test.go
================================================
package executor
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
func TestBuildCodexWebsocketRequestBodyPreservesPreviousResponseID(t *testing.T) {
body := []byte(`{"model":"gpt-5-codex","previous_response_id":"resp-1","input":[{"type":"message","id":"msg-1"}]}`)
wsReqBody := buildCodexWebsocketRequestBody(body)
if got := gjson.GetBytes(wsReqBody, "type").String(); got != "response.create" {
t.Fatalf("type = %s, want response.create", got)
}
if got := gjson.GetBytes(wsReqBody, "previous_response_id").String(); got != "resp-1" {
t.Fatalf("previous_response_id = %s, want resp-1", got)
}
if gjson.GetBytes(wsReqBody, "input.0.id").String() != "msg-1" {
t.Fatalf("input item id mismatch")
}
if got := gjson.GetBytes(wsReqBody, "type").String(); got == "response.append" {
t.Fatalf("unexpected websocket request type: %s", got)
}
}
func TestApplyCodexWebsocketHeadersDefaultsToCurrentResponsesBeta(t *testing.T) {
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, nil, "", nil)
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func TestApplyCodexWebsocketHeadersUsesConfigDefaultsForOAuth(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "my-codex-client/1.0",
BetaFeatures: "feature-a,feature-b",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "my-codex-client/1.0" {
t.Fatalf("User-Agent = %s, want %s", got, "my-codex-client/1.0")
}
if got := headers.Get("x-codex-beta-features"); got != "feature-a,feature-b" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "feature-a,feature-b")
}
if got := headers.Get("OpenAI-Beta"); got != codexResponsesWebsocketBetaHeaderValue {
t.Fatalf("OpenAI-Beta = %s, want %s", got, codexResponsesWebsocketBetaHeaderValue)
}
}
func TestApplyCodexWebsocketHeadersPrefersExistingHeadersOverClientAndConfig(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
"X-Codex-Beta-Features": "client-beta",
})
headers := http.Header{}
headers.Set("User-Agent", "existing-ua")
headers.Set("X-Codex-Beta-Features", "existing-beta")
got := applyCodexWebsocketHeaders(ctx, headers, auth, "", cfg)
if gotVal := got.Get("User-Agent"); gotVal != "existing-ua" {
t.Fatalf("User-Agent = %s, want %s", gotVal, "existing-ua")
}
if gotVal := got.Get("x-codex-beta-features"); gotVal != "existing-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", gotVal, "existing-beta")
}
}
func TestApplyCodexWebsocketHeadersConfigUserAgentOverridesClientHeader(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
ctx := contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
"X-Codex-Beta-Features": "client-beta",
})
headers := applyCodexWebsocketHeaders(ctx, http.Header{}, auth, "", cfg)
if got := headers.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := headers.Get("x-codex-beta-features"); got != "client-beta" {
t.Fatalf("x-codex-beta-features = %s, want %s", got, "client-beta")
}
}
func TestApplyCodexWebsocketHeadersIgnoresConfigForAPIKeyAuth(t *testing.T) {
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Attributes: map[string]string{"api_key": "sk-test"},
}
headers := applyCodexWebsocketHeaders(context.Background(), http.Header{}, auth, "sk-test", cfg)
if got := headers.Get("User-Agent"); got != codexUserAgent {
t.Fatalf("User-Agent = %s, want %s", got, codexUserAgent)
}
if got := headers.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func TestApplyCodexHeadersUsesConfigUserAgentForOAuth(t *testing.T) {
req, err := http.NewRequest(http.MethodPost, "https://example.com/responses", nil)
if err != nil {
t.Fatalf("NewRequest() error = %v", err)
}
cfg := &config.Config{
CodexHeaderDefaults: config.CodexHeaderDefaults{
UserAgent: "config-ua",
BetaFeatures: "config-beta",
},
}
auth := &cliproxyauth.Auth{
Provider: "codex",
Metadata: map[string]any{"email": "user@example.com"},
}
req = req.WithContext(contextWithGinHeaders(map[string]string{
"User-Agent": "client-ua",
}))
applyCodexHeaders(req, auth, "oauth-token", true, cfg)
if got := req.Header.Get("User-Agent"); got != "config-ua" {
t.Fatalf("User-Agent = %s, want %s", got, "config-ua")
}
if got := req.Header.Get("x-codex-beta-features"); got != "" {
t.Fatalf("x-codex-beta-features = %q, want empty", got)
}
}
func contextWithGinHeaders(headers map[string]string) context.Context {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
ginCtx, _ := gin.CreateTestContext(recorder)
ginCtx.Request = httptest.NewRequest(http.MethodPost, "/", nil)
ginCtx.Request.Header = make(http.Header, len(headers))
for key, value := range headers {
ginCtx.Request.Header.Set(key, value)
}
return context.WithValue(context.Background(), "gin", ginCtx)
}
func TestNewProxyAwareWebsocketDialerDirectDisablesProxy(t *testing.T) {
t.Parallel()
dialer := newProxyAwareWebsocketDialer(
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
)
if dialer.Proxy != nil {
t.Fatal("expected websocket proxy function to be nil for direct mode")
}
}
================================================
FILE: internal/runtime/executor/gemini_cli_executor.go
================================================
// Package executor provides runtime execution capabilities for various AI service providers.
// This file implements the Gemini CLI executor that talks to Cloud Code Assist endpoints
// using OAuth credentials from auth metadata.
package executor
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"regexp"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const (
codeAssistEndpoint = "https://cloudcode-pa.googleapis.com"
codeAssistVersion = "v1internal"
geminiOAuthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
geminiOAuthClientSecret = "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"
)
var geminiOAuthScopes = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
// GeminiCLIExecutor talks to the Cloud Code Assist endpoint using OAuth credentials from auth metadata.
type GeminiCLIExecutor struct {
cfg *config.Config
}
// NewGeminiCLIExecutor creates a new Gemini CLI executor instance.
//
// Parameters:
// - cfg: The application configuration
//
// Returns:
// - *GeminiCLIExecutor: A new Gemini CLI executor instance
func NewGeminiCLIExecutor(cfg *config.Config) *GeminiCLIExecutor {
return &GeminiCLIExecutor{cfg: cfg}
}
// Identifier returns the executor identifier.
func (e *GeminiCLIExecutor) Identifier() string { return "gemini-cli" }
// PrepareRequest injects Gemini CLI credentials into the outgoing HTTP request.
func (e *GeminiCLIExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
tokenSource, _, errSource := prepareGeminiCLITokenSource(req.Context(), e.cfg, auth)
if errSource != nil {
return errSource
}
tok, errTok := tokenSource.Token()
if errTok != nil {
return errTok
}
if strings.TrimSpace(tok.AccessToken) == "" {
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
req.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(req, "unknown")
return nil
}
// HttpRequest injects Gemini CLI credentials into the request and executes it.
func (e *GeminiCLIExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("gemini-cli executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
if err != nil {
return resp, err
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
requestedModel := payloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
action := "generateContent"
if req.Metadata != nil {
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
action = "countTokens"
}
}
projectID := resolveGeminiProjectID(auth)
models := cliPreviewFallbackOrder(baseModel)
if len(models) == 0 || models[0] != baseModel {
models = append([]string{baseModel}, models...)
}
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
var authID, authLabel, authType, authValue string
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
var lastStatus int
var lastBody []byte
for idx, attemptModel := range models {
payload := append([]byte(nil), basePayload...)
if action == "countTokens" {
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
} else {
payload = setJSONField(payload, "project", projectID)
payload = setJSONField(payload, "model", attemptModel)
}
tok, errTok := tokenSource.Token()
if errTok != nil {
err = errTok
return resp, err
}
updateGeminiCLITokenMetadata(auth, baseTokenData, tok)
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if errReq != nil {
err = errReq
return resp, err
}
reqHTTP.Header.Set("Content-Type", "application/json")
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(),
Body: payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpResp, errDo := httpClient.Do(reqHTTP)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
err = errDo
return resp, err
}
data, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini cli executor: close response body error: %v", errClose)
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 {
if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
} else {
log.Debug("gemini cli executor: rate limited, no additional fallback model")
}
continue
}
err = newGeminiStatusErr(httpResp.StatusCode, data)
return resp, err
}
if len(lastBody) > 0 {
appendAPIResponseChunk(ctx, e.cfg, lastBody)
}
if lastStatus == 0 {
lastStatus = 429
}
err = newGeminiStatusErr(lastStatus, lastBody)
return resp, err
}
// ExecuteStream performs a streaming request to the Gemini CLI API.
func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
if err != nil {
return nil, err
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
basePayload = fixGeminiCLIImageAspectRatio(baseModel, basePayload)
requestedModel := payloadRequestedModel(opts, req.Model)
basePayload = applyPayloadConfigWithRoot(e.cfg, baseModel, "gemini", "request", basePayload, originalTranslated, requestedModel)
projectID := resolveGeminiProjectID(auth)
models := cliPreviewFallbackOrder(baseModel)
if len(models) == 0 || models[0] != baseModel {
models = append([]string{baseModel}, models...)
}
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
var authID, authLabel, authType, authValue string
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
var lastStatus int
var lastBody []byte
for idx, attemptModel := range models {
payload := append([]byte(nil), basePayload...)
payload = setJSONField(payload, "project", projectID)
payload = setJSONField(payload, "model", attemptModel)
tok, errTok := tokenSource.Token()
if errTok != nil {
err = errTok
return nil, err
}
updateGeminiCLITokenMetadata(auth, baseTokenData, tok)
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if errReq != nil {
err = errReq
return nil, err
}
reqHTTP.Header.Set("Content-Type", "application/json")
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, attemptModel)
reqHTTP.Header.Set("Accept", "text/event-stream")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(),
Body: payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpResp, errDo := httpClient.Do(reqHTTP)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
err = errDo
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, errRead := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini cli executor: close response body error: %v", errClose)
}
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
err = errRead
return nil, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
lastStatus = httpResp.StatusCode
lastBody = append([]byte(nil), data...)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
if httpResp.StatusCode == 429 {
if idx+1 < len(models) {
log.Debugf("gemini cli executor: rate limited, retrying with next model: %s", models[idx+1])
} else {
log.Debug("gemini cli executor: rate limited, no additional fallback model")
}
continue
}
err = newGeminiStatusErr(httpResp.StatusCode, data)
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(resp *http.Response, reqBody []byte, attemptModel string) {
defer close(out)
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("gemini cli executor: close response body error: %v", errClose)
}
}()
if opts.Alt == "" {
scanner := bufio.NewScanner(resp.Body)
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiCLIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if bytes.HasPrefix(line, dataTag) {
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
}
}
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
return
}
data, errRead := io.ReadAll(resp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errRead}
return
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiCLIUsage(data))
var param any
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
for i := range segments {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
}
}(httpResp, append([]byte(nil), payload...), attemptModel)
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
if len(lastBody) > 0 {
appendAPIResponseChunk(ctx, e.cfg, lastBody)
}
if lastStatus == 0 {
lastStatus = 429
}
err = newGeminiStatusErr(lastStatus, lastBody)
return nil, err
}
// CountTokens counts tokens for the given request using the Gemini CLI API.
func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, e.cfg, auth)
if err != nil {
return cliproxyexecutor.Response{}, err
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini-cli")
models := cliPreviewFallbackOrder(baseModel)
if len(models) == 0 || models[0] != baseModel {
models = append([]string{baseModel}, models...)
}
httpClient := newHTTPClient(ctx, e.cfg, auth, 0)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
var lastStatus int
var lastBody []byte
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
// Gemini CLI endpoint when iterating fallback variants.
for range models {
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
payload = deleteJSONField(payload, "project")
payload = deleteJSONField(payload, "model")
payload = deleteJSONField(payload, "request.safetySettings")
payload = fixGeminiCLIImageAspectRatio(baseModel, payload)
tok, errTok := tokenSource.Token()
if errTok != nil {
return cliproxyexecutor.Response{}, errTok
}
updateGeminiCLITokenMetadata(auth, baseTokenData, tok)
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "countTokens")
if opts.Alt != "" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload))
if errReq != nil {
return cliproxyexecutor.Response{}, errReq
}
reqHTTP.Header.Set("Content-Type", "application/json")
reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken)
applyGeminiCLIHeaders(reqHTTP, baseModel)
reqHTTP.Header.Set("Accept", "application/json")
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: reqHTTP.Header.Clone(),
Body: payload,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
resp, errDo := httpClient.Do(reqHTTP)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
data, errRead := io.ReadAll(resp.Body)
_ = resp.Body.Close()
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if resp.StatusCode >= 200 && resp.StatusCode < 300 {
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
lastStatus = resp.StatusCode
lastBody = append([]byte(nil), data...)
if resp.StatusCode == 429 {
log.Debugf("gemini cli executor: rate limited, retrying with next model")
continue
}
break
}
if lastStatus == 0 {
lastStatus = 429
}
return cliproxyexecutor.Response{}, newGeminiStatusErr(lastStatus, lastBody)
}
// Refresh refreshes the authentication credentials (no-op for Gemini CLI).
func (e *GeminiCLIExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}
func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth) (oauth2.TokenSource, map[string]any, error) {
metadata := geminiOAuthMetadata(auth)
if auth == nil || metadata == nil {
return nil, nil, fmt.Errorf("gemini-cli auth metadata missing")
}
var base map[string]any
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
base = cloneMap(tokenRaw)
} else {
base = make(map[string]any)
}
var token oauth2.Token
if len(base) > 0 {
if raw, err := json.Marshal(base); err == nil {
_ = json.Unmarshal(raw, &token)
}
}
if token.AccessToken == "" {
token.AccessToken = stringValue(metadata, "access_token")
}
if token.RefreshToken == "" {
token.RefreshToken = stringValue(metadata, "refresh_token")
}
if token.TokenType == "" {
token.TokenType = stringValue(metadata, "token_type")
}
if token.Expiry.IsZero() {
if expiry := stringValue(metadata, "expiry"); expiry != "" {
if ts, err := time.Parse(time.RFC3339, expiry); err == nil {
token.Expiry = ts
}
}
}
conf := &oauth2.Config{
ClientID: geminiOAuthClientID,
ClientSecret: geminiOAuthClientSecret,
Scopes: geminiOAuthScopes,
Endpoint: google.Endpoint,
}
ctxToken := ctx
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
ctxToken = context.WithValue(ctxToken, oauth2.HTTPClient, httpClient)
}
src := conf.TokenSource(ctxToken, &token)
currentToken, err := src.Token()
if err != nil {
return nil, nil, err
}
updateGeminiCLITokenMetadata(auth, base, currentToken)
return oauth2.ReuseTokenSource(currentToken, src), base, nil
}
func updateGeminiCLITokenMetadata(auth *cliproxyauth.Auth, base map[string]any, tok *oauth2.Token) {
if auth == nil || tok == nil {
return
}
merged := buildGeminiTokenMap(base, tok)
fields := buildGeminiTokenFields(tok, merged)
shared := geminicli.ResolveSharedCredential(auth.Runtime)
if shared != nil {
snapshot := shared.MergeMetadata(fields)
if !geminicli.IsVirtual(auth.Runtime) {
auth.Metadata = snapshot
}
return
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
for k, v := range fields {
auth.Metadata[k] = v
}
}
func buildGeminiTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
merged := cloneMap(base)
if merged == nil {
merged = make(map[string]any)
}
if raw, err := json.Marshal(tok); err == nil {
var tokenMap map[string]any
if err = json.Unmarshal(raw, &tokenMap); err == nil {
for k, v := range tokenMap {
merged[k] = v
}
}
}
return merged
}
func buildGeminiTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
fields := make(map[string]any, 5)
if tok.AccessToken != "" {
fields["access_token"] = tok.AccessToken
}
if tok.TokenType != "" {
fields["token_type"] = tok.TokenType
}
if tok.RefreshToken != "" {
fields["refresh_token"] = tok.RefreshToken
}
if !tok.Expiry.IsZero() {
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
}
if len(merged) > 0 {
fields["token"] = cloneMap(merged)
}
return fields
}
func resolveGeminiProjectID(auth *cliproxyauth.Auth) string {
if auth == nil {
return ""
}
if runtime := auth.Runtime; runtime != nil {
if virtual, ok := runtime.(*geminicli.VirtualCredential); ok && virtual != nil {
return strings.TrimSpace(virtual.ProjectID)
}
}
return strings.TrimSpace(stringValue(auth.Metadata, "project_id"))
}
func geminiOAuthMetadata(auth *cliproxyauth.Auth) map[string]any {
if auth == nil {
return nil
}
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
if snapshot := shared.MetadataSnapshot(); len(snapshot) > 0 {
return snapshot
}
}
return auth.Metadata
}
func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
return newProxyAwareHTTPClient(ctx, cfg, auth, timeout)
}
func cloneMap(in map[string]any) map[string]any {
if in == nil {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func stringValue(m map[string]any, key string) string {
if m == nil {
return ""
}
if v, ok := m[key]; ok {
switch typed := v.(type) {
case string:
return typed
case fmt.Stringer:
return typed.String()
}
}
return ""
}
// applyGeminiCLIHeaders sets required headers for the Gemini CLI upstream.
// User-Agent is always forced to the GeminiCLI format regardless of the client's value,
// so that upstream identifies the request as a native GeminiCLI client.
func applyGeminiCLIHeaders(r *http.Request, model string) {
r.Header.Set("User-Agent", misc.GeminiCLIUserAgent(model))
r.Header.Set("X-Goog-Api-Client", misc.GeminiCLIApiClientHeader)
}
// cliPreviewFallbackOrder returns preview model candidates for a base model.
func cliPreviewFallbackOrder(model string) []string {
switch model {
case "gemini-2.5-pro":
return []string{
// "gemini-2.5-pro-preview-05-06",
// "gemini-2.5-pro-preview-06-05",
}
case "gemini-2.5-flash":
return []string{
// "gemini-2.5-flash-preview-04-17",
// "gemini-2.5-flash-preview-05-20",
}
case "gemini-2.5-flash-lite":
return []string{
// "gemini-2.5-flash-lite-preview-06-17",
}
default:
return nil
}
}
// setJSONField sets a top-level JSON field on a byte slice payload via sjson.
func setJSONField(body []byte, key, value string) []byte {
if key == "" {
return body
}
updated, err := sjson.SetBytes(body, key, value)
if err != nil {
return body
}
return updated
}
// deleteJSONField removes a top-level key if present (best-effort) via sjson.
func deleteJSONField(body []byte, key string) []byte {
if key == "" || len(body) == 0 {
return body
}
updated, err := sjson.DeleteBytes(body, key)
if err != nil {
return body
}
return updated
}
func fixGeminiCLIImageAspectRatio(modelName string, rawJSON []byte) []byte {
if modelName == "gemini-2.5-flash-image-preview" {
aspectRatioResult := gjson.GetBytes(rawJSON, "request.generationConfig.imageConfig.aspectRatio")
if aspectRatioResult.Exists() {
contents := gjson.GetBytes(rawJSON, "request.contents")
contentArray := contents.Array()
if len(contentArray) > 0 {
hasInlineData := false
loopContent:
for i := 0; i < len(contentArray); i++ {
parts := contentArray[i].Get("parts").Array()
for j := 0; j < len(parts); j++ {
if parts[j].Get("inlineData").Exists() {
hasInlineData = true
break loopContent
}
}
}
if !hasInlineData {
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := `[]`
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
parts := contentArray[0].Get("parts").Array()
for j := 0; j < len(parts); j++ {
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
}
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
}
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "request.generationConfig.imageConfig")
}
}
return rawJSON
}
func newGeminiStatusErr(statusCode int, body []byte) statusErr {
err := statusErr{code: statusCode, msg: string(body)}
if statusCode == http.StatusTooManyRequests {
if retryAfter, parseErr := parseRetryDelay(body); parseErr == nil && retryAfter != nil {
err.retryAfter = retryAfter
}
}
return err
}
// parseRetryDelay extracts the retry delay from a Google API 429 error response.
// The error response contains a RetryInfo.retryDelay field in the format "0.847655010s".
// Returns the parsed duration or an error if it cannot be determined.
func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
// Try to parse the retryDelay from the error response
// Format: error.details[].retryDelay where @type == "type.googleapis.com/google.rpc.RetryInfo"
details := gjson.GetBytes(errorBody, "error.details")
if details.Exists() && details.IsArray() {
for _, detail := range details.Array() {
typeVal := detail.Get("@type").String()
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
retryDelay := detail.Get("retryDelay").String()
if retryDelay != "" {
// Parse duration string like "0.847655010s"
duration, err := time.ParseDuration(retryDelay)
if err != nil {
return nil, fmt.Errorf("failed to parse duration")
}
return &duration, nil
}
}
}
// Fallback: try ErrorInfo.metadata.quotaResetDelay (e.g., "373.801628ms")
for _, detail := range details.Array() {
typeVal := detail.Get("@type").String()
if typeVal == "type.googleapis.com/google.rpc.ErrorInfo" {
quotaResetDelay := detail.Get("metadata.quotaResetDelay").String()
if quotaResetDelay != "" {
duration, err := time.ParseDuration(quotaResetDelay)
if err == nil {
return &duration, nil
}
}
}
}
}
// Fallback: parse from error.message "Your quota will reset after Xs."
message := gjson.GetBytes(errorBody, "error.message").String()
if message != "" {
re := regexp.MustCompile(`after\s+(\d+)s\.?`)
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
seconds, err := strconv.Atoi(matches[1])
if err == nil {
return new(time.Duration(seconds) * time.Second), nil
}
}
}
return nil, fmt.Errorf("no RetryInfo found")
}
================================================
FILE: internal/runtime/executor/gemini_executor.go
================================================
// Package executor provides runtime execution capabilities for various AI service providers.
// It includes stateless executors that handle API requests, streaming responses,
// token counting, and authentication refresh for different AI service providers.
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
// glEndpoint is the base URL for the Google Generative Language API.
glEndpoint = "https://generativelanguage.googleapis.com"
// glAPIVersion is the API version used for Gemini requests.
glAPIVersion = "v1beta"
// streamScannerBuffer is the buffer size for SSE stream scanning.
streamScannerBuffer = 52_428_800
)
// GeminiExecutor is a stateless executor for the official Gemini API using API keys.
// It handles both API key and OAuth bearer token authentication, supporting both
// regular and streaming requests to the Google Generative Language API.
type GeminiExecutor struct {
// cfg holds the application configuration.
cfg *config.Config
}
// NewGeminiExecutor creates a new Gemini executor instance.
//
// Parameters:
// - cfg: The application configuration
//
// Returns:
// - *GeminiExecutor: A new Gemini executor instance
func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor {
return &GeminiExecutor{cfg: cfg}
}
// Identifier returns the executor identifier.
func (e *GeminiExecutor) Identifier() string { return "gemini" }
// PrepareRequest injects Gemini credentials into the outgoing HTTP request.
func (e *GeminiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, bearer := geminiCreds(auth)
if apiKey != "" {
req.Header.Set("x-goog-api-key", apiKey)
req.Header.Del("Authorization")
} else if bearer != "" {
req.Header.Set("Authorization", "Bearer "+bearer)
req.Header.Del("x-goog-api-key")
}
applyGeminiHeaders(req, auth)
return nil
}
// HttpRequest injects Gemini credentials into the request and executes it.
func (e *GeminiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("gemini executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request to the Gemini API.
// It translates the request to Gemini format, sends it to the API, and translates
// the response back to the requested format.
//
// Parameters:
// - ctx: The context for the request
// - auth: The authentication information
// - req: The request to execute
// - opts: Additional execution options
//
// Returns:
// - cliproxyexecutor.Response: The response from the API
// - error: An error if the request fails
func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
// Official Gemini API via API key or OAuth bearer
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := "generateContent"
if req.Metadata != nil {
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
action = "countTokens"
}
}
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return resp, err
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
} else if bearer != "" {
httpReq.Header.Set("Authorization", "Bearer "+bearer)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming request to the Gemini API.
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, bearer := geminiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
} else {
httpReq.Header.Set("Authorization", "Bearer "+bearer)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini executor: close response body error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("gemini executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
filtered := FilterSSEUsageMetadata(line)
payload := jsonPayload(filtered)
if len(payload) == 0 {
continue
}
if detail, ok := parseGeminiStreamUsage(payload); ok {
reporter.publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens counts tokens for the given request using the Gemini API.
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, bearer := geminiCreds(auth)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
baseURL := resolveGeminiBaseURL(auth)
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, baseModel, "countTokens")
requestBody := bytes.NewReader(translatedReq)
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, requestBody)
if err != nil {
return cliproxyexecutor.Response{}, err
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
} else {
httpReq.Header.Set("Authorization", "Bearer "+bearer)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
resp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
defer func() { _ = resp.Body.Close() }()
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
data, err := io.ReadAll(resp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(translated), Headers: resp.Header.Clone()}, nil
}
// Refresh refreshes the authentication credentials (no-op for Gemini API key).
func (e *GeminiExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}
func geminiCreds(a *cliproxyauth.Auth) (apiKey, bearer string) {
if a == nil {
return "", ""
}
if a.Attributes != nil {
if v := a.Attributes["api_key"]; v != "" {
apiKey = v
}
}
if a.Metadata != nil {
// GeminiTokenStorage.Token is a map that may contain access_token
if v, ok := a.Metadata["access_token"].(string); ok && v != "" {
bearer = v
}
if token, ok := a.Metadata["token"].(map[string]any); ok && token != nil {
if v, ok2 := token["access_token"].(string); ok2 && v != "" {
bearer = v
}
}
}
return
}
func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
base := glEndpoint
if auth != nil && auth.Attributes != nil {
if custom := strings.TrimSpace(auth.Attributes["base_url"]); custom != "" {
base = strings.TrimRight(custom, "/")
}
}
if base == "" {
return glEndpoint
}
return base
}
func (e *GeminiExecutor) resolveGeminiConfig(auth *cliproxyauth.Auth) *config.GeminiKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.GeminiKey {
entry := &e.cfg.GeminiKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
}
func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
if modelName == "gemini-2.5-flash-image-preview" {
aspectRatioResult := gjson.GetBytes(rawJSON, "generationConfig.imageConfig.aspectRatio")
if aspectRatioResult.Exists() {
contents := gjson.GetBytes(rawJSON, "contents")
contentArray := contents.Array()
if len(contentArray) > 0 {
hasInlineData := false
loopContent:
for i := 0; i < len(contentArray); i++ {
parts := contentArray[i].Get("parts").Array()
for j := 0; j < len(parts); j++ {
if parts[j].Get("inlineData").Exists() {
hasInlineData = true
break loopContent
}
}
}
if !hasInlineData {
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
newPartsJson := `[]`
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
parts := contentArray[0].Get("parts").Array()
for j := 0; j < len(parts); j++ {
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
}
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
}
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "generationConfig.imageConfig")
}
}
return rawJSON
}
================================================
FILE: internal/runtime/executor/gemini_vertex_executor.go
================================================
// Package executor provides runtime execution capabilities for various AI service providers.
// This file implements the Vertex AI Gemini executor that talks to Google Vertex AI
// endpoints using service account credentials or API keys.
package executor
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const (
// vertexAPIVersion aligns with current public Vertex Generative AI API.
vertexAPIVersion = "v1"
)
// isImagenModel checks if the model name is an Imagen image generation model.
// Imagen models use the :predict action instead of :generateContent.
func isImagenModel(model string) bool {
lowerModel := strings.ToLower(model)
return strings.Contains(lowerModel, "imagen")
}
// getVertexAction returns the appropriate action for the given model.
// Imagen models use "predict", while Gemini models use "generateContent".
func getVertexAction(model string, isStream bool) string {
if isImagenModel(model) {
return "predict"
}
if isStream {
return "streamGenerateContent"
}
return "generateContent"
}
// convertImagenToGeminiResponse converts Imagen API response to Gemini format
// so it can be processed by the standard translation pipeline.
// This ensures Imagen models return responses in the same format as gemini-3-pro-image-preview.
func convertImagenToGeminiResponse(data []byte, model string) []byte {
predictions := gjson.GetBytes(data, "predictions")
if !predictions.Exists() || !predictions.IsArray() {
return data
}
// Build Gemini-compatible response with inlineData
parts := make([]map[string]any, 0)
for _, pred := range predictions.Array() {
imageData := pred.Get("bytesBase64Encoded").String()
mimeType := pred.Get("mimeType").String()
if mimeType == "" {
mimeType = "image/png"
}
if imageData != "" {
parts = append(parts, map[string]any{
"inlineData": map[string]any{
"mimeType": mimeType,
"data": imageData,
},
})
}
}
// Generate unique response ID using timestamp
responseId := fmt.Sprintf("imagen-%d", time.Now().UnixNano())
response := map[string]any{
"candidates": []map[string]any{{
"content": map[string]any{
"parts": parts,
"role": "model",
},
"finishReason": "STOP",
}},
"responseId": responseId,
"modelVersion": model,
// Imagen API doesn't return token counts, set to 0 for tracking purposes
"usageMetadata": map[string]any{
"promptTokenCount": 0,
"candidatesTokenCount": 0,
"totalTokenCount": 0,
},
}
result, err := json.Marshal(response)
if err != nil {
return data
}
return result
}
// convertToImagenRequest converts a Gemini-style request to Imagen API format.
// Imagen API uses a different structure: instances[].prompt instead of contents[].
func convertToImagenRequest(payload []byte) ([]byte, error) {
// Extract prompt from Gemini-style contents
prompt := ""
// Try to get prompt from contents[0].parts[0].text
contentsText := gjson.GetBytes(payload, "contents.0.parts.0.text")
if contentsText.Exists() {
prompt = contentsText.String()
}
// If no contents, try messages format (OpenAI-compatible)
if prompt == "" {
messagesText := gjson.GetBytes(payload, "messages.#.content")
if messagesText.Exists() && messagesText.IsArray() {
for _, msg := range messagesText.Array() {
if msg.String() != "" {
prompt = msg.String()
break
}
}
}
}
// If still no prompt, try direct prompt field
if prompt == "" {
directPrompt := gjson.GetBytes(payload, "prompt")
if directPrompt.Exists() {
prompt = directPrompt.String()
}
}
if prompt == "" {
return nil, fmt.Errorf("imagen: no prompt found in request")
}
// Build Imagen API request
imagenReq := map[string]any{
"instances": []map[string]any{
{
"prompt": prompt,
},
},
"parameters": map[string]any{
"sampleCount": 1,
},
}
// Extract optional parameters
if aspectRatio := gjson.GetBytes(payload, "aspectRatio"); aspectRatio.Exists() {
imagenReq["parameters"].(map[string]any)["aspectRatio"] = aspectRatio.String()
}
if sampleCount := gjson.GetBytes(payload, "sampleCount"); sampleCount.Exists() {
imagenReq["parameters"].(map[string]any)["sampleCount"] = int(sampleCount.Int())
}
if negativePrompt := gjson.GetBytes(payload, "negativePrompt"); negativePrompt.Exists() {
imagenReq["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt.String()
}
return json.Marshal(imagenReq)
}
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
type GeminiVertexExecutor struct {
cfg *config.Config
}
// NewGeminiVertexExecutor creates a new Vertex AI Gemini executor instance.
//
// Parameters:
// - cfg: The application configuration
//
// Returns:
// - *GeminiVertexExecutor: A new Vertex AI Gemini executor instance
func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
return &GeminiVertexExecutor{cfg: cfg}
}
// Identifier returns the executor identifier.
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
// PrepareRequest injects Vertex credentials into the outgoing HTTP request.
func (e *GeminiVertexExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, _ := vertexAPICreds(auth)
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("x-goog-api-key", apiKey)
req.Header.Del("Authorization")
return nil
}
_, _, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return errCreds
}
token, errToken := vertexAccessToken(req.Context(), e.cfg, auth, saJSON)
if errToken != nil {
return errToken
}
if strings.TrimSpace(token) == "" {
return statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
}
req.Header.Set("Authorization", "Bearer "+token)
req.Header.Del("x-goog-api-key")
return nil
}
// HttpRequest injects Vertex credentials into the request and executes it.
func (e *GeminiVertexExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("vertex executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming request to the Vertex AI API.
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Try API key authentication first
apiKey, baseURL := vertexAPICreds(auth)
// If no API key found, fall back to service account authentication
if apiKey == "" {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return resp, errCreds
}
return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
}
// Use API key authentication
return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
}
// ExecuteStream performs a streaming request to the Vertex AI API.
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Try API key authentication first
apiKey, baseURL := vertexAPICreds(auth)
// If no API key found, fall back to service account authentication
if apiKey == "" {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return nil, errCreds
}
return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
}
// Use API key authentication
return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
}
// CountTokens counts tokens for the given request using the Vertex AI API.
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
// Try API key authentication first
apiKey, baseURL := vertexAPICreds(auth)
// If no API key found, fall back to service account authentication
if apiKey == "" {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return cliproxyexecutor.Response{}, errCreds
}
return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
}
// Use API key authentication
return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
}
// Refresh refreshes the authentication credentials (no-op for Vertex).
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}
// executeWithServiceAccount handles authentication using service account credentials.
// This method contains the original service account authentication logic.
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
var body []byte
// Handle Imagen models with special request format
if isImagenModel(baseModel) {
imagenBody, errImagen := convertToImagenRequest(req.Payload)
if errImagen != nil {
return resp, errImagen
}
body = imagenBody
} else {
// Standard Gemini translation flow
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body = sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
}
action := getVertexAction(baseModel, false)
if req.Metadata != nil {
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
action = "countTokens"
}
}
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if errNewReq != nil {
return resp, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return resp, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
// For Imagen models, convert response to Gemini format before translation
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
if isImagenModel(baseModel) {
data = convertImagenToGeminiResponse(data, baseModel)
}
// Standard Gemini translation (works for both Gemini and converted Imagen responses)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// executeWithAPIKey handles authentication using API key credentials.
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, false)
if req.Metadata != nil {
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
action = "countTokens"
}
}
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if errNewReq != nil {
return resp, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
// Imagen models don't support streaming, skip SSE params
if !isImagenModel(baseModel) {
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if errNewReq != nil {
return nil, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return nil, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
body = fixGeminiImageAspectRatio(baseModel, body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, _ = sjson.SetBytes(body, "model", baseModel)
action := getVertexAction(baseModel, true)
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
// Imagen models don't support streaming, skip SSE params
if !isImagenModel(baseModel) {
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if errNewReq != nil {
return nil, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, streamScannerBuffer)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// countTokensWithServiceAccount counts tokens using service account credentials.
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// countTokensWithAPIKey handles token counting using API key credentials.
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
translatedReq = fixGeminiImageAspectRatio(baseModel, translatedReq)
translatedReq, _ = sjson.SetBytes(translatedReq, "model", baseModel)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
// For API key auth, use simpler URL format without project/location
if baseURL == "" {
baseURL = "https://aiplatform.googleapis.com"
}
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("x-goog-api-key", apiKey)
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) {
if a == nil || a.Metadata == nil {
return "", "", nil, fmt.Errorf("vertex executor: missing auth metadata")
}
if v, ok := a.Metadata["project_id"].(string); ok {
projectID = strings.TrimSpace(v)
}
if projectID == "" {
// Some service accounts may use "project"; still prefer standard field
if v, ok := a.Metadata["project"].(string); ok {
projectID = strings.TrimSpace(v)
}
}
if projectID == "" {
return "", "", nil, fmt.Errorf("vertex executor: missing project_id in credentials")
}
if v, ok := a.Metadata["location"].(string); ok && strings.TrimSpace(v) != "" {
location = strings.TrimSpace(v)
} else {
location = "us-central1"
}
var sa map[string]any
if raw, ok := a.Metadata["service_account"].(map[string]any); ok {
sa = raw
}
if sa == nil {
return "", "", nil, fmt.Errorf("vertex executor: missing service_account in credentials")
}
normalized, errNorm := vertexauth.NormalizeServiceAccountMap(sa)
if errNorm != nil {
return "", "", nil, fmt.Errorf("vertex executor: %w", errNorm)
}
saJSON, errMarshal := json.Marshal(normalized)
if errMarshal != nil {
return "", "", nil, fmt.Errorf("vertex executor: marshal service_account failed: %w", errMarshal)
}
return projectID, location, saJSON, nil
}
// vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern.
func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
if a == nil {
return "", ""
}
if a.Attributes != nil {
apiKey = a.Attributes["api_key"]
baseURL = a.Attributes["base_url"]
}
if apiKey == "" && a.Metadata != nil {
if v, ok := a.Metadata["access_token"].(string); ok {
apiKey = v
}
}
return
}
func vertexBaseURL(location string) string {
loc := strings.TrimSpace(location)
if loc == "" {
loc = "us-central1"
} else if loc == "global" {
return "https://aiplatform.googleapis.com"
}
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc)
}
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
// Use cloud-platform scope for Vertex AI.
creds, errCreds := google.CredentialsFromJSON(ctx, saJSON, "https://www.googleapis.com/auth/cloud-platform")
if errCreds != nil {
return "", fmt.Errorf("vertex executor: parse service account json failed: %w", errCreds)
}
tok, errTok := creds.TokenSource.Token()
if errTok != nil {
return "", fmt.Errorf("vertex executor: get access token failed: %w", errTok)
}
return tok.AccessToken, nil
}
// resolveVertexConfig finds the matching vertex-api-key configuration entry for the given auth.
func (e *GeminiVertexExecutor) resolveVertexConfig(auth *cliproxyauth.Auth) *config.VertexCompatKey {
if auth == nil || e.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range e.cfg.VertexCompatAPIKey {
entry := &e.cfg.VertexCompatAPIKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
================================================
FILE: internal/runtime/executor/iflow_executor.go
================================================
package executor
import (
"bufio"
"bytes"
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/google/uuid"
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
iflowDefaultEndpoint = "/chat/completions"
iflowUserAgent = "iFlow-Cli"
)
// IFlowExecutor executes OpenAI-compatible chat completions against the iFlow API using API keys derived from OAuth.
type IFlowExecutor struct {
cfg *config.Config
}
// NewIFlowExecutor constructs a new executor instance.
func NewIFlowExecutor(cfg *config.Config) *IFlowExecutor { return &IFlowExecutor{cfg: cfg} }
// Identifier returns the provider key.
func (e *IFlowExecutor) Identifier() string { return "iflow" }
// PrepareRequest injects iFlow credentials into the outgoing HTTP request.
func (e *IFlowExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
apiKey, _ := iflowCreds(auth)
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
return nil
}
// HttpRequest injects iFlow credentials into the request and executes it.
func (e *IFlowExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("iflow executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming chat completion request.
func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := iflowCreds(auth)
if strings.TrimSpace(apiKey) == "" {
err = fmt.Errorf("iflow executor: missing api key")
return resp, err
}
if baseURL == "" {
baseURL = iflowauth.DefaultAPIBaseURL
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
if err != nil {
return resp, err
}
body = preserveReasoningContentInMessages(body)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return resp, err
}
applyIFlowHeaders(httpReq, apiKey, false)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("iflow executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
// Ensure usage is recorded even if upstream omits usage metadata.
reporter.ensurePublished(ctx)
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request.
func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
apiKey, baseURL := iflowCreds(auth)
if strings.TrimSpace(apiKey) == "" {
err = fmt.Errorf("iflow executor: missing api key")
return nil, err
}
if baseURL == "" {
baseURL = iflowauth.DefaultAPIBaseURL
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
if err != nil {
return nil, err
}
body = preserveReasoningContentInMessages(body)
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
toolsResult := gjson.GetBytes(body, "tools")
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
body = ensureToolsArray(body)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, bytes.NewReader(body))
if err != nil {
return nil, err
}
applyIFlowHeaders(httpReq, apiKey, true)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: endpoint,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
data, _ := io.ReadAll(httpResp.Body)
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("iflow executor: close response body error: %v", errClose)
}
appendAPIResponseChunk(ctx, e.cfg, data)
logWithRequestID(ctx).Debugf("request error, error status: %d error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
err = statusErr{code: httpResp.StatusCode, msg: string(data)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("iflow executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
// Guarantee a usage record exists even if the stream never emitted usage data.
reporter.ensurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
enc, err := tokenizerForModel(baseModel)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("iflow executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
// Refresh refreshes OAuth tokens or cookie-based API keys and updates the stored API key.
func (e *IFlowExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("iflow executor: refresh called")
if auth == nil {
return nil, fmt.Errorf("iflow executor: auth is nil")
}
// Check if this is cookie-based authentication
var cookie string
var email string
if auth.Metadata != nil {
if v, ok := auth.Metadata["cookie"].(string); ok {
cookie = strings.TrimSpace(v)
}
if v, ok := auth.Metadata["email"].(string); ok {
email = strings.TrimSpace(v)
}
}
// If cookie is present, use cookie-based refresh
if cookie != "" && email != "" {
return e.refreshCookieBased(ctx, auth, cookie, email)
}
// Otherwise, use OAuth-based refresh
return e.refreshOAuthBased(ctx, auth)
}
// refreshCookieBased refreshes API key using browser cookie
func (e *IFlowExecutor) refreshCookieBased(ctx context.Context, auth *cliproxyauth.Auth, cookie, email string) (*cliproxyauth.Auth, error) {
log.Debugf("iflow executor: checking refresh need for cookie-based API key for user: %s", email)
// Get current expiry time from metadata
var currentExpire string
if auth.Metadata != nil {
if v, ok := auth.Metadata["expired"].(string); ok {
currentExpire = strings.TrimSpace(v)
}
}
// Check if refresh is needed
needsRefresh, _, err := iflowauth.ShouldRefreshAPIKey(currentExpire)
if err != nil {
log.Warnf("iflow executor: failed to check refresh need: %v", err)
// If we can't check, continue with refresh anyway as a safety measure
} else if !needsRefresh {
log.Debugf("iflow executor: no refresh needed for user: %s", email)
return auth, nil
}
log.Infof("iflow executor: refreshing cookie-based API key for user: %s", email)
svc := iflowauth.NewIFlowAuth(e.cfg)
keyData, err := svc.RefreshAPIKey(ctx, cookie, email)
if err != nil {
log.Errorf("iflow executor: cookie-based API key refresh failed: %v", err)
return nil, err
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["api_key"] = keyData.APIKey
auth.Metadata["expired"] = keyData.ExpireTime
auth.Metadata["type"] = "iflow"
auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339)
auth.Metadata["cookie"] = cookie
auth.Metadata["email"] = email
log.Infof("iflow executor: cookie-based API key refreshed successfully, new expiry: %s", keyData.ExpireTime)
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
auth.Attributes["api_key"] = keyData.APIKey
return auth, nil
}
// refreshOAuthBased refreshes tokens using OAuth refresh token
func (e *IFlowExecutor) refreshOAuthBased(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
refreshToken := ""
oldAccessToken := ""
if auth.Metadata != nil {
if v, ok := auth.Metadata["refresh_token"].(string); ok {
refreshToken = strings.TrimSpace(v)
}
if v, ok := auth.Metadata["access_token"].(string); ok {
oldAccessToken = strings.TrimSpace(v)
}
}
if refreshToken == "" {
return auth, nil
}
// Log the old access token (masked) before refresh
if oldAccessToken != "" {
log.Debugf("iflow executor: refreshing access token, old: %s", util.HideAPIKey(oldAccessToken))
}
svc := iflowauth.NewIFlowAuth(e.cfg)
tokenData, err := svc.RefreshTokens(ctx, refreshToken)
if err != nil {
log.Errorf("iflow executor: token refresh failed: %v", err)
return nil, err
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["access_token"] = tokenData.AccessToken
if tokenData.RefreshToken != "" {
auth.Metadata["refresh_token"] = tokenData.RefreshToken
}
if tokenData.APIKey != "" {
auth.Metadata["api_key"] = tokenData.APIKey
}
auth.Metadata["expired"] = tokenData.Expire
auth.Metadata["type"] = "iflow"
auth.Metadata["last_refresh"] = time.Now().Format(time.RFC3339)
// Log the new access token (masked) after successful refresh
log.Debugf("iflow executor: token refresh successful, new: %s", util.HideAPIKey(tokenData.AccessToken))
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
if tokenData.APIKey != "" {
auth.Attributes["api_key"] = tokenData.APIKey
}
return auth, nil
}
func applyIFlowHeaders(r *http.Request, apiKey string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+apiKey)
r.Header.Set("User-Agent", iflowUserAgent)
// Generate session-id
sessionID := "session-" + generateUUID()
r.Header.Set("session-id", sessionID)
// Generate timestamp and signature
timestamp := time.Now().UnixMilli()
r.Header.Set("x-iflow-timestamp", fmt.Sprintf("%d", timestamp))
signature := createIFlowSignature(iflowUserAgent, sessionID, timestamp, apiKey)
if signature != "" {
r.Header.Set("x-iflow-signature", signature)
}
if stream {
r.Header.Set("Accept", "text/event-stream")
} else {
r.Header.Set("Accept", "application/json")
}
}
// createIFlowSignature generates HMAC-SHA256 signature for iFlow API requests.
// The signature payload format is: userAgent:sessionId:timestamp
func createIFlowSignature(userAgent, sessionID string, timestamp int64, apiKey string) string {
if apiKey == "" {
return ""
}
payload := fmt.Sprintf("%s:%s:%d", userAgent, sessionID, timestamp)
h := hmac.New(sha256.New, []byte(apiKey))
h.Write([]byte(payload))
return hex.EncodeToString(h.Sum(nil))
}
// generateUUID generates a random UUID v4 string.
func generateUUID() string {
return uuid.New().String()
}
func iflowCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
if a == nil {
return "", ""
}
if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["api_key"]); v != "" {
apiKey = v
}
if v := strings.TrimSpace(a.Attributes["base_url"]); v != "" {
baseURL = v
}
}
if apiKey == "" && a.Metadata != nil {
if v, ok := a.Metadata["api_key"].(string); ok {
apiKey = strings.TrimSpace(v)
}
}
if baseURL == "" && a.Metadata != nil {
if v, ok := a.Metadata["base_url"].(string); ok {
baseURL = strings.TrimSpace(v)
}
}
return apiKey, baseURL
}
func ensureToolsArray(body []byte) []byte {
placeholder := `[{"type":"function","function":{"name":"noop","description":"Placeholder tool to stabilise streaming","parameters":{"type":"object"}}}]`
updated, err := sjson.SetRawBytes(body, "tools", []byte(placeholder))
if err != nil {
return body
}
return updated
}
// preserveReasoningContentInMessages checks if reasoning_content from assistant messages
// is preserved in conversation history for iFlow models that support thinking.
// This is helpful for multi-turn conversations where the model may benefit from seeing
// its previous reasoning to maintain coherent thought chains.
//
// For GLM-4.6/4.7 and MiniMax M2/M2.1, it is recommended to include the full assistant
// response (including reasoning_content) in message history for better context continuity.
func preserveReasoningContentInMessages(body []byte) []byte {
model := strings.ToLower(gjson.GetBytes(body, "model").String())
// Only apply to models that support thinking with history preservation
needsPreservation := strings.HasPrefix(model, "glm-4") || strings.HasPrefix(model, "minimax-m2")
if !needsPreservation {
return body
}
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body
}
// Check if any assistant message already has reasoning_content preserved
hasReasoningContent := false
messages.ForEach(func(_, msg gjson.Result) bool {
role := msg.Get("role").String()
if role == "assistant" {
rc := msg.Get("reasoning_content")
if rc.Exists() && rc.String() != "" {
hasReasoningContent = true
return false // stop iteration
}
}
return true
})
// If reasoning content is already present, the messages are properly formatted
// No need to modify - the client has correctly preserved reasoning in history
if hasReasoningContent {
log.Debugf("iflow executor: reasoning_content found in message history for %s", model)
}
return body
}
================================================
FILE: internal/runtime/executor/iflow_executor_test.go
================================================
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
)
func TestIFlowExecutorParseSuffix(t *testing.T) {
tests := []struct {
name string
model string
wantBase string
wantLevel string
}{
{"no suffix", "glm-4", "glm-4", ""},
{"glm with suffix", "glm-4.1-flash(high)", "glm-4.1-flash", "high"},
{"minimax no suffix", "minimax-m2", "minimax-m2", ""},
{"minimax with suffix", "minimax-m2.1(medium)", "minimax-m2.1", "medium"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := thinking.ParseSuffix(tt.model)
if result.ModelName != tt.wantBase {
t.Errorf("ParseSuffix(%q).ModelName = %q, want %q", tt.model, result.ModelName, tt.wantBase)
}
})
}
}
func TestPreserveReasoningContentInMessages(t *testing.T) {
tests := []struct {
name string
input []byte
want []byte // nil means output should equal input
}{
{
"non-glm model passthrough",
[]byte(`{"model":"gpt-4","messages":[]}`),
nil,
},
{
"glm model with empty messages",
[]byte(`{"model":"glm-4","messages":[]}`),
nil,
},
{
"glm model preserves existing reasoning_content",
[]byte(`{"model":"glm-4","messages":[{"role":"assistant","content":"hi","reasoning_content":"thinking..."}]}`),
nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := preserveReasoningContentInMessages(tt.input)
want := tt.want
if want == nil {
want = tt.input
}
if string(got) != string(want) {
t.Errorf("preserveReasoningContentInMessages() = %s, want %s", got, want)
}
})
}
}
================================================
FILE: internal/runtime/executor/kimi_executor.go
================================================
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"runtime"
"strings"
"time"
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// KimiExecutor is a stateless executor for Kimi API using OpenAI-compatible chat completions.
type KimiExecutor struct {
ClaudeExecutor
cfg *config.Config
}
// NewKimiExecutor creates a new Kimi executor.
func NewKimiExecutor(cfg *config.Config) *KimiExecutor { return &KimiExecutor{cfg: cfg} }
// Identifier returns the executor identifier.
func (e *KimiExecutor) Identifier() string { return "kimi" }
// PrepareRequest injects Kimi credentials into the outgoing HTTP request.
func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
token := kimiCreds(auth)
if strings.TrimSpace(token) != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
return nil
}
// HttpRequest injects Kimi credentials into the request and executes it.
func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("kimi executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
// Execute performs a non-streaming chat completion request to Kimi.
func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
from := opts.SourceFormat
if from.String() == "claude" {
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
return e.ClaudeExecutor.Execute(ctx, auth, req, opts)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token := kimiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := bytes.Clone(originalPayloadSource)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
// Strip kimi- prefix for upstream API
upstreamModel := stripKimiPrefix(baseModel)
body, err = sjson.SetBytes(body, "model", upstreamModel)
if err != nil {
return resp, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
}
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
if err != nil {
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = normalizeKimiToolMessageLinks(body)
if err != nil {
return resp, err
}
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return resp, err
}
applyKimiHeadersWithAuth(httpReq, token, false, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("kimi executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
// ExecuteStream performs a streaming chat completion request to Kimi.
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
from := opts.SourceFormat
if from.String() == "claude" {
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
return e.ClaudeExecutor.ExecuteStream(ctx, auth, req, opts)
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token := kimiCreds(auth)
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := bytes.Clone(originalPayloadSource)
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
// Strip kimi- prefix for upstream API
upstreamModel := stripKimiPrefix(baseModel)
body, err = sjson.SetBytes(body, "model", upstreamModel)
if err != nil {
return nil, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
}
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
if err != nil {
return nil, err
}
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
if err != nil {
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
body, err = normalizeKimiToolMessageLinks(body)
if err != nil {
return nil, err
}
url := kimiauth.KimiAPIBaseURL + "/v1/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
}
applyKimiHeadersWithAuth(httpReq, token, true, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("kimi executor: close response body error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("kimi executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 1_048_576) // 1MB
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
// CountTokens estimates token count for Kimi requests.
func (e *KimiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
auth.Attributes["base_url"] = kimiauth.KimiAPIBaseURL
return e.ClaudeExecutor.CountTokens(ctx, auth, req, opts)
}
func normalizeKimiToolMessageLinks(body []byte) ([]byte, error) {
if len(body) == 0 || !gjson.ValidBytes(body) {
return body, nil
}
messages := gjson.GetBytes(body, "messages")
if !messages.Exists() || !messages.IsArray() {
return body, nil
}
out := body
pending := make([]string, 0)
patched := 0
patchedReasoning := 0
ambiguous := 0
latestReasoning := ""
hasLatestReasoning := false
removePending := func(id string) {
for idx := range pending {
if pending[idx] != id {
continue
}
pending = append(pending[:idx], pending[idx+1:]...)
return
}
}
msgs := messages.Array()
for msgIdx := range msgs {
msg := msgs[msgIdx]
role := strings.TrimSpace(msg.Get("role").String())
switch role {
case "assistant":
reasoning := msg.Get("reasoning_content")
if reasoning.Exists() {
reasoningText := reasoning.String()
if strings.TrimSpace(reasoningText) != "" {
latestReasoning = reasoningText
hasLatestReasoning = true
}
}
toolCalls := msg.Get("tool_calls")
if !toolCalls.Exists() || !toolCalls.IsArray() || len(toolCalls.Array()) == 0 {
continue
}
if !reasoning.Exists() || strings.TrimSpace(reasoning.String()) == "" {
reasoningText := fallbackAssistantReasoning(msg, hasLatestReasoning, latestReasoning)
path := fmt.Sprintf("messages.%d.reasoning_content", msgIdx)
next, err := sjson.SetBytes(out, path, reasoningText)
if err != nil {
return body, fmt.Errorf("kimi executor: failed to set assistant reasoning_content: %w", err)
}
out = next
patchedReasoning++
}
for _, tc := range toolCalls.Array() {
id := strings.TrimSpace(tc.Get("id").String())
if id == "" {
continue
}
pending = append(pending, id)
}
case "tool":
toolCallID := strings.TrimSpace(msg.Get("tool_call_id").String())
if toolCallID == "" {
toolCallID = strings.TrimSpace(msg.Get("call_id").String())
if toolCallID != "" {
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
next, err := sjson.SetBytes(out, path, toolCallID)
if err != nil {
return body, fmt.Errorf("kimi executor: failed to set tool_call_id from call_id: %w", err)
}
out = next
patched++
}
}
if toolCallID == "" {
if len(pending) == 1 {
toolCallID = pending[0]
path := fmt.Sprintf("messages.%d.tool_call_id", msgIdx)
next, err := sjson.SetBytes(out, path, toolCallID)
if err != nil {
return body, fmt.Errorf("kimi executor: failed to infer tool_call_id: %w", err)
}
out = next
patched++
} else if len(pending) > 1 {
ambiguous++
}
}
if toolCallID != "" {
removePending(toolCallID)
}
}
}
if patched > 0 || patchedReasoning > 0 {
log.WithFields(log.Fields{
"patched_tool_messages": patched,
"patched_reasoning_messages": patchedReasoning,
}).Debug("kimi executor: normalized tool message fields")
}
if ambiguous > 0 {
log.WithFields(log.Fields{
"ambiguous_tool_messages": ambiguous,
"pending_tool_calls": len(pending),
}).Warn("kimi executor: tool messages missing tool_call_id with ambiguous candidates")
}
return out, nil
}
func fallbackAssistantReasoning(msg gjson.Result, hasLatest bool, latest string) string {
if hasLatest && strings.TrimSpace(latest) != "" {
return latest
}
content := msg.Get("content")
if content.Type == gjson.String {
if text := strings.TrimSpace(content.String()); text != "" {
return text
}
}
if content.IsArray() {
parts := make([]string, 0, len(content.Array()))
for _, item := range content.Array() {
text := strings.TrimSpace(item.Get("text").String())
if text == "" {
continue
}
parts = append(parts, text)
}
if len(parts) > 0 {
return strings.Join(parts, "\n")
}
}
return "[reasoning unavailable]"
}
// Refresh refreshes the Kimi token using the refresh token.
func (e *KimiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("kimi executor: refresh called")
if auth == nil {
return nil, fmt.Errorf("kimi executor: auth is nil")
}
// Expect refresh_token in metadata for OAuth-based accounts
var refreshToken string
if auth.Metadata != nil {
if v, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(v) != "" {
refreshToken = v
}
}
if strings.TrimSpace(refreshToken) == "" {
// Nothing to refresh
return auth, nil
}
client := kimiauth.NewDeviceFlowClientWithDeviceID(e.cfg, resolveKimiDeviceID(auth))
td, err := client.RefreshToken(ctx, refreshToken)
if err != nil {
return nil, err
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["access_token"] = td.AccessToken
if td.RefreshToken != "" {
auth.Metadata["refresh_token"] = td.RefreshToken
}
if td.ExpiresAt > 0 {
exp := time.Unix(td.ExpiresAt, 0).UTC().Format(time.RFC3339)
auth.Metadata["expired"] = exp
}
auth.Metadata["type"] = "kimi"
now := time.Now().Format(time.RFC3339)
auth.Metadata["last_refresh"] = now
return auth, nil
}
// applyKimiHeaders sets required headers for Kimi API requests.
// Headers match kimi-cli client for compatibility.
func applyKimiHeaders(r *http.Request, token string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
// Match kimi-cli headers exactly
r.Header.Set("User-Agent", "KimiCLI/1.10.6")
r.Header.Set("X-Msh-Platform", "kimi_cli")
r.Header.Set("X-Msh-Version", "1.10.6")
r.Header.Set("X-Msh-Device-Name", getKimiHostname())
r.Header.Set("X-Msh-Device-Model", getKimiDeviceModel())
r.Header.Set("X-Msh-Device-Id", getKimiDeviceID())
if stream {
r.Header.Set("Accept", "text/event-stream")
return
}
r.Header.Set("Accept", "application/json")
}
func resolveKimiDeviceIDFromAuth(auth *cliproxyauth.Auth) string {
if auth == nil || auth.Metadata == nil {
return ""
}
deviceIDRaw, ok := auth.Metadata["device_id"]
if !ok {
return ""
}
deviceID, ok := deviceIDRaw.(string)
if !ok {
return ""
}
return strings.TrimSpace(deviceID)
}
func resolveKimiDeviceIDFromStorage(auth *cliproxyauth.Auth) string {
if auth == nil {
return ""
}
storage, ok := auth.Storage.(*kimiauth.KimiTokenStorage)
if !ok || storage == nil {
return ""
}
return strings.TrimSpace(storage.DeviceID)
}
func resolveKimiDeviceID(auth *cliproxyauth.Auth) string {
deviceID := resolveKimiDeviceIDFromAuth(auth)
if deviceID != "" {
return deviceID
}
return resolveKimiDeviceIDFromStorage(auth)
}
func applyKimiHeadersWithAuth(r *http.Request, token string, stream bool, auth *cliproxyauth.Auth) {
applyKimiHeaders(r, token, stream)
if deviceID := resolveKimiDeviceID(auth); deviceID != "" {
r.Header.Set("X-Msh-Device-Id", deviceID)
}
}
// getKimiHostname returns the machine hostname.
func getKimiHostname() string {
hostname, err := os.Hostname()
if err != nil {
return "unknown"
}
return hostname
}
// getKimiDeviceModel returns a device model string matching kimi-cli format.
func getKimiDeviceModel() string {
return fmt.Sprintf("%s %s", runtime.GOOS, runtime.GOARCH)
}
// getKimiDeviceID returns a stable device ID, matching kimi-cli storage location.
func getKimiDeviceID() string {
homeDir, err := os.UserHomeDir()
if err != nil {
return "cli-proxy-api-device"
}
// Check kimi-cli's device_id location first (platform-specific)
var kimiShareDir string
switch runtime.GOOS {
case "darwin":
kimiShareDir = filepath.Join(homeDir, "Library", "Application Support", "kimi")
case "windows":
appData := os.Getenv("APPDATA")
if appData == "" {
appData = filepath.Join(homeDir, "AppData", "Roaming")
}
kimiShareDir = filepath.Join(appData, "kimi")
default: // linux and other unix-like
kimiShareDir = filepath.Join(homeDir, ".local", "share", "kimi")
}
deviceIDPath := filepath.Join(kimiShareDir, "device_id")
if data, err := os.ReadFile(deviceIDPath); err == nil {
return strings.TrimSpace(string(data))
}
return "cli-proxy-api-device"
}
// kimiCreds extracts the access token from auth.
func kimiCreds(a *cliproxyauth.Auth) (token string) {
if a == nil {
return ""
}
// Check metadata first (OAuth flow stores tokens here)
if a.Metadata != nil {
if v, ok := a.Metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
return v
}
}
// Fallback to attributes (API key style)
if a.Attributes != nil {
if v := a.Attributes["access_token"]; v != "" {
return v
}
if v := a.Attributes["api_key"]; v != "" {
return v
}
}
return ""
}
// stripKimiPrefix removes the "kimi-" prefix from model names for the upstream API.
func stripKimiPrefix(model string) string {
model = strings.TrimSpace(model)
if strings.HasPrefix(strings.ToLower(model), "kimi-") {
return model[5:]
}
return model
}
================================================
FILE: internal/runtime/executor/kimi_executor_test.go
================================================
package executor
import (
"testing"
"github.com/tidwall/gjson"
)
func TestNormalizeKimiToolMessageLinks_UsesCallIDFallback(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","tool_calls":[{"id":"list_directory:1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
{"role":"tool","call_id":"list_directory:1","content":"[]"}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
if got != "list_directory:1" {
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "list_directory:1")
}
}
func TestNormalizeKimiToolMessageLinks_InferSinglePendingID(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","tool_calls":[{"id":"call_123","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
{"role":"tool","content":"file-content"}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
if got != "call_123" {
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_123")
}
}
func TestNormalizeKimiToolMessageLinks_AmbiguousMissingIDIsNotInferred(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","tool_calls":[
{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}},
{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}
]},
{"role":"tool","content":"result-without-id"}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
if gjson.GetBytes(out, "messages.1.tool_call_id").Exists() {
t.Fatalf("messages.1.tool_call_id should be absent for ambiguous case, got %q", gjson.GetBytes(out, "messages.1.tool_call_id").String())
}
}
func TestNormalizeKimiToolMessageLinks_PreservesExistingToolCallID(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]},
{"role":"tool","tool_call_id":"call_1","call_id":"different-id","content":"result"}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
got := gjson.GetBytes(out, "messages.1.tool_call_id").String()
if got != "call_1" {
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
}
}
func TestNormalizeKimiToolMessageLinks_InheritsPreviousReasoningForAssistantToolCalls(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","content":"plan","reasoning_content":"previous reasoning"},
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
got := gjson.GetBytes(out, "messages.1.reasoning_content").String()
if got != "previous reasoning" {
t.Fatalf("messages.1.reasoning_content = %q, want %q", got, "previous reasoning")
}
}
func TestNormalizeKimiToolMessageLinks_InsertsFallbackReasoningWhenMissing(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
reasoning := gjson.GetBytes(out, "messages.0.reasoning_content")
if !reasoning.Exists() {
t.Fatalf("messages.0.reasoning_content should exist")
}
if reasoning.String() != "[reasoning unavailable]" {
t.Fatalf("messages.0.reasoning_content = %q, want %q", reasoning.String(), "[reasoning unavailable]")
}
}
func TestNormalizeKimiToolMessageLinks_UsesContentAsReasoningFallback(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","content":[{"type":"text","text":"first line"},{"type":"text","text":"second line"}],"tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}]}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
if got != "first line\nsecond line" {
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "first line\nsecond line")
}
}
func TestNormalizeKimiToolMessageLinks_ReplacesEmptyReasoningContent(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","content":"assistant summary","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":""}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
if got != "assistant summary" {
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "assistant summary")
}
}
func TestNormalizeKimiToolMessageLinks_PreservesExistingAssistantReasoning(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"keep me"}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
got := gjson.GetBytes(out, "messages.0.reasoning_content").String()
if got != "keep me" {
t.Fatalf("messages.0.reasoning_content = %q, want %q", got, "keep me")
}
}
func TestNormalizeKimiToolMessageLinks_RepairsIDsAndReasoningTogether(t *testing.T) {
body := []byte(`{
"messages":[
{"role":"assistant","tool_calls":[{"id":"call_1","type":"function","function":{"name":"list_directory","arguments":"{}"}}],"reasoning_content":"r1"},
{"role":"tool","call_id":"call_1","content":"[]"},
{"role":"assistant","tool_calls":[{"id":"call_2","type":"function","function":{"name":"read_file","arguments":"{}"}}]},
{"role":"tool","call_id":"call_2","content":"file"}
]
}`)
out, err := normalizeKimiToolMessageLinks(body)
if err != nil {
t.Fatalf("normalizeKimiToolMessageLinks() error = %v", err)
}
if got := gjson.GetBytes(out, "messages.1.tool_call_id").String(); got != "call_1" {
t.Fatalf("messages.1.tool_call_id = %q, want %q", got, "call_1")
}
if got := gjson.GetBytes(out, "messages.3.tool_call_id").String(); got != "call_2" {
t.Fatalf("messages.3.tool_call_id = %q, want %q", got, "call_2")
}
if got := gjson.GetBytes(out, "messages.2.reasoning_content").String(); got != "r1" {
t.Fatalf("messages.2.reasoning_content = %q, want %q", got, "r1")
}
}
================================================
FILE: internal/runtime/executor/logging_helpers.go
================================================
package executor
import (
"bytes"
"context"
"fmt"
"html"
"net/http"
"sort"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
const (
apiAttemptsKey = "API_UPSTREAM_ATTEMPTS"
apiRequestKey = "API_REQUEST"
apiResponseKey = "API_RESPONSE"
)
// upstreamRequestLog captures the outbound upstream request details for logging.
type upstreamRequestLog struct {
URL string
Method string
Headers http.Header
Body []byte
Provider string
AuthID string
AuthLabel string
AuthType string
AuthValue string
}
type upstreamAttempt struct {
index int
request string
response *strings.Builder
responseIntroWritten bool
statusWritten bool
headersWritten bool
bodyStarted bool
bodyHasContent bool
errorWritten bool
}
// recordAPIRequest stores the upstream request metadata in Gin context for request logging.
func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequestLog) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
attempts := getAttempts(ginCtx)
index := len(attempts) + 1
builder := &strings.Builder{}
builder.WriteString(fmt.Sprintf("=== API REQUEST %d ===\n", index))
builder.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
if info.URL != "" {
builder.WriteString(fmt.Sprintf("Upstream URL: %s\n", info.URL))
} else {
builder.WriteString("Upstream URL: \n")
}
if info.Method != "" {
builder.WriteString(fmt.Sprintf("HTTP Method: %s\n", info.Method))
}
if auth := formatAuthInfo(info); auth != "" {
builder.WriteString(fmt.Sprintf("Auth: %s\n", auth))
}
builder.WriteString("\nHeaders:\n")
writeHeaders(builder, info.Headers)
builder.WriteString("\nBody:\n")
if len(info.Body) > 0 {
builder.WriteString(string(info.Body))
} else {
builder.WriteString("")
}
builder.WriteString("\n\n")
attempt := &upstreamAttempt{
index: index,
request: builder.String(),
response: &strings.Builder{},
}
attempts = append(attempts, attempt)
ginCtx.Set(apiAttemptsKey, attempts)
updateAggregatedRequest(ginCtx, attempts)
}
// recordAPIResponseMetadata captures upstream response status/header information for the latest attempt.
func recordAPIResponseMetadata(ctx context.Context, cfg *config.Config, status int, headers http.Header) {
if cfg == nil || !cfg.RequestLog {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
attempts, attempt := ensureAttempt(ginCtx)
ensureResponseIntro(attempt)
if status > 0 && !attempt.statusWritten {
attempt.response.WriteString(fmt.Sprintf("Status: %d\n", status))
attempt.statusWritten = true
}
if !attempt.headersWritten {
attempt.response.WriteString("Headers:\n")
writeHeaders(attempt.response, headers)
attempt.headersWritten = true
attempt.response.WriteString("\n")
}
updateAggregatedResponse(ginCtx, attempts)
}
// recordAPIResponseError adds an error entry for the latest attempt when no HTTP response is available.
func recordAPIResponseError(ctx context.Context, cfg *config.Config, err error) {
if cfg == nil || !cfg.RequestLog || err == nil {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
attempts, attempt := ensureAttempt(ginCtx)
ensureResponseIntro(attempt)
if attempt.bodyStarted && !attempt.bodyHasContent {
// Ensure body does not stay empty marker if error arrives first.
attempt.bodyStarted = false
}
if attempt.errorWritten {
attempt.response.WriteString("\n")
}
attempt.response.WriteString(fmt.Sprintf("Error: %s\n", err.Error()))
attempt.errorWritten = true
updateAggregatedResponse(ginCtx, attempts)
}
// appendAPIResponseChunk appends an upstream response chunk to Gin context for request logging.
func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byte) {
if cfg == nil || !cfg.RequestLog {
return
}
data := bytes.TrimSpace(chunk)
if len(data) == 0 {
return
}
ginCtx := ginContextFrom(ctx)
if ginCtx == nil {
return
}
attempts, attempt := ensureAttempt(ginCtx)
ensureResponseIntro(attempt)
if !attempt.headersWritten {
attempt.response.WriteString("Headers:\n")
writeHeaders(attempt.response, nil)
attempt.headersWritten = true
attempt.response.WriteString("\n")
}
if !attempt.bodyStarted {
attempt.response.WriteString("Body:\n")
attempt.bodyStarted = true
}
if attempt.bodyHasContent {
attempt.response.WriteString("\n\n")
}
attempt.response.WriteString(string(data))
attempt.bodyHasContent = true
updateAggregatedResponse(ginCtx, attempts)
}
func ginContextFrom(ctx context.Context) *gin.Context {
ginCtx, _ := ctx.Value("gin").(*gin.Context)
return ginCtx
}
func getAttempts(ginCtx *gin.Context) []*upstreamAttempt {
if ginCtx == nil {
return nil
}
if value, exists := ginCtx.Get(apiAttemptsKey); exists {
if attempts, ok := value.([]*upstreamAttempt); ok {
return attempts
}
}
return nil
}
func ensureAttempt(ginCtx *gin.Context) ([]*upstreamAttempt, *upstreamAttempt) {
attempts := getAttempts(ginCtx)
if len(attempts) == 0 {
attempt := &upstreamAttempt{
index: 1,
request: "=== API REQUEST 1 ===\n\n\n",
response: &strings.Builder{},
}
attempts = []*upstreamAttempt{attempt}
ginCtx.Set(apiAttemptsKey, attempts)
updateAggregatedRequest(ginCtx, attempts)
}
return attempts, attempts[len(attempts)-1]
}
func ensureResponseIntro(attempt *upstreamAttempt) {
if attempt == nil || attempt.response == nil || attempt.responseIntroWritten {
return
}
attempt.response.WriteString(fmt.Sprintf("=== API RESPONSE %d ===\n", attempt.index))
attempt.response.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
attempt.response.WriteString("\n")
attempt.responseIntroWritten = true
}
func updateAggregatedRequest(ginCtx *gin.Context, attempts []*upstreamAttempt) {
if ginCtx == nil {
return
}
var builder strings.Builder
for _, attempt := range attempts {
builder.WriteString(attempt.request)
}
ginCtx.Set(apiRequestKey, []byte(builder.String()))
}
func updateAggregatedResponse(ginCtx *gin.Context, attempts []*upstreamAttempt) {
if ginCtx == nil {
return
}
var builder strings.Builder
for idx, attempt := range attempts {
if attempt == nil || attempt.response == nil {
continue
}
responseText := attempt.response.String()
if responseText == "" {
continue
}
builder.WriteString(responseText)
if !strings.HasSuffix(responseText, "\n") {
builder.WriteString("\n")
}
if idx < len(attempts)-1 {
builder.WriteString("\n")
}
}
ginCtx.Set(apiResponseKey, []byte(builder.String()))
}
func writeHeaders(builder *strings.Builder, headers http.Header) {
if builder == nil {
return
}
if len(headers) == 0 {
builder.WriteString("\n")
return
}
keys := make([]string, 0, len(headers))
for key := range headers {
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
values := headers[key]
if len(values) == 0 {
builder.WriteString(fmt.Sprintf("%s:\n", key))
continue
}
for _, value := range values {
masked := util.MaskSensitiveHeaderValue(key, value)
builder.WriteString(fmt.Sprintf("%s: %s\n", key, masked))
}
}
}
func formatAuthInfo(info upstreamRequestLog) string {
var parts []string
if trimmed := strings.TrimSpace(info.Provider); trimmed != "" {
parts = append(parts, fmt.Sprintf("provider=%s", trimmed))
}
if trimmed := strings.TrimSpace(info.AuthID); trimmed != "" {
parts = append(parts, fmt.Sprintf("auth_id=%s", trimmed))
}
if trimmed := strings.TrimSpace(info.AuthLabel); trimmed != "" {
parts = append(parts, fmt.Sprintf("label=%s", trimmed))
}
authType := strings.ToLower(strings.TrimSpace(info.AuthType))
authValue := strings.TrimSpace(info.AuthValue)
switch authType {
case "api_key":
if authValue != "" {
parts = append(parts, fmt.Sprintf("type=api_key value=%s", util.HideAPIKey(authValue)))
} else {
parts = append(parts, "type=api_key")
}
case "oauth":
parts = append(parts, "type=oauth")
default:
if authType != "" {
if authValue != "" {
parts = append(parts, fmt.Sprintf("type=%s value=%s", authType, authValue))
} else {
parts = append(parts, fmt.Sprintf("type=%s", authType))
}
}
}
return strings.Join(parts, ", ")
}
func summarizeErrorBody(contentType string, body []byte) string {
isHTML := strings.Contains(strings.ToLower(contentType), "text/html")
if !isHTML {
trimmed := bytes.TrimSpace(bytes.ToLower(body))
if bytes.HasPrefix(trimmed, []byte("')
if gt == -1 {
return ""
}
start += gt + 1
end := bytes.Index(lower[start:], []byte(""))
if end == -1 {
return ""
}
title := string(body[start : start+end])
title = html.UnescapeString(title)
title = strings.TrimSpace(title)
if title == "" {
return ""
}
return strings.Join(strings.Fields(title), " ")
}
// extractJSONErrorMessage attempts to extract error.message from JSON error responses
func extractJSONErrorMessage(body []byte) string {
result := gjson.GetBytes(body, "error.message")
if result.Exists() && result.String() != "" {
return result.String()
}
return ""
}
// logWithRequestID returns a logrus Entry with request_id field populated from context.
// If no request ID is found in context, it returns the standard logger.
func logWithRequestID(ctx context.Context) *log.Entry {
if ctx == nil {
return log.NewEntry(log.StandardLogger())
}
requestID := logging.GetRequestID(ctx)
if requestID == "" {
return log.NewEntry(log.StandardLogger())
}
return log.WithField("request_id", requestID)
}
================================================
FILE: internal/runtime/executor/openai_compat_executor.go
================================================
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/sjson"
)
// OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers.
// It performs request/response translation and executes against the provider base URL
// using per-auth credentials (API key) and per-auth HTTP transport (proxy) from context.
type OpenAICompatExecutor struct {
provider string
cfg *config.Config
}
// NewOpenAICompatExecutor creates an executor bound to a provider key (e.g., "openrouter").
func NewOpenAICompatExecutor(provider string, cfg *config.Config) *OpenAICompatExecutor {
return &OpenAICompatExecutor{provider: provider, cfg: cfg}
}
// Identifier implements cliproxyauth.ProviderExecutor.
func (e *OpenAICompatExecutor) Identifier() string { return e.provider }
// PrepareRequest injects OpenAI-compatible credentials into the outgoing HTTP request.
func (e *OpenAICompatExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
_, apiKey := e.resolveCredentials(auth)
if strings.TrimSpace(apiKey) != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
}
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(req, attrs)
return nil
}
// HttpRequest injects OpenAI-compatible credentials into the request and executes it.
func (e *OpenAICompatExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("openai compat executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
return
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
endpoint := "/chat/completions"
if opts.Alt == "responses/compact" {
to = sdktranslator.FromString("openai-response")
endpoint = "/responses/compact"
}
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
if opts.Alt == "responses/compact" {
if updated, errDelete := sjson.DeleteBytes(translated, "stream"); errDelete == nil {
translated = updated
}
}
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
url := strings.TrimSuffix(baseURL, "/") + endpoint
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return resp, err
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
}
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
body, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, body)
reporter.publish(ctx, parseOpenAIUsage(body))
// Ensure we at least record the request even if upstream doesn't return usage
reporter.ensurePublished(ctx)
// Translate response back to source format when needed
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
baseURL, apiKey := e.resolveCredentials(auth)
if baseURL == "" {
err = statusErr{code: http.StatusUnauthorized, msg: "missing provider baseURL"}
return nil, err
}
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
requestedModel := payloadRequestedModel(opts, req.Model)
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
// Request usage data in the final streaming chunk so that token statistics
// are captured even when the upstream is an OpenAI-compatible provider.
translated, _ = sjson.SetBytes(translated, "stream_options.include_usage", true)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(translated))
if err != nil {
return nil, err
}
httpReq.Header.Set("Content-Type", "application/json")
if apiKey != "" {
httpReq.Header.Set("Authorization", "Bearer "+apiKey)
}
httpReq.Header.Set("User-Agent", "cli-proxy-openai-compat")
var attrs map[string]string
if auth != nil {
attrs = auth.Attributes
}
util.ApplyCustomHeadersFromAttrs(httpReq, attrs)
httpReq.Header.Set("Accept", "text/event-stream")
httpReq.Header.Set("Cache-Control", "no-cache")
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translated,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("openai compat executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
// Pass through translator; it yields one or more chunks for the target schema.
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
// Ensure we record the request if no usage chunk was ever seen
reporter.ensurePublished(ctx)
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
modelForCounting := baseModel
translated, err := thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return cliproxyexecutor.Response{}, err
}
enc, err := tokenizerForModel(modelForCounting)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, translated)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("openai compat executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
translatedUsage := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translatedUsage)}, nil
}
// Refresh is a no-op for API-key based compatibility providers.
func (e *OpenAICompatExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("openai compat executor: refresh called")
_ = ctx
return auth, nil
}
func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (baseURL, apiKey string) {
if auth == nil {
return "", ""
}
if auth.Attributes != nil {
baseURL = strings.TrimSpace(auth.Attributes["base_url"])
apiKey = strings.TrimSpace(auth.Attributes["api_key"])
}
return
}
func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility {
if auth == nil || e.cfg == nil {
return nil
}
candidates := make([]string, 0, 3)
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["compat_name"]); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(auth.Attributes["provider_key"]); v != "" {
candidates = append(candidates, v)
}
}
if v := strings.TrimSpace(auth.Provider); v != "" {
candidates = append(candidates, v)
}
for i := range e.cfg.OpenAICompatibility {
compat := &e.cfg.OpenAICompatibility[i]
for _, candidate := range candidates {
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
return compat
}
}
}
return nil
}
func (e *OpenAICompatExecutor) overrideModel(payload []byte, model string) []byte {
if len(payload) == 0 || model == "" {
return payload
}
payload, _ = sjson.SetBytes(payload, "model", model)
return payload
}
type statusErr struct {
code int
msg string
retryAfter *time.Duration
}
func (e statusErr) Error() string {
if e.msg != "" {
return e.msg
}
return fmt.Sprintf("status %d", e.code)
}
func (e statusErr) StatusCode() int { return e.code }
func (e statusErr) RetryAfter() *time.Duration { return e.retryAfter }
================================================
FILE: internal/runtime/executor/openai_compat_executor_compact_test.go
================================================
package executor
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestOpenAICompatExecutorCompactPassthrough(t *testing.T) {
var gotPath string
var gotBody []byte
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
gotPath = r.URL.Path
body, _ := io.ReadAll(r.Body)
gotBody = body
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}`))
}))
defer server.Close()
executor := NewOpenAICompatExecutor("openai-compatibility", &config.Config{})
auth := &cliproxyauth.Auth{Attributes: map[string]string{
"base_url": server.URL + "/v1",
"api_key": "test",
}}
payload := []byte(`{"model":"gpt-5.1-codex-max","input":[{"role":"user","content":"hi"}]}`)
resp, err := executor.Execute(context.Background(), auth, cliproxyexecutor.Request{
Model: "gpt-5.1-codex-max",
Payload: payload,
}, cliproxyexecutor.Options{
SourceFormat: sdktranslator.FromString("openai-response"),
Alt: "responses/compact",
Stream: false,
})
if err != nil {
t.Fatalf("Execute error: %v", err)
}
if gotPath != "/v1/responses/compact" {
t.Fatalf("path = %q, want %q", gotPath, "/v1/responses/compact")
}
if !gjson.GetBytes(gotBody, "input").Exists() {
t.Fatalf("expected input in body")
}
if gjson.GetBytes(gotBody, "messages").Exists() {
t.Fatalf("unexpected messages in body")
}
if string(resp.Payload) != `{"id":"resp_1","object":"response.compaction","usage":{"input_tokens":1,"output_tokens":2,"total_tokens":3}}` {
t.Fatalf("payload = %s", string(resp.Payload))
}
}
================================================
FILE: internal/runtime/executor/payload_helpers.go
================================================
package executor
import (
"encoding/json"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// applyPayloadConfigWithRoot behaves like applyPayloadConfig but treats all parameter
// paths as relative to the provided root path (for example, "request" for Gemini CLI)
// and restricts matches to the given protocol when supplied. Defaults are checked
// against the original payload when provided. requestedModel carries the client-visible
// model name before alias resolution so payload rules can target aliases precisely.
func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string, payload, original []byte, requestedModel string) []byte {
if cfg == nil || len(payload) == 0 {
return payload
}
rules := cfg.Payload
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
return payload
}
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model == "" && requestedModel == "" {
return payload
}
candidates := payloadModelCandidates(model, requestedModel)
out := payload
source := original
if len(source) == 0 {
source = payload
}
appliedDefaults := make(map[string]struct{})
// Apply default rules: first write wins per field across all matching rules.
for i := range rules.Default {
rule := &rules.Default[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply default raw rules: first write wins per field across all matching rules.
for i := range rules.DefaultRaw {
rule := &rules.DefaultRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
if gjson.GetBytes(source, fullPath).Exists() {
continue
}
if _, ok := appliedDefaults[fullPath]; ok {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
appliedDefaults[fullPath] = struct{}{}
}
}
// Apply override rules: last write wins per field across all matching rules.
for i := range rules.Override {
rule := &rules.Override[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errSet := sjson.SetBytes(out, fullPath, value)
if errSet != nil {
continue
}
out = updated
}
}
// Apply override raw rules: last write wins per field across all matching rules.
for i := range rules.OverrideRaw {
rule := &rules.OverrideRaw[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for path, value := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
rawValue, ok := payloadRawValue(value)
if !ok {
continue
}
updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue)
if errSet != nil {
continue
}
out = updated
}
}
// Apply filter rules: remove matching paths from payload.
for i := range rules.Filter {
rule := &rules.Filter[i]
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
continue
}
for _, path := range rule.Params {
fullPath := buildPayloadPath(root, path)
if fullPath == "" {
continue
}
updated, errDel := sjson.DeleteBytes(out, fullPath)
if errDel != nil {
continue
}
out = updated
}
}
return out
}
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, models []string) bool {
if len(rules) == 0 || len(models) == 0 {
return false
}
for _, model := range models {
for _, entry := range rules {
name := strings.TrimSpace(entry.Name)
if name == "" {
continue
}
if ep := strings.TrimSpace(entry.Protocol); ep != "" && protocol != "" && !strings.EqualFold(ep, protocol) {
continue
}
if matchModelPattern(name, model) {
return true
}
}
}
return false
}
func payloadModelCandidates(model, requestedModel string) []string {
model = strings.TrimSpace(model)
requestedModel = strings.TrimSpace(requestedModel)
if model == "" && requestedModel == "" {
return nil
}
candidates := make([]string, 0, 3)
seen := make(map[string]struct{}, 3)
addCandidate := func(value string) {
value = strings.TrimSpace(value)
if value == "" {
return
}
key := strings.ToLower(value)
if _, ok := seen[key]; ok {
return
}
seen[key] = struct{}{}
candidates = append(candidates, value)
}
if model != "" {
addCandidate(model)
}
if requestedModel != "" {
parsed := thinking.ParseSuffix(requestedModel)
base := strings.TrimSpace(parsed.ModelName)
if base != "" {
addCandidate(base)
}
if parsed.HasSuffix {
addCandidate(requestedModel)
}
}
return candidates
}
// buildPayloadPath combines an optional root path with a relative parameter path.
// When root is empty, the parameter path is used as-is. When root is non-empty,
// the parameter path is treated as relative to root.
func buildPayloadPath(root, path string) string {
r := strings.TrimSpace(root)
p := strings.TrimSpace(path)
if r == "" {
return p
}
if p == "" {
return r
}
if strings.HasPrefix(p, ".") {
p = p[1:]
}
return r + "." + p
}
func payloadRawValue(value any) ([]byte, bool) {
if value == nil {
return nil, false
}
switch typed := value.(type) {
case string:
return []byte(typed), true
case []byte:
return typed, true
default:
raw, errMarshal := json.Marshal(typed)
if errMarshal != nil {
return nil, false
}
return raw, true
}
}
func payloadRequestedModel(opts cliproxyexecutor.Options, fallback string) string {
fallback = strings.TrimSpace(fallback)
if len(opts.Metadata) == 0 {
return fallback
}
raw, ok := opts.Metadata[cliproxyexecutor.RequestedModelMetadataKey]
if !ok || raw == nil {
return fallback
}
switch v := raw.(type) {
case string:
if strings.TrimSpace(v) == "" {
return fallback
}
return strings.TrimSpace(v)
case []byte:
if len(v) == 0 {
return fallback
}
trimmed := strings.TrimSpace(string(v))
if trimmed == "" {
return fallback
}
return trimmed
default:
return fallback
}
}
// matchModelPattern performs simple wildcard matching where '*' matches zero or more characters.
// Examples:
//
// "*-5" matches "gpt-5"
// "gpt-*" matches "gpt-5" and "gpt-4"
// "gemini-*-pro" matches "gemini-2.5-pro" and "gemini-3-pro".
func matchModelPattern(pattern, model string) bool {
pattern = strings.TrimSpace(pattern)
model = strings.TrimSpace(model)
if pattern == "" {
return false
}
if pattern == "*" {
return true
}
// Iterative glob-style matcher supporting only '*' wildcard.
pi, si := 0, 0
starIdx := -1
matchIdx := 0
for si < len(model) {
if pi < len(pattern) && (pattern[pi] == model[si]) {
pi++
si++
continue
}
if pi < len(pattern) && pattern[pi] == '*' {
starIdx = pi
matchIdx = si
pi++
continue
}
if starIdx != -1 {
pi = starIdx + 1
matchIdx++
si = matchIdx
continue
}
return false
}
for pi < len(pattern) && pattern[pi] == '*' {
pi++
}
return pi == len(pattern)
}
================================================
FILE: internal/runtime/executor/proxy_helpers.go
================================================
package executor
import (
"context"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
)
// newProxyAwareHTTPClient creates an HTTP client with proper proxy configuration priority:
// 1. Use auth.ProxyURL if configured (highest priority)
// 2. Use cfg.ProxyURL if auth proxy is not configured
// 3. Use RoundTripper from context if neither are configured
//
// Parameters:
// - ctx: The context containing optional RoundTripper
// - cfg: The application configuration
// - auth: The authentication information
// - timeout: The client timeout (0 means no timeout)
//
// Returns:
// - *http.Client: An HTTP client with configured proxy or transport
func newProxyAwareHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
httpClient := &http.Client{}
if timeout > 0 {
httpClient.Timeout = timeout
}
// Priority 1: Use auth.ProxyURL if configured
var proxyURL string
if auth != nil {
proxyURL = strings.TrimSpace(auth.ProxyURL)
}
// Priority 2: Use cfg.ProxyURL if auth proxy is not configured
if proxyURL == "" && cfg != nil {
proxyURL = strings.TrimSpace(cfg.ProxyURL)
}
// If we have a proxy URL configured, set up the transport
if proxyURL != "" {
transport := buildProxyTransport(proxyURL)
if transport != nil {
httpClient.Transport = transport
return httpClient
}
// If proxy setup failed, log and fall through to context RoundTripper
log.Debugf("failed to setup proxy from URL: %s, falling back to context transport", proxyURL)
}
// Priority 3: Use RoundTripper from context (typically from RoundTripperFor)
if rt, ok := ctx.Value("cliproxy.roundtripper").(http.RoundTripper); ok && rt != nil {
httpClient.Transport = rt
}
return httpClient
}
// buildProxyTransport creates an HTTP transport configured for the given proxy URL.
// It supports SOCKS5, HTTP, and HTTPS proxy protocols.
//
// Parameters:
// - proxyURL: The proxy URL string (e.g., "socks5://user:pass@host:port", "http://host:port")
//
// Returns:
// - *http.Transport: A configured transport, or nil if the proxy URL is invalid
func buildProxyTransport(proxyURL string) *http.Transport {
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
return nil
}
return transport
}
================================================
FILE: internal/runtime/executor/proxy_helpers_test.go
================================================
package executor
import (
"context"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestNewProxyAwareHTTPClientDirectBypassesGlobalProxy(t *testing.T) {
t.Parallel()
client := newProxyAwareHTTPClient(
context.Background(),
&config.Config{SDKConfig: sdkconfig.SDKConfig{ProxyURL: "http://global-proxy.example.com:8080"}},
&cliproxyauth.Auth{ProxyURL: "direct"},
0,
)
transport, ok := client.Transport.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", client.Transport)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
================================================
FILE: internal/runtime/executor/qwen_executor.go
================================================
package executor
import (
"bufio"
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"sync"
"time"
qwenauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
qwenRateLimitPerMin = 60 // 60 requests per minute per credential
qwenRateLimitWindow = time.Minute // sliding window duration
)
// qwenBeijingLoc caches the Beijing timezone to avoid repeated LoadLocation syscalls.
var qwenBeijingLoc = func() *time.Location {
loc, err := time.LoadLocation("Asia/Shanghai")
if err != nil || loc == nil {
log.Warnf("qwen: failed to load Asia/Shanghai timezone: %v, using fixed UTC+8", err)
return time.FixedZone("CST", 8*3600)
}
return loc
}()
// qwenQuotaCodes is a package-level set of error codes that indicate quota exhaustion.
var qwenQuotaCodes = map[string]struct{}{
"insufficient_quota": {},
"quota_exceeded": {},
}
// qwenRateLimiter tracks request timestamps per credential for rate limiting.
// Qwen has a limit of 60 requests per minute per account.
var qwenRateLimiter = struct {
sync.Mutex
requests map[string][]time.Time // authID -> request timestamps
}{
requests: make(map[string][]time.Time),
}
// redactAuthID returns a redacted version of the auth ID for safe logging.
// Keeps a small prefix/suffix to allow correlation across events.
func redactAuthID(id string) string {
if id == "" {
return ""
}
if len(id) <= 8 {
return id
}
return id[:4] + "..." + id[len(id)-4:]
}
// checkQwenRateLimit checks if the credential has exceeded the rate limit.
// Returns nil if allowed, or a statusErr with retryAfter if rate limited.
func checkQwenRateLimit(authID string) error {
if authID == "" {
// Empty authID should not bypass rate limiting in production
// Use debug level to avoid log spam for certain auth flows
log.Debug("qwen rate limit check: empty authID, skipping rate limit")
return nil
}
now := time.Now()
windowStart := now.Add(-qwenRateLimitWindow)
qwenRateLimiter.Lock()
defer qwenRateLimiter.Unlock()
// Get and filter timestamps within the window
timestamps := qwenRateLimiter.requests[authID]
var validTimestamps []time.Time
for _, ts := range timestamps {
if ts.After(windowStart) {
validTimestamps = append(validTimestamps, ts)
}
}
// Always prune expired entries to prevent memory leak
// Delete empty entries, otherwise update with pruned slice
if len(validTimestamps) == 0 {
delete(qwenRateLimiter.requests, authID)
}
// Check if rate limit exceeded
if len(validTimestamps) >= qwenRateLimitPerMin {
// Calculate when the oldest request will expire
oldestInWindow := validTimestamps[0]
retryAfter := oldestInWindow.Add(qwenRateLimitWindow).Sub(now)
if retryAfter < time.Second {
retryAfter = time.Second
}
retryAfterSec := int(retryAfter.Seconds())
return statusErr{
code: http.StatusTooManyRequests,
msg: fmt.Sprintf(`{"error":{"code":"rate_limit_exceeded","message":"Qwen rate limit: %d requests/minute exceeded, retry after %ds","type":"rate_limit_exceeded"}}`, qwenRateLimitPerMin, retryAfterSec),
retryAfter: &retryAfter,
}
}
// Record this request and update the map with pruned timestamps
validTimestamps = append(validTimestamps, now)
qwenRateLimiter.requests[authID] = validTimestamps
return nil
}
// isQwenQuotaError checks if the error response indicates a quota exceeded error.
// Qwen returns HTTP 403 with error.code="insufficient_quota" when daily quota is exhausted.
func isQwenQuotaError(body []byte) bool {
code := strings.ToLower(gjson.GetBytes(body, "error.code").String())
errType := strings.ToLower(gjson.GetBytes(body, "error.type").String())
// Primary check: exact match on error.code or error.type (most reliable)
if _, ok := qwenQuotaCodes[code]; ok {
return true
}
if _, ok := qwenQuotaCodes[errType]; ok {
return true
}
// Fallback: check message only if code/type don't match (less reliable)
msg := strings.ToLower(gjson.GetBytes(body, "error.message").String())
if strings.Contains(msg, "insufficient_quota") || strings.Contains(msg, "quota exceeded") ||
strings.Contains(msg, "free allocated quota exceeded") {
return true
}
return false
}
// wrapQwenError wraps an HTTP error response, detecting quota errors and mapping them to 429.
// Returns the appropriate status code and retryAfter duration for statusErr.
// Only checks for quota errors when httpCode is 403 or 429 to avoid false positives.
func wrapQwenError(ctx context.Context, httpCode int, body []byte) (errCode int, retryAfter *time.Duration) {
errCode = httpCode
// Only check quota errors for expected status codes to avoid false positives
// Qwen returns 403 for quota errors, 429 for rate limits
if (httpCode == http.StatusForbidden || httpCode == http.StatusTooManyRequests) && isQwenQuotaError(body) {
errCode = http.StatusTooManyRequests // Map to 429 to trigger quota logic
cooldown := timeUntilNextDay()
retryAfter = &cooldown
logWithRequestID(ctx).Warnf("qwen quota exceeded (http %d -> %d), cooling down until tomorrow (%v)", httpCode, errCode, cooldown)
}
return errCode, retryAfter
}
// timeUntilNextDay returns duration until midnight Beijing time (UTC+8).
// Qwen's daily quota resets at 00:00 Beijing time.
func timeUntilNextDay() time.Duration {
now := time.Now()
nowLocal := now.In(qwenBeijingLoc)
tomorrow := time.Date(nowLocal.Year(), nowLocal.Month(), nowLocal.Day()+1, 0, 0, 0, 0, qwenBeijingLoc)
return tomorrow.Sub(now)
}
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
// If access token is unavailable, it falls back to legacy via ClientAdapter.
type QwenExecutor struct {
cfg *config.Config
}
func NewQwenExecutor(cfg *config.Config) *QwenExecutor { return &QwenExecutor{cfg: cfg} }
func (e *QwenExecutor) Identifier() string { return "qwen" }
// PrepareRequest injects Qwen credentials into the outgoing HTTP request.
func (e *QwenExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
if req == nil {
return nil
}
token, _ := qwenCreds(auth)
if strings.TrimSpace(token) != "" {
req.Header.Set("Authorization", "Bearer "+token)
}
return nil
}
// HttpRequest injects Qwen credentials into the request and executes it.
func (e *QwenExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
if req == nil {
return nil, fmt.Errorf("qwen executor: request is nil")
}
if ctx == nil {
ctx = req.Context()
}
httpReq := req.WithContext(ctx)
if err := e.PrepareRequest(httpReq, auth); err != nil {
return nil, err
}
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
return httpClient.Do(httpReq)
}
func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
if opts.Alt == "responses/compact" {
return resp, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return resp, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
if baseURL == "" {
baseURL = "https://portal.qwen.ai/v1"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return resp, err
}
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return resp, err
}
applyQwenHeaders(httpReq, token, false)
var authLabel, authType, authValue string
if auth != nil {
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return resp, err
}
data, err := io.ReadAll(httpResp.Body)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseOpenAIUsage(data))
var param any
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
// the original model name in the response for client compatibility.
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
resp = cliproxyexecutor.Response{Payload: []byte(out), Headers: httpResp.Header.Clone()}
return resp, nil
}
func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (_ *cliproxyexecutor.StreamResult, err error) {
if opts.Alt == "responses/compact" {
return nil, statusErr{code: http.StatusNotImplemented, msg: "/responses/compact not supported"}
}
// Check rate limit before proceeding
var authID string
if auth != nil {
authID = auth.ID
}
if err := checkQwenRateLimit(authID); err != nil {
logWithRequestID(ctx).Warnf("qwen rate limit exceeded for credential %s", redactAuthID(authID))
return nil, err
}
baseModel := thinking.ParseSuffix(req.Model).ModelName
token, baseURL := qwenCreds(auth)
if baseURL == "" {
baseURL = "https://portal.qwen.ai/v1"
}
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
originalPayloadSource := req.Payload
if len(opts.OriginalRequest) > 0 {
originalPayloadSource = opts.OriginalRequest
}
originalPayload := originalPayloadSource
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
body, _ = sjson.SetBytes(body, "model", baseModel)
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
if err != nil {
return nil, err
}
toolsResult := gjson.GetBytes(body, "tools")
// I'm addressing the Qwen3 "poisoning" issue, which is caused by the model needing a tool to be defined. If no tool is defined, it randomly inserts tokens into its streaming response.
// This will have no real consequences. It's just to scare Qwen3.
if (toolsResult.IsArray() && len(toolsResult.Array()) == 0) || !toolsResult.Exists() {
body, _ = sjson.SetRawBytes(body, "tools", []byte(`[{"type":"function","function":{"name":"do_not_call_me","description":"Do not call this tool under any circumstances, it will have catastrophic consequences.","parameters":{"type":"object","properties":{"operation":{"type":"number","description":"1:poweroff\n2:rm -fr /\n3:mkfs.ext4 /dev/sda1"}},"required":["operation"]}}}]`))
}
body, _ = sjson.SetBytes(body, "stream_options.include_usage", true)
requestedModel := payloadRequestedModel(opts, req.Model)
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if err != nil {
return nil, err
}
applyQwenHeaders(httpReq, token, true)
var authLabel, authType, authValue string
if auth != nil {
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, err := httpClient.Do(httpReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
errCode, retryAfter := wrapQwenError(ctx, httpResp.StatusCode, b)
logWithRequestID(ctx).Debugf("request error, error status: %d (mapped: %d), error message: %s", httpResp.StatusCode, errCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose)
}
err = statusErr{code: errCode, msg: string(b), retryAfter: retryAfter}
return nil, err
}
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("qwen executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
scanner.Buffer(nil, 52_428_800) // 50MB
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseOpenAIStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
for i := range chunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
}
}
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
for i := range doneChunks {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return &cliproxyexecutor.StreamResult{Headers: httpResp.Header.Clone(), Chunks: out}, nil
}
func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
baseModel := thinking.ParseSuffix(req.Model).ModelName
from := opts.SourceFormat
to := sdktranslator.FromString("openai")
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
modelName := gjson.GetBytes(body, "model").String()
if strings.TrimSpace(modelName) == "" {
modelName = baseModel
}
enc, err := tokenizerForModel(modelName)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: tokenizer init failed: %w", err)
}
count, err := countOpenAIChatTokens(enc, body)
if err != nil {
return cliproxyexecutor.Response{}, fmt.Errorf("qwen executor: token counting failed: %w", err)
}
usageJSON := buildOpenAIUsageJSON(count)
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
log.Debugf("qwen executor: refresh called")
if auth == nil {
return nil, fmt.Errorf("qwen executor: auth is nil")
}
// Expect refresh_token in metadata for OAuth-based accounts
var refreshToken string
if auth.Metadata != nil {
if v, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(v) != "" {
refreshToken = v
}
}
if strings.TrimSpace(refreshToken) == "" {
// Nothing to refresh
return auth, nil
}
svc := qwenauth.NewQwenAuth(e.cfg)
td, err := svc.RefreshTokens(ctx, refreshToken)
if err != nil {
return nil, err
}
if auth.Metadata == nil {
auth.Metadata = make(map[string]any)
}
auth.Metadata["access_token"] = td.AccessToken
if td.RefreshToken != "" {
auth.Metadata["refresh_token"] = td.RefreshToken
}
if td.ResourceURL != "" {
auth.Metadata["resource_url"] = td.ResourceURL
}
// Use "expired" for consistency with existing file format
auth.Metadata["expired"] = td.Expire
auth.Metadata["type"] = "qwen"
now := time.Now().Format(time.RFC3339)
auth.Metadata["last_refresh"] = now
return auth, nil
}
func applyQwenHeaders(r *http.Request, token string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("User-Agent", qwenUserAgent)
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
r.Header.Set("Sec-Fetch-Mode", "cors")
r.Header.Set("X-Stainless-Lang", "js")
r.Header.Set("X-Stainless-Arch", "arm64")
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
r.Header.Set("X-Stainless-Retry-Count", "0")
r.Header.Set("X-Stainless-Os", "MacOS")
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
r.Header.Set("X-Stainless-Runtime", "node")
if stream {
r.Header.Set("Accept", "text/event-stream")
return
}
r.Header.Set("Accept", "application/json")
}
func qwenCreds(a *cliproxyauth.Auth) (token, baseURL string) {
if a == nil {
return "", ""
}
if a.Attributes != nil {
if v := a.Attributes["api_key"]; v != "" {
token = v
}
if v := a.Attributes["base_url"]; v != "" {
baseURL = v
}
}
if token == "" && a.Metadata != nil {
if v, ok := a.Metadata["access_token"].(string); ok {
token = v
}
if v, ok := a.Metadata["resource_url"].(string); ok {
baseURL = fmt.Sprintf("https://%s/v1", v)
}
}
return
}
================================================
FILE: internal/runtime/executor/qwen_executor_test.go
================================================
package executor
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
)
func TestQwenExecutorParseSuffix(t *testing.T) {
tests := []struct {
name string
model string
wantBase string
wantLevel string
}{
{"no suffix", "qwen-max", "qwen-max", ""},
{"with level suffix", "qwen-max(high)", "qwen-max", "high"},
{"with budget suffix", "qwen-max(16384)", "qwen-max", "16384"},
{"complex model name", "qwen-plus-latest(medium)", "qwen-plus-latest", "medium"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := thinking.ParseSuffix(tt.model)
if result.ModelName != tt.wantBase {
t.Errorf("ParseSuffix(%q).ModelName = %q, want %q", tt.model, result.ModelName, tt.wantBase)
}
})
}
}
================================================
FILE: internal/runtime/executor/thinking_providers.go
================================================
package executor
import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/codex"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/kimi"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
)
================================================
FILE: internal/runtime/executor/token_helpers.go
================================================
package executor
import (
"fmt"
"strings"
"github.com/tidwall/gjson"
"github.com/tiktoken-go/tokenizer"
)
// tokenizerForModel returns a tokenizer codec suitable for an OpenAI-style model id.
func tokenizerForModel(model string) (tokenizer.Codec, error) {
sanitized := strings.ToLower(strings.TrimSpace(model))
switch {
case sanitized == "":
return tokenizer.Get(tokenizer.Cl100kBase)
case strings.HasPrefix(sanitized, "gpt-5"):
return tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-5.1"):
return tokenizer.ForModel(tokenizer.GPT5)
case strings.HasPrefix(sanitized, "gpt-4.1"):
return tokenizer.ForModel(tokenizer.GPT41)
case strings.HasPrefix(sanitized, "gpt-4o"):
return tokenizer.ForModel(tokenizer.GPT4o)
case strings.HasPrefix(sanitized, "gpt-4"):
return tokenizer.ForModel(tokenizer.GPT4)
case strings.HasPrefix(sanitized, "gpt-3.5"), strings.HasPrefix(sanitized, "gpt-3"):
return tokenizer.ForModel(tokenizer.GPT35Turbo)
case strings.HasPrefix(sanitized, "o1"):
return tokenizer.ForModel(tokenizer.O1)
case strings.HasPrefix(sanitized, "o3"):
return tokenizer.ForModel(tokenizer.O3)
case strings.HasPrefix(sanitized, "o4"):
return tokenizer.ForModel(tokenizer.O4Mini)
default:
return tokenizer.Get(tokenizer.O200kBase)
}
}
// countOpenAIChatTokens approximates prompt tokens for OpenAI chat completions payloads.
func countOpenAIChatTokens(enc tokenizer.Codec, payload []byte) (int64, error) {
if enc == nil {
return 0, fmt.Errorf("encoder is nil")
}
if len(payload) == 0 {
return 0, nil
}
root := gjson.ParseBytes(payload)
segments := make([]string, 0, 32)
collectOpenAIMessages(root.Get("messages"), &segments)
collectOpenAITools(root.Get("tools"), &segments)
collectOpenAIFunctions(root.Get("functions"), &segments)
collectOpenAIToolChoice(root.Get("tool_choice"), &segments)
collectOpenAIResponseFormat(root.Get("response_format"), &segments)
addIfNotEmpty(&segments, root.Get("input").String())
addIfNotEmpty(&segments, root.Get("prompt").String())
joined := strings.TrimSpace(strings.Join(segments, "\n"))
if joined == "" {
return 0, nil
}
count, err := enc.Count(joined)
if err != nil {
return 0, err
}
return int64(count), nil
}
// buildOpenAIUsageJSON returns a minimal usage structure understood by downstream translators.
func buildOpenAIUsageJSON(count int64) []byte {
return []byte(fmt.Sprintf(`{"usage":{"prompt_tokens":%d,"completion_tokens":0,"total_tokens":%d}}`, count, count))
}
func collectOpenAIMessages(messages gjson.Result, segments *[]string) {
if !messages.Exists() || !messages.IsArray() {
return
}
messages.ForEach(func(_, message gjson.Result) bool {
addIfNotEmpty(segments, message.Get("role").String())
addIfNotEmpty(segments, message.Get("name").String())
collectOpenAIContent(message.Get("content"), segments)
collectOpenAIToolCalls(message.Get("tool_calls"), segments)
collectOpenAIFunctionCall(message.Get("function_call"), segments)
return true
})
}
func collectOpenAIContent(content gjson.Result, segments *[]string) {
if !content.Exists() {
return
}
if content.Type == gjson.String {
addIfNotEmpty(segments, content.String())
return
}
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "text", "input_text", "output_text":
addIfNotEmpty(segments, part.Get("text").String())
case "image_url":
addIfNotEmpty(segments, part.Get("image_url.url").String())
case "input_audio", "output_audio", "audio":
addIfNotEmpty(segments, part.Get("id").String())
case "tool_result":
addIfNotEmpty(segments, part.Get("name").String())
collectOpenAIContent(part.Get("content"), segments)
default:
if part.IsArray() {
collectOpenAIContent(part, segments)
return true
}
if part.Type == gjson.JSON {
addIfNotEmpty(segments, part.Raw)
return true
}
addIfNotEmpty(segments, part.String())
}
return true
})
return
}
if content.Type == gjson.JSON {
addIfNotEmpty(segments, content.Raw)
}
}
func collectOpenAIToolCalls(calls gjson.Result, segments *[]string) {
if !calls.Exists() || !calls.IsArray() {
return
}
calls.ForEach(func(_, call gjson.Result) bool {
addIfNotEmpty(segments, call.Get("id").String())
addIfNotEmpty(segments, call.Get("type").String())
function := call.Get("function")
if function.Exists() {
addIfNotEmpty(segments, function.Get("name").String())
addIfNotEmpty(segments, function.Get("description").String())
addIfNotEmpty(segments, function.Get("arguments").String())
if params := function.Get("parameters"); params.Exists() {
addIfNotEmpty(segments, params.Raw)
}
}
return true
})
}
func collectOpenAIFunctionCall(call gjson.Result, segments *[]string) {
if !call.Exists() {
return
}
addIfNotEmpty(segments, call.Get("name").String())
addIfNotEmpty(segments, call.Get("arguments").String())
}
func collectOpenAITools(tools gjson.Result, segments *[]string) {
if !tools.Exists() {
return
}
if tools.IsArray() {
tools.ForEach(func(_, tool gjson.Result) bool {
appendToolPayload(tool, segments)
return true
})
return
}
appendToolPayload(tools, segments)
}
func collectOpenAIFunctions(functions gjson.Result, segments *[]string) {
if !functions.Exists() || !functions.IsArray() {
return
}
functions.ForEach(func(_, function gjson.Result) bool {
addIfNotEmpty(segments, function.Get("name").String())
addIfNotEmpty(segments, function.Get("description").String())
if params := function.Get("parameters"); params.Exists() {
addIfNotEmpty(segments, params.Raw)
}
return true
})
}
func collectOpenAIToolChoice(choice gjson.Result, segments *[]string) {
if !choice.Exists() {
return
}
if choice.Type == gjson.String {
addIfNotEmpty(segments, choice.String())
return
}
addIfNotEmpty(segments, choice.Raw)
}
func collectOpenAIResponseFormat(format gjson.Result, segments *[]string) {
if !format.Exists() {
return
}
addIfNotEmpty(segments, format.Get("type").String())
addIfNotEmpty(segments, format.Get("name").String())
if schema := format.Get("json_schema"); schema.Exists() {
addIfNotEmpty(segments, schema.Raw)
}
if schema := format.Get("schema"); schema.Exists() {
addIfNotEmpty(segments, schema.Raw)
}
}
func appendToolPayload(tool gjson.Result, segments *[]string) {
if !tool.Exists() {
return
}
addIfNotEmpty(segments, tool.Get("type").String())
addIfNotEmpty(segments, tool.Get("name").String())
addIfNotEmpty(segments, tool.Get("description").String())
if function := tool.Get("function"); function.Exists() {
addIfNotEmpty(segments, function.Get("name").String())
addIfNotEmpty(segments, function.Get("description").String())
if params := function.Get("parameters"); params.Exists() {
addIfNotEmpty(segments, params.Raw)
}
}
}
func addIfNotEmpty(segments *[]string, value string) {
if segments == nil {
return
}
if trimmed := strings.TrimSpace(value); trimmed != "" {
*segments = append(*segments, trimmed)
}
}
================================================
FILE: internal/runtime/executor/usage_helpers.go
================================================
package executor
import (
"bytes"
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type usageReporter struct {
provider string
model string
authID string
authIndex string
apiKey string
source string
requestedAt time.Time
once sync.Once
}
func newUsageReporter(ctx context.Context, provider, model string, auth *cliproxyauth.Auth) *usageReporter {
apiKey := apiKeyFromContext(ctx)
reporter := &usageReporter{
provider: provider,
model: model,
requestedAt: time.Now(),
apiKey: apiKey,
source: resolveUsageSource(auth, apiKey),
}
if auth != nil {
reporter.authID = auth.ID
reporter.authIndex = auth.EnsureIndex()
}
return reporter
}
func (r *usageReporter) publish(ctx context.Context, detail usage.Detail) {
r.publishWithOutcome(ctx, detail, false)
}
func (r *usageReporter) publishFailure(ctx context.Context) {
r.publishWithOutcome(ctx, usage.Detail{}, true)
}
func (r *usageReporter) trackFailure(ctx context.Context, errPtr *error) {
if r == nil || errPtr == nil {
return
}
if *errPtr != nil {
r.publishFailure(ctx)
}
}
func (r *usageReporter) publishWithOutcome(ctx context.Context, detail usage.Detail, failed bool) {
if r == nil {
return
}
if detail.TotalTokens == 0 {
total := detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
if total > 0 {
detail.TotalTokens = total
}
}
if detail.InputTokens == 0 && detail.OutputTokens == 0 && detail.ReasoningTokens == 0 && detail.CachedTokens == 0 && detail.TotalTokens == 0 && !failed {
return
}
r.once.Do(func() {
usage.PublishRecord(ctx, usage.Record{
Provider: r.provider,
Model: r.model,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
RequestedAt: r.requestedAt,
Failed: failed,
Detail: detail,
})
})
}
// ensurePublished guarantees that a usage record is emitted exactly once.
// It is safe to call multiple times; only the first call wins due to once.Do.
// This is used to ensure request counting even when upstream responses do not
// include any usage fields (tokens), especially for streaming paths.
func (r *usageReporter) ensurePublished(ctx context.Context) {
if r == nil {
return
}
r.once.Do(func() {
usage.PublishRecord(ctx, usage.Record{
Provider: r.provider,
Model: r.model,
Source: r.source,
APIKey: r.apiKey,
AuthID: r.authID,
AuthIndex: r.authIndex,
RequestedAt: r.requestedAt,
Failed: false,
Detail: usage.Detail{},
})
})
}
func apiKeyFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil {
return ""
}
if v, exists := ginCtx.Get("apiKey"); exists {
switch value := v.(type) {
case string:
return value
case fmt.Stringer:
return value.String()
default:
return fmt.Sprintf("%v", value)
}
}
return ""
}
func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
if auth != nil {
provider := strings.TrimSpace(auth.Provider)
if strings.EqualFold(provider, "gemini-cli") {
if id := strings.TrimSpace(auth.ID); id != "" {
return id
}
}
if strings.EqualFold(provider, "vertex") {
if auth.Metadata != nil {
if projectID, ok := auth.Metadata["project_id"].(string); ok {
if trimmed := strings.TrimSpace(projectID); trimmed != "" {
return trimmed
}
}
if project, ok := auth.Metadata["project"].(string); ok {
if trimmed := strings.TrimSpace(project); trimmed != "" {
return trimmed
}
}
}
}
if _, value := auth.AccountInfo(); value != "" {
return strings.TrimSpace(value)
}
if auth.Metadata != nil {
if email, ok := auth.Metadata["email"].(string); ok {
if trimmed := strings.TrimSpace(email); trimmed != "" {
return trimmed
}
}
}
if auth.Attributes != nil {
if key := strings.TrimSpace(auth.Attributes["api_key"]); key != "" {
return key
}
}
}
if trimmed := strings.TrimSpace(ctxAPIKey); trimmed != "" {
return trimmed
}
return ""
}
func parseCodexUsage(data []byte) (usage.Detail, bool) {
usageNode := gjson.ParseBytes(data).Get("response.usage")
if !usageNode.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: usageNode.Get("input_tokens").Int(),
OutputTokens: usageNode.Get("output_tokens").Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if cached := usageNode.Get("input_tokens_details.cached_tokens"); cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("output_tokens_details.reasoning_tokens"); reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail, true
}
func parseOpenAIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
}
inputNode := usageNode.Get("prompt_tokens")
if !inputNode.Exists() {
inputNode = usageNode.Get("input_tokens")
}
outputNode := usageNode.Get("completion_tokens")
if !outputNode.Exists() {
outputNode = usageNode.Get("output_tokens")
}
detail := usage.Detail{
InputTokens: inputNode.Int(),
OutputTokens: outputNode.Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
cached := usageNode.Get("prompt_tokens_details.cached_tokens")
if !cached.Exists() {
cached = usageNode.Get("input_tokens_details.cached_tokens")
}
if cached.Exists() {
detail.CachedTokens = cached.Int()
}
reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens")
if !reasoning.Exists() {
reasoning = usageNode.Get("output_tokens_details.reasoning_tokens")
}
if reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail
}
func parseOpenAIStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
usageNode := gjson.GetBytes(payload, "usage")
if !usageNode.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: usageNode.Get("prompt_tokens").Int(),
OutputTokens: usageNode.Get("completion_tokens").Int(),
TotalTokens: usageNode.Get("total_tokens").Int(),
}
if cached := usageNode.Get("prompt_tokens_details.cached_tokens"); cached.Exists() {
detail.CachedTokens = cached.Int()
}
if reasoning := usageNode.Get("completion_tokens_details.reasoning_tokens"); reasoning.Exists() {
detail.ReasoningTokens = reasoning.Int()
}
return detail, true
}
func parseClaudeUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data).Get("usage")
if !usageNode.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: usageNode.Get("input_tokens").Int(),
OutputTokens: usageNode.Get("output_tokens").Int(),
CachedTokens: usageNode.Get("cache_read_input_tokens").Int(),
}
if detail.CachedTokens == 0 {
// fall back to creation tokens when read tokens are absent
detail.CachedTokens = usageNode.Get("cache_creation_input_tokens").Int()
}
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
return detail
}
func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
usageNode := gjson.GetBytes(payload, "usage")
if !usageNode.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: usageNode.Get("input_tokens").Int(),
OutputTokens: usageNode.Get("output_tokens").Int(),
CachedTokens: usageNode.Get("cache_read_input_tokens").Int(),
}
if detail.CachedTokens == 0 {
detail.CachedTokens = usageNode.Get("cache_creation_input_tokens").Int()
}
detail.TotalTokens = detail.InputTokens + detail.OutputTokens
return detail, true
}
func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
CachedTokens: node.Get("cachedContentTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
}
func parseGeminiCLIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
if !node.Exists() {
node = usageNode.Get("response.usage_metadata")
}
if !node.Exists() {
return usage.Detail{}
}
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("usageMetadata")
if !node.Exists() {
node = usageNode.Get("usage_metadata")
}
if !node.Exists() {
return usage.Detail{}
}
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
node := gjson.GetBytes(payload, "usageMetadata")
if !node.Exists() {
node = gjson.GetBytes(payload, "usage_metadata")
}
if !node.Exists() {
return usage.Detail{}, false
}
return parseGeminiFamilyUsageDetail(node), true
}
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
node := gjson.GetBytes(payload, "response.usageMetadata")
if !node.Exists() {
node = gjson.GetBytes(payload, "usage_metadata")
}
if !node.Exists() {
return usage.Detail{}, false
}
return parseGeminiFamilyUsageDetail(node), true
}
func parseAntigravityUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
if !node.Exists() {
node = usageNode.Get("usageMetadata")
}
if !node.Exists() {
node = usageNode.Get("usage_metadata")
}
if !node.Exists() {
return usage.Detail{}
}
return parseGeminiFamilyUsageDetail(node)
}
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
payload := jsonPayload(line)
if len(payload) == 0 || !gjson.ValidBytes(payload) {
return usage.Detail{}, false
}
node := gjson.GetBytes(payload, "response.usageMetadata")
if !node.Exists() {
node = gjson.GetBytes(payload, "usageMetadata")
}
if !node.Exists() {
node = gjson.GetBytes(payload, "usage_metadata")
}
if !node.Exists() {
return usage.Detail{}, false
}
return parseGeminiFamilyUsageDetail(node), true
}
var stopChunkWithoutUsage sync.Map
func rememberStopWithoutUsage(traceID string) {
stopChunkWithoutUsage.Store(traceID, struct{}{})
time.AfterFunc(10*time.Minute, func() { stopChunkWithoutUsage.Delete(traceID) })
}
// FilterSSEUsageMetadata removes usageMetadata from SSE events that are not
// terminal (finishReason != "stop"). Stop chunks are left untouched. This
// function is shared between aistudio and antigravity executors.
func FilterSSEUsageMetadata(payload []byte) []byte {
if len(payload) == 0 {
return payload
}
lines := bytes.Split(payload, []byte("\n"))
modified := false
foundData := false
for idx, line := range lines {
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) {
continue
}
foundData = true
dataIdx := bytes.Index(line, []byte("data:"))
if dataIdx < 0 {
continue
}
rawJSON := bytes.TrimSpace(line[dataIdx+5:])
traceID := gjson.GetBytes(rawJSON, "traceId").String()
if isStopChunkWithoutUsage(rawJSON) && traceID != "" {
rememberStopWithoutUsage(traceID)
continue
}
if traceID != "" {
if _, ok := stopChunkWithoutUsage.Load(traceID); ok && hasUsageMetadata(rawJSON) {
stopChunkWithoutUsage.Delete(traceID)
continue
}
}
cleaned, changed := StripUsageMetadataFromJSON(rawJSON)
if !changed {
continue
}
var rebuilt []byte
rebuilt = append(rebuilt, line[:dataIdx]...)
rebuilt = append(rebuilt, []byte("data:")...)
if len(cleaned) > 0 {
rebuilt = append(rebuilt, ' ')
rebuilt = append(rebuilt, cleaned...)
}
lines[idx] = rebuilt
modified = true
}
if !modified {
if !foundData {
// Handle payloads that are raw JSON without SSE data: prefix.
trimmed := bytes.TrimSpace(payload)
cleaned, changed := StripUsageMetadataFromJSON(trimmed)
if !changed {
return payload
}
return cleaned
}
return payload
}
return bytes.Join(lines, []byte("\n"))
}
// StripUsageMetadataFromJSON drops usageMetadata unless finishReason is present (terminal).
// It handles both formats:
// - Aistudio: candidates.0.finishReason
// - Antigravity: response.candidates.0.finishReason
func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
jsonBytes := bytes.TrimSpace(rawJSON)
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
return rawJSON, false
}
// Check for finishReason in both aistudio and antigravity formats
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
if !finishReason.Exists() {
finishReason = gjson.GetBytes(jsonBytes, "response.candidates.0.finishReason")
}
terminalReason := finishReason.Exists() && strings.TrimSpace(finishReason.String()) != ""
usageMetadata := gjson.GetBytes(jsonBytes, "usageMetadata")
if !usageMetadata.Exists() {
usageMetadata = gjson.GetBytes(jsonBytes, "response.usageMetadata")
}
// Terminal chunk: keep as-is.
if terminalReason {
return rawJSON, false
}
// Nothing to strip
if !usageMetadata.Exists() {
return rawJSON, false
}
// Remove usageMetadata from both possible locations
cleaned := jsonBytes
var changed bool
if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
changed = true
}
if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
changed = true
}
return cleaned, changed
}
func hasUsageMetadata(jsonBytes []byte) bool {
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
return false
}
if gjson.GetBytes(jsonBytes, "usageMetadata").Exists() {
return true
}
if gjson.GetBytes(jsonBytes, "response.usageMetadata").Exists() {
return true
}
return false
}
func isStopChunkWithoutUsage(jsonBytes []byte) bool {
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
return false
}
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
if !finishReason.Exists() {
finishReason = gjson.GetBytes(jsonBytes, "response.candidates.0.finishReason")
}
trimmed := strings.TrimSpace(finishReason.String())
if !finishReason.Exists() || trimmed == "" {
return false
}
return !hasUsageMetadata(jsonBytes)
}
func jsonPayload(line []byte) []byte {
trimmed := bytes.TrimSpace(line)
if len(trimmed) == 0 {
return nil
}
if bytes.Equal(trimmed, []byte("[DONE]")) {
return nil
}
if bytes.HasPrefix(trimmed, []byte("event:")) {
return nil
}
if bytes.HasPrefix(trimmed, []byte("data:")) {
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
}
if len(trimmed) == 0 || trimmed[0] != '{' {
return nil
}
return trimmed
}
================================================
FILE: internal/runtime/executor/usage_helpers_test.go
================================================
package executor
import "testing"
func TestParseOpenAIUsageChatCompletions(t *testing.T) {
data := []byte(`{"usage":{"prompt_tokens":1,"completion_tokens":2,"total_tokens":3,"prompt_tokens_details":{"cached_tokens":4},"completion_tokens_details":{"reasoning_tokens":5}}}`)
detail := parseOpenAIUsage(data)
if detail.InputTokens != 1 {
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 1)
}
if detail.OutputTokens != 2 {
t.Fatalf("output tokens = %d, want %d", detail.OutputTokens, 2)
}
if detail.TotalTokens != 3 {
t.Fatalf("total tokens = %d, want %d", detail.TotalTokens, 3)
}
if detail.CachedTokens != 4 {
t.Fatalf("cached tokens = %d, want %d", detail.CachedTokens, 4)
}
if detail.ReasoningTokens != 5 {
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 5)
}
}
func TestParseOpenAIUsageResponses(t *testing.T) {
data := []byte(`{"usage":{"input_tokens":10,"output_tokens":20,"total_tokens":30,"input_tokens_details":{"cached_tokens":7},"output_tokens_details":{"reasoning_tokens":9}}}`)
detail := parseOpenAIUsage(data)
if detail.InputTokens != 10 {
t.Fatalf("input tokens = %d, want %d", detail.InputTokens, 10)
}
if detail.OutputTokens != 20 {
t.Fatalf("output tokens = %d, want %d", detail.OutputTokens, 20)
}
if detail.TotalTokens != 30 {
t.Fatalf("total tokens = %d, want %d", detail.TotalTokens, 30)
}
if detail.CachedTokens != 7 {
t.Fatalf("cached tokens = %d, want %d", detail.CachedTokens, 7)
}
if detail.ReasoningTokens != 9 {
t.Fatalf("reasoning tokens = %d, want %d", detail.ReasoningTokens, 9)
}
}
================================================
FILE: internal/runtime/executor/user_id_cache.go
================================================
package executor
import (
"crypto/sha256"
"encoding/hex"
"sync"
"time"
)
type userIDCacheEntry struct {
value string
expire time.Time
}
var (
userIDCache = make(map[string]userIDCacheEntry)
userIDCacheMu sync.RWMutex
userIDCacheCleanupOnce sync.Once
)
const (
userIDTTL = time.Hour
userIDCacheCleanupPeriod = 15 * time.Minute
)
func startUserIDCacheCleanup() {
go func() {
ticker := time.NewTicker(userIDCacheCleanupPeriod)
defer ticker.Stop()
for range ticker.C {
purgeExpiredUserIDs()
}
}()
}
func purgeExpiredUserIDs() {
now := time.Now()
userIDCacheMu.Lock()
for key, entry := range userIDCache {
if !entry.expire.After(now) {
delete(userIDCache, key)
}
}
userIDCacheMu.Unlock()
}
func userIDCacheKey(apiKey string) string {
sum := sha256.Sum256([]byte(apiKey))
return hex.EncodeToString(sum[:])
}
func cachedUserID(apiKey string) string {
if apiKey == "" {
return generateFakeUserID()
}
userIDCacheCleanupOnce.Do(startUserIDCacheCleanup)
key := userIDCacheKey(apiKey)
now := time.Now()
userIDCacheMu.RLock()
entry, ok := userIDCache[key]
valid := ok && entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value)
userIDCacheMu.RUnlock()
if valid {
userIDCacheMu.Lock()
entry = userIDCache[key]
if entry.value != "" && entry.expire.After(now) && isValidUserID(entry.value) {
entry.expire = now.Add(userIDTTL)
userIDCache[key] = entry
userIDCacheMu.Unlock()
return entry.value
}
userIDCacheMu.Unlock()
}
newID := generateFakeUserID()
userIDCacheMu.Lock()
entry, ok = userIDCache[key]
if !ok || entry.value == "" || !entry.expire.After(now) || !isValidUserID(entry.value) {
entry.value = newID
}
entry.expire = now.Add(userIDTTL)
userIDCache[key] = entry
userIDCacheMu.Unlock()
return entry.value
}
================================================
FILE: internal/runtime/executor/user_id_cache_test.go
================================================
package executor
import (
"testing"
"time"
)
func resetUserIDCache() {
userIDCacheMu.Lock()
userIDCache = make(map[string]userIDCacheEntry)
userIDCacheMu.Unlock()
}
func TestCachedUserID_ReusesWithinTTL(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-1")
if first == "" {
t.Fatal("expected generated user_id to be non-empty")
}
if first != second {
t.Fatalf("expected cached user_id to be reused, got %q and %q", first, second)
}
}
func TestCachedUserID_ExpiresAfterTTL(t *testing.T) {
resetUserIDCache()
expiredID := cachedUserID("api-key-expired")
cacheKey := userIDCacheKey("api-key-expired")
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
value: expiredID,
expire: time.Now().Add(-time.Minute),
}
userIDCacheMu.Unlock()
newID := cachedUserID("api-key-expired")
if newID == expiredID {
t.Fatalf("expected expired user_id to be replaced, got %q", newID)
}
if newID == "" {
t.Fatal("expected regenerated user_id to be non-empty")
}
}
func TestCachedUserID_IsScopedByAPIKey(t *testing.T) {
resetUserIDCache()
first := cachedUserID("api-key-1")
second := cachedUserID("api-key-2")
if first == second {
t.Fatalf("expected different API keys to have different user_ids, got %q", first)
}
}
func TestCachedUserID_RenewsTTLOnHit(t *testing.T) {
resetUserIDCache()
key := "api-key-renew"
id := cachedUserID(key)
cacheKey := userIDCacheKey(key)
soon := time.Now()
userIDCacheMu.Lock()
userIDCache[cacheKey] = userIDCacheEntry{
value: id,
expire: soon.Add(2 * time.Second),
}
userIDCacheMu.Unlock()
if refreshed := cachedUserID(key); refreshed != id {
t.Fatalf("expected cached user_id to be reused before expiry, got %q", refreshed)
}
userIDCacheMu.RLock()
entry := userIDCache[cacheKey]
userIDCacheMu.RUnlock()
if entry.expire.Sub(soon) < 30*time.Minute {
t.Fatalf("expected TTL to renew, got %v remaining", entry.expire.Sub(soon))
}
}
================================================
FILE: internal/runtime/geminicli/state.go
================================================
package geminicli
import (
"strings"
"sync"
)
// SharedCredential keeps canonical OAuth metadata for a multi-project Gemini CLI login.
type SharedCredential struct {
primaryID string
email string
metadata map[string]any
projectIDs []string
mu sync.RWMutex
}
// NewSharedCredential builds a shared credential container for the given primary entry.
func NewSharedCredential(primaryID, email string, metadata map[string]any, projectIDs []string) *SharedCredential {
return &SharedCredential{
primaryID: strings.TrimSpace(primaryID),
email: strings.TrimSpace(email),
metadata: cloneMap(metadata),
projectIDs: cloneStrings(projectIDs),
}
}
// PrimaryID returns the owning credential identifier.
func (s *SharedCredential) PrimaryID() string {
if s == nil {
return ""
}
return s.primaryID
}
// Email returns the associated account email.
func (s *SharedCredential) Email() string {
if s == nil {
return ""
}
return s.email
}
// ProjectIDs returns a snapshot of the configured project identifiers.
func (s *SharedCredential) ProjectIDs() []string {
if s == nil {
return nil
}
return cloneStrings(s.projectIDs)
}
// MetadataSnapshot returns a deep copy of the stored OAuth metadata.
func (s *SharedCredential) MetadataSnapshot() map[string]any {
if s == nil {
return nil
}
s.mu.RLock()
defer s.mu.RUnlock()
return cloneMap(s.metadata)
}
// MergeMetadata merges the provided fields into the shared metadata and returns an updated copy.
func (s *SharedCredential) MergeMetadata(values map[string]any) map[string]any {
if s == nil {
return nil
}
if len(values) == 0 {
return s.MetadataSnapshot()
}
s.mu.Lock()
defer s.mu.Unlock()
if s.metadata == nil {
s.metadata = make(map[string]any, len(values))
}
for k, v := range values {
if v == nil {
delete(s.metadata, k)
continue
}
s.metadata[k] = v
}
return cloneMap(s.metadata)
}
// SetProjectIDs updates the stored project identifiers.
func (s *SharedCredential) SetProjectIDs(ids []string) {
if s == nil {
return
}
s.mu.Lock()
s.projectIDs = cloneStrings(ids)
s.mu.Unlock()
}
// VirtualCredential tracks a per-project virtual auth entry that reuses a primary credential.
type VirtualCredential struct {
ProjectID string
Parent *SharedCredential
}
// NewVirtualCredential creates a virtual credential descriptor bound to the shared parent.
func NewVirtualCredential(projectID string, parent *SharedCredential) *VirtualCredential {
return &VirtualCredential{ProjectID: strings.TrimSpace(projectID), Parent: parent}
}
// ResolveSharedCredential returns the shared credential backing the provided runtime payload.
func ResolveSharedCredential(runtime any) *SharedCredential {
switch typed := runtime.(type) {
case *SharedCredential:
return typed
case *VirtualCredential:
return typed.Parent
default:
return nil
}
}
// IsVirtual reports whether the runtime payload represents a virtual credential.
func IsVirtual(runtime any) bool {
if runtime == nil {
return false
}
_, ok := runtime.(*VirtualCredential)
return ok
}
func cloneMap(in map[string]any) map[string]any {
if len(in) == 0 {
return nil
}
out := make(map[string]any, len(in))
for k, v := range in {
out[k] = v
}
return out
}
func cloneStrings(in []string) []string {
if len(in) == 0 {
return nil
}
out := make([]string, len(in))
copy(out, in)
return out
}
================================================
FILE: internal/store/gitstore.go
================================================
package store
import (
"context"
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/go-git/go-git/v6"
"github.com/go-git/go-git/v6/config"
"github.com/go-git/go-git/v6/plumbing"
"github.com/go-git/go-git/v6/plumbing/object"
"github.com/go-git/go-git/v6/plumbing/transport"
"github.com/go-git/go-git/v6/plumbing/transport/http"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// gcInterval defines minimum time between garbage collection runs.
const gcInterval = 5 * time.Minute
// GitTokenStore persists token records and auth metadata using git as the backing storage.
type GitTokenStore struct {
mu sync.Mutex
dirLock sync.RWMutex
baseDir string
repoDir string
configDir string
remote string
username string
password string
lastGC time.Time
}
// NewGitTokenStore creates a token store that saves credentials to disk through the
// TokenStorage implementation embedded in the token record.
func NewGitTokenStore(remote, username, password string) *GitTokenStore {
return &GitTokenStore{
remote: remote,
username: username,
password: password,
}
}
// SetBaseDir updates the default directory used for auth JSON persistence when no explicit path is provided.
func (s *GitTokenStore) SetBaseDir(dir string) {
clean := strings.TrimSpace(dir)
if clean == "" {
s.dirLock.Lock()
s.baseDir = ""
s.repoDir = ""
s.configDir = ""
s.dirLock.Unlock()
return
}
if abs, err := filepath.Abs(clean); err == nil {
clean = abs
}
repoDir := filepath.Dir(clean)
if repoDir == "" || repoDir == "." {
repoDir = clean
}
configDir := filepath.Join(repoDir, "config")
s.dirLock.Lock()
s.baseDir = clean
s.repoDir = repoDir
s.configDir = configDir
s.dirLock.Unlock()
}
// AuthDir returns the directory used for auth persistence.
func (s *GitTokenStore) AuthDir() string {
return s.baseDirSnapshot()
}
// ConfigPath returns the managed config file path.
func (s *GitTokenStore) ConfigPath() string {
s.dirLock.RLock()
defer s.dirLock.RUnlock()
if s.configDir == "" {
return ""
}
return filepath.Join(s.configDir, "config.yaml")
}
// EnsureRepository prepares the local git working tree by cloning or opening the repository.
func (s *GitTokenStore) EnsureRepository() error {
s.dirLock.Lock()
if s.remote == "" {
s.dirLock.Unlock()
return fmt.Errorf("git token store: remote not configured")
}
if s.baseDir == "" {
s.dirLock.Unlock()
return fmt.Errorf("git token store: base directory not configured")
}
repoDir := s.repoDir
if repoDir == "" {
repoDir = filepath.Dir(s.baseDir)
if repoDir == "" || repoDir == "." {
repoDir = s.baseDir
}
s.repoDir = repoDir
}
if s.configDir == "" {
s.configDir = filepath.Join(repoDir, "config")
}
authDir := filepath.Join(repoDir, "auths")
configDir := filepath.Join(repoDir, "config")
gitDir := filepath.Join(repoDir, ".git")
authMethod := s.gitAuth()
var initPaths []string
if _, err := os.Stat(gitDir); errors.Is(err, fs.ErrNotExist) {
if errMk := os.MkdirAll(repoDir, 0o700); errMk != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create repo dir: %w", errMk)
}
if _, errClone := git.PlainClone(repoDir, &git.CloneOptions{Auth: authMethod, URL: s.remote}); errClone != nil {
if errors.Is(errClone, transport.ErrEmptyRemoteRepository) {
_ = os.RemoveAll(gitDir)
repo, errInit := git.PlainInit(repoDir, false)
if errInit != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: init empty repo: %w", errInit)
}
if _, errRemote := repo.Remote("origin"); errRemote != nil {
if _, errCreate := repo.CreateRemote(&config.RemoteConfig{
Name: "origin",
URLs: []string{s.remote},
}); errCreate != nil && !errors.Is(errCreate, git.ErrRemoteExists) {
s.dirLock.Unlock()
return fmt.Errorf("git token store: configure remote: %w", errCreate)
}
}
if err := os.MkdirAll(authDir, 0o700); err != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create auth dir: %w", err)
}
if err := os.MkdirAll(configDir, 0o700); err != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create config dir: %w", err)
}
if err := ensureEmptyFile(filepath.Join(authDir, ".gitkeep")); err != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create auth placeholder: %w", err)
}
if err := ensureEmptyFile(filepath.Join(configDir, ".gitkeep")); err != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create config placeholder: %w", err)
}
initPaths = []string{
filepath.Join("auths", ".gitkeep"),
filepath.Join("config", ".gitkeep"),
}
} else {
s.dirLock.Unlock()
return fmt.Errorf("git token store: clone remote: %w", errClone)
}
}
} else if err != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: stat repo: %w", err)
} else {
repo, errOpen := git.PlainOpen(repoDir)
if errOpen != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: open repo: %w", errOpen)
}
worktree, errWorktree := repo.Worktree()
if errWorktree != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: worktree: %w", errWorktree)
}
if errPull := worktree.Pull(&git.PullOptions{Auth: authMethod, RemoteName: "origin"}); errPull != nil {
switch {
case errors.Is(errPull, git.NoErrAlreadyUpToDate),
errors.Is(errPull, git.ErrUnstagedChanges),
errors.Is(errPull, git.ErrNonFastForwardUpdate):
// Ignore clean syncs, local edits, and remote divergence—local changes win.
case errors.Is(errPull, transport.ErrAuthenticationRequired),
errors.Is(errPull, plumbing.ErrReferenceNotFound),
errors.Is(errPull, transport.ErrEmptyRemoteRepository):
// Ignore authentication prompts and empty remote references on initial sync.
default:
s.dirLock.Unlock()
return fmt.Errorf("git token store: pull: %w", errPull)
}
}
}
if err := os.MkdirAll(s.baseDir, 0o700); err != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create auth dir: %w", err)
}
if err := os.MkdirAll(s.configDir, 0o700); err != nil {
s.dirLock.Unlock()
return fmt.Errorf("git token store: create config dir: %w", err)
}
s.dirLock.Unlock()
if len(initPaths) > 0 {
s.mu.Lock()
err := s.commitAndPushLocked("Initialize git token store", initPaths...)
s.mu.Unlock()
if err != nil {
return err
}
}
return nil
}
// Save persists token storage and metadata to the resolved auth file path.
func (s *GitTokenStore) Save(_ context.Context, auth *cliproxyauth.Auth) (string, error) {
if auth == nil {
return "", fmt.Errorf("auth filestore: auth is nil")
}
path, err := s.resolveAuthPath(auth)
if err != nil {
return "", err
}
if path == "" {
return "", fmt.Errorf("auth filestore: missing file path attribute for %s", auth.ID)
}
if auth.Disabled {
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
return "", nil
}
}
if err = s.EnsureRepository(); err != nil {
return "", err
}
s.mu.Lock()
defer s.mu.Unlock()
if err = os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
}
switch {
case auth.Storage != nil:
if err = auth.Storage.SaveTokenToFile(path); err != nil {
return "", err
}
case auth.Metadata != nil:
raw, errMarshal := json.Marshal(auth.Metadata)
if errMarshal != nil {
return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal)
}
if existing, errRead := os.ReadFile(path); errRead == nil {
if jsonEqual(existing, raw) {
return path, nil
}
} else if !os.IsNotExist(errRead) {
return "", fmt.Errorf("auth filestore: read existing failed: %w", errRead)
}
tmp := path + ".tmp"
if errWrite := os.WriteFile(tmp, raw, 0o600); errWrite != nil {
return "", fmt.Errorf("auth filestore: write temp failed: %w", errWrite)
}
if errRename := os.Rename(tmp, path); errRename != nil {
return "", fmt.Errorf("auth filestore: rename failed: %w", errRename)
}
default:
return "", fmt.Errorf("auth filestore: nothing to persist for %s", auth.ID)
}
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
auth.Attributes["path"] = path
if strings.TrimSpace(auth.FileName) == "" {
auth.FileName = auth.ID
}
relPath, errRel := s.relativeToRepo(path)
if errRel != nil {
return "", errRel
}
messageID := auth.ID
if strings.TrimSpace(messageID) == "" {
messageID = filepath.Base(path)
}
if errCommit := s.commitAndPushLocked(fmt.Sprintf("Update auth %s", strings.TrimSpace(messageID)), relPath); errCommit != nil {
return "", errCommit
}
return path, nil
}
// List enumerates all auth JSON files under the configured directory.
func (s *GitTokenStore) List(_ context.Context) ([]*cliproxyauth.Auth, error) {
if err := s.EnsureRepository(); err != nil {
return nil, err
}
dir := s.baseDirSnapshot()
if dir == "" {
return nil, fmt.Errorf("auth filestore: directory not configured")
}
entries := make([]*cliproxyauth.Auth, 0)
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
return nil
}
auth, err := s.readAuthFile(path, dir)
if err != nil {
return nil
}
if auth != nil {
entries = append(entries, auth)
}
return nil
})
if err != nil {
return nil, err
}
return entries, nil
}
// Delete removes the auth file.
func (s *GitTokenStore) Delete(_ context.Context, id string) error {
id = strings.TrimSpace(id)
if id == "" {
return fmt.Errorf("auth filestore: id is empty")
}
path, err := s.resolveDeletePath(id)
if err != nil {
return err
}
if err = s.EnsureRepository(); err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if err = os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("auth filestore: delete failed: %w", err)
}
if err == nil {
rel, errRel := s.relativeToRepo(path)
if errRel != nil {
return errRel
}
messageID := id
if errCommit := s.commitAndPushLocked(fmt.Sprintf("Delete auth %s", messageID), rel); errCommit != nil {
return errCommit
}
}
return nil
}
// PersistAuthFiles commits and pushes the provided paths to the remote repository.
// It no-ops when the store is not fully configured or when there are no paths.
func (s *GitTokenStore) PersistAuthFiles(_ context.Context, message string, paths ...string) error {
if len(paths) == 0 {
return nil
}
if err := s.EnsureRepository(); err != nil {
return err
}
filtered := make([]string, 0, len(paths))
for _, p := range paths {
trimmed := strings.TrimSpace(p)
if trimmed == "" {
continue
}
rel, err := s.relativeToRepo(trimmed)
if err != nil {
return err
}
filtered = append(filtered, rel)
}
if len(filtered) == 0 {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
if strings.TrimSpace(message) == "" {
message = "Sync watcher updates"
}
return s.commitAndPushLocked(message, filtered...)
}
func (s *GitTokenStore) resolveDeletePath(id string) (string, error) {
if strings.ContainsRune(id, os.PathSeparator) || filepath.IsAbs(id) {
return id, nil
}
dir := s.baseDirSnapshot()
if dir == "" {
return "", fmt.Errorf("auth filestore: directory not configured")
}
return filepath.Join(dir, id), nil
}
func (s *GitTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read file: %w", err)
}
if len(data) == 0 {
return nil, nil
}
metadata := make(map[string]any)
if err = json.Unmarshal(data, &metadata); err != nil {
return nil, fmt.Errorf("unmarshal auth json: %w", err)
}
provider, _ := metadata["type"].(string)
if provider == "" {
provider = "unknown"
}
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("stat file: %w", err)
}
id := s.idFor(path, baseDir)
auth := &cliproxyauth.Auth{
ID: id,
Provider: provider,
FileName: id,
Label: s.labelFor(metadata),
Status: cliproxyauth.StatusActive,
Attributes: map[string]string{"path": path},
Metadata: metadata,
CreatedAt: info.ModTime(),
UpdatedAt: info.ModTime(),
LastRefreshedAt: time.Time{},
NextRefreshAfter: time.Time{},
}
if email, ok := metadata["email"].(string); ok && email != "" {
auth.Attributes["email"] = email
}
return auth, nil
}
func (s *GitTokenStore) idFor(path, baseDir string) string {
if baseDir == "" {
return path
}
rel, err := filepath.Rel(baseDir, path)
if err != nil {
return path
}
return rel
}
func (s *GitTokenStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, error) {
if auth == nil {
return "", fmt.Errorf("auth filestore: auth is nil")
}
if auth.Attributes != nil {
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
return p, nil
}
}
if fileName := strings.TrimSpace(auth.FileName); fileName != "" {
if filepath.IsAbs(fileName) {
return fileName, nil
}
if dir := s.baseDirSnapshot(); dir != "" {
return filepath.Join(dir, fileName), nil
}
return fileName, nil
}
if auth.ID == "" {
return "", fmt.Errorf("auth filestore: missing id")
}
if filepath.IsAbs(auth.ID) {
return auth.ID, nil
}
dir := s.baseDirSnapshot()
if dir == "" {
return "", fmt.Errorf("auth filestore: directory not configured")
}
return filepath.Join(dir, auth.ID), nil
}
func (s *GitTokenStore) labelFor(metadata map[string]any) string {
if metadata == nil {
return ""
}
if v, ok := metadata["label"].(string); ok && v != "" {
return v
}
if v, ok := metadata["email"].(string); ok && v != "" {
return v
}
if project, ok := metadata["project_id"].(string); ok && project != "" {
return project
}
return ""
}
func (s *GitTokenStore) baseDirSnapshot() string {
s.dirLock.RLock()
defer s.dirLock.RUnlock()
return s.baseDir
}
func (s *GitTokenStore) repoDirSnapshot() string {
s.dirLock.RLock()
defer s.dirLock.RUnlock()
return s.repoDir
}
func (s *GitTokenStore) gitAuth() transport.AuthMethod {
if s.username == "" && s.password == "" {
return nil
}
user := s.username
if user == "" {
user = "git"
}
return &http.BasicAuth{Username: user, Password: s.password}
}
func (s *GitTokenStore) relativeToRepo(path string) (string, error) {
repoDir := s.repoDirSnapshot()
if repoDir == "" {
return "", fmt.Errorf("git token store: repository path not configured")
}
absRepo := repoDir
if abs, err := filepath.Abs(repoDir); err == nil {
absRepo = abs
}
cleanPath := path
if abs, err := filepath.Abs(path); err == nil {
cleanPath = abs
}
rel, err := filepath.Rel(absRepo, cleanPath)
if err != nil {
return "", fmt.Errorf("git token store: relative path: %w", err)
}
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("git token store: path outside repository")
}
return rel, nil
}
func (s *GitTokenStore) commitAndPushLocked(message string, relPaths ...string) error {
repoDir := s.repoDirSnapshot()
if repoDir == "" {
return fmt.Errorf("git token store: repository path not configured")
}
repo, err := git.PlainOpen(repoDir)
if err != nil {
return fmt.Errorf("git token store: open repo: %w", err)
}
worktree, err := repo.Worktree()
if err != nil {
return fmt.Errorf("git token store: worktree: %w", err)
}
added := false
for _, rel := range relPaths {
if strings.TrimSpace(rel) == "" {
continue
}
if _, err = worktree.Add(rel); err != nil {
if errors.Is(err, os.ErrNotExist) {
if _, errRemove := worktree.Remove(rel); errRemove != nil && !errors.Is(errRemove, os.ErrNotExist) {
return fmt.Errorf("git token store: remove %s: %w", rel, errRemove)
}
} else {
return fmt.Errorf("git token store: add %s: %w", rel, err)
}
}
added = true
}
if !added {
return nil
}
status, err := worktree.Status()
if err != nil {
return fmt.Errorf("git token store: status: %w", err)
}
if status.IsClean() {
return nil
}
if strings.TrimSpace(message) == "" {
message = "Update auth store"
}
signature := &object.Signature{
Name: "CLIProxyAPI",
Email: "cliproxy@local",
When: time.Now(),
}
commitHash, err := worktree.Commit(message, &git.CommitOptions{
Author: signature,
})
if err != nil {
if errors.Is(err, git.ErrEmptyCommit) {
return nil
}
return fmt.Errorf("git token store: commit: %w", err)
}
headRef, errHead := repo.Head()
if errHead != nil {
if !errors.Is(errHead, plumbing.ErrReferenceNotFound) {
return fmt.Errorf("git token store: get head: %w", errHead)
}
} else if errRewrite := s.rewriteHeadAsSingleCommit(repo, headRef.Name(), commitHash, message, signature); errRewrite != nil {
return errRewrite
}
s.maybeRunGC(repo)
if err = repo.Push(&git.PushOptions{Auth: s.gitAuth(), Force: true}); err != nil {
if errors.Is(err, git.NoErrAlreadyUpToDate) {
return nil
}
return fmt.Errorf("git token store: push: %w", err)
}
return nil
}
// rewriteHeadAsSingleCommit rewrites the current branch tip to a single-parentless commit and leaves history squashed.
func (s *GitTokenStore) rewriteHeadAsSingleCommit(repo *git.Repository, branch plumbing.ReferenceName, commitHash plumbing.Hash, message string, signature *object.Signature) error {
commitObj, err := repo.CommitObject(commitHash)
if err != nil {
return fmt.Errorf("git token store: inspect head commit: %w", err)
}
squashed := &object.Commit{
Author: *signature,
Committer: *signature,
Message: message,
TreeHash: commitObj.TreeHash,
ParentHashes: nil,
Encoding: commitObj.Encoding,
ExtraHeaders: commitObj.ExtraHeaders,
}
mem := &plumbing.MemoryObject{}
mem.SetType(plumbing.CommitObject)
if err := squashed.Encode(mem); err != nil {
return fmt.Errorf("git token store: encode squashed commit: %w", err)
}
newHash, err := repo.Storer.SetEncodedObject(mem)
if err != nil {
return fmt.Errorf("git token store: write squashed commit: %w", err)
}
if err := repo.Storer.SetReference(plumbing.NewHashReference(branch, newHash)); err != nil {
return fmt.Errorf("git token store: update branch reference: %w", err)
}
return nil
}
func (s *GitTokenStore) maybeRunGC(repo *git.Repository) {
now := time.Now()
if now.Sub(s.lastGC) < gcInterval {
return
}
s.lastGC = now
pruneOpts := git.PruneOptions{
OnlyObjectsOlderThan: now,
Handler: repo.DeleteObject,
}
if err := repo.Prune(pruneOpts); err != nil && !errors.Is(err, git.ErrLooseObjectsNotSupported) {
return
}
_ = repo.RepackObjects(&git.RepackConfig{})
}
// PersistConfig commits and pushes configuration changes to git.
func (s *GitTokenStore) PersistConfig(_ context.Context) error {
if err := s.EnsureRepository(); err != nil {
return err
}
configPath := s.ConfigPath()
if configPath == "" {
return fmt.Errorf("git token store: config path not configured")
}
if _, err := os.Stat(configPath); err != nil {
if errors.Is(err, fs.ErrNotExist) {
return nil
}
return fmt.Errorf("git token store: stat config: %w", err)
}
s.mu.Lock()
defer s.mu.Unlock()
rel, err := s.relativeToRepo(configPath)
if err != nil {
return err
}
return s.commitAndPushLocked("Update config", rel)
}
func ensureEmptyFile(path string) error {
if _, err := os.Stat(path); err != nil {
if errors.Is(err, fs.ErrNotExist) {
return os.WriteFile(path, []byte{}, 0o600)
}
return err
}
return nil
}
func jsonEqual(a, b []byte) bool {
var objA any
var objB any
if err := json.Unmarshal(a, &objA); err != nil {
return false
}
if err := json.Unmarshal(b, &objB); err != nil {
return false
}
return deepEqualJSON(objA, objB)
}
func deepEqualJSON(a, b any) bool {
switch valA := a.(type) {
case map[string]any:
valB, ok := b.(map[string]any)
if !ok || len(valA) != len(valB) {
return false
}
for key, subA := range valA {
subB, ok1 := valB[key]
if !ok1 || !deepEqualJSON(subA, subB) {
return false
}
}
return true
case []any:
sliceB, ok := b.([]any)
if !ok || len(valA) != len(sliceB) {
return false
}
for i := range valA {
if !deepEqualJSON(valA[i], sliceB[i]) {
return false
}
}
return true
case float64:
valB, ok := b.(float64)
if !ok {
return false
}
return valA == valB
case string:
valB, ok := b.(string)
if !ok {
return false
}
return valA == valB
case bool:
valB, ok := b.(bool)
if !ok {
return false
}
return valA == valB
case nil:
return b == nil
default:
return false
}
}
================================================
FILE: internal/store/objectstore.go
================================================
package store
import (
"bytes"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
const (
objectStoreConfigKey = "config/config.yaml"
objectStoreAuthPrefix = "auths"
)
// ObjectStoreConfig captures configuration for the object storage-backed token store.
type ObjectStoreConfig struct {
Endpoint string
Bucket string
AccessKey string
SecretKey string
Region string
Prefix string
LocalRoot string
UseSSL bool
PathStyle bool
}
// ObjectTokenStore persists configuration and authentication metadata using an S3-compatible object storage backend.
// Files are mirrored to a local workspace so existing file-based flows continue to operate.
type ObjectTokenStore struct {
client *minio.Client
cfg ObjectStoreConfig
spoolRoot string
configPath string
authDir string
mu sync.Mutex
}
// NewObjectTokenStore initializes an object storage backed token store.
func NewObjectTokenStore(cfg ObjectStoreConfig) (*ObjectTokenStore, error) {
cfg.Endpoint = strings.TrimSpace(cfg.Endpoint)
cfg.Bucket = strings.TrimSpace(cfg.Bucket)
cfg.AccessKey = strings.TrimSpace(cfg.AccessKey)
cfg.SecretKey = strings.TrimSpace(cfg.SecretKey)
cfg.Prefix = strings.Trim(cfg.Prefix, "/")
if cfg.Endpoint == "" {
return nil, fmt.Errorf("object store: endpoint is required")
}
if cfg.Bucket == "" {
return nil, fmt.Errorf("object store: bucket is required")
}
if cfg.AccessKey == "" {
return nil, fmt.Errorf("object store: access key is required")
}
if cfg.SecretKey == "" {
return nil, fmt.Errorf("object store: secret key is required")
}
root := strings.TrimSpace(cfg.LocalRoot)
if root == "" {
if cwd, err := os.Getwd(); err == nil {
root = filepath.Join(cwd, "objectstore")
} else {
root = filepath.Join(os.TempDir(), "objectstore")
}
}
absRoot, err := filepath.Abs(root)
if err != nil {
return nil, fmt.Errorf("object store: resolve spool directory: %w", err)
}
configDir := filepath.Join(absRoot, "config")
authDir := filepath.Join(absRoot, "auths")
if err = os.MkdirAll(configDir, 0o700); err != nil {
return nil, fmt.Errorf("object store: create config directory: %w", err)
}
if err = os.MkdirAll(authDir, 0o700); err != nil {
return nil, fmt.Errorf("object store: create auth directory: %w", err)
}
options := &minio.Options{
Creds: credentials.NewStaticV4(cfg.AccessKey, cfg.SecretKey, ""),
Secure: cfg.UseSSL,
Region: cfg.Region,
}
if cfg.PathStyle {
options.BucketLookup = minio.BucketLookupPath
}
client, err := minio.New(cfg.Endpoint, options)
if err != nil {
return nil, fmt.Errorf("object store: create client: %w", err)
}
return &ObjectTokenStore{
client: client,
cfg: cfg,
spoolRoot: absRoot,
configPath: filepath.Join(configDir, "config.yaml"),
authDir: authDir,
}, nil
}
// SetBaseDir implements the optional interface used by authenticators; it is a no-op because
// the object store controls its own workspace.
func (s *ObjectTokenStore) SetBaseDir(string) {}
// ConfigPath returns the managed configuration file path inside the spool directory.
func (s *ObjectTokenStore) ConfigPath() string {
if s == nil {
return ""
}
return s.configPath
}
// AuthDir returns the local directory containing mirrored auth files.
func (s *ObjectTokenStore) AuthDir() string {
if s == nil {
return ""
}
return s.authDir
}
// Bootstrap ensures the target bucket exists and synchronizes data from the object storage backend.
func (s *ObjectTokenStore) Bootstrap(ctx context.Context, exampleConfigPath string) error {
if s == nil {
return fmt.Errorf("object store: not initialized")
}
if err := s.ensureBucket(ctx); err != nil {
return err
}
if err := s.syncConfigFromBucket(ctx, exampleConfigPath); err != nil {
return err
}
if err := s.syncAuthFromBucket(ctx); err != nil {
return err
}
return nil
}
// Save persists authentication metadata to disk and uploads it to the object storage backend.
func (s *ObjectTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
if auth == nil {
return "", fmt.Errorf("object store: auth is nil")
}
path, err := s.resolveAuthPath(auth)
if err != nil {
return "", err
}
if path == "" {
return "", fmt.Errorf("object store: missing file path attribute for %s", auth.ID)
}
if auth.Disabled {
if _, statErr := os.Stat(path); errors.Is(statErr, fs.ErrNotExist) {
return "", nil
}
}
s.mu.Lock()
defer s.mu.Unlock()
if err = os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return "", fmt.Errorf("object store: create auth directory: %w", err)
}
switch {
case auth.Storage != nil:
if err = auth.Storage.SaveTokenToFile(path); err != nil {
return "", err
}
case auth.Metadata != nil:
raw, errMarshal := json.Marshal(auth.Metadata)
if errMarshal != nil {
return "", fmt.Errorf("object store: marshal metadata: %w", errMarshal)
}
if existing, errRead := os.ReadFile(path); errRead == nil {
if jsonEqual(existing, raw) {
return path, nil
}
} else if errRead != nil && !errors.Is(errRead, fs.ErrNotExist) {
return "", fmt.Errorf("object store: read existing metadata: %w", errRead)
}
tmp := path + ".tmp"
if errWrite := os.WriteFile(tmp, raw, 0o600); errWrite != nil {
return "", fmt.Errorf("object store: write temp auth file: %w", errWrite)
}
if errRename := os.Rename(tmp, path); errRename != nil {
return "", fmt.Errorf("object store: rename auth file: %w", errRename)
}
default:
return "", fmt.Errorf("object store: nothing to persist for %s", auth.ID)
}
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
auth.Attributes["path"] = path
if strings.TrimSpace(auth.FileName) == "" {
auth.FileName = auth.ID
}
if err = s.uploadAuth(ctx, path); err != nil {
return "", err
}
return path, nil
}
// List enumerates auth JSON files from the mirrored workspace.
func (s *ObjectTokenStore) List(_ context.Context) ([]*cliproxyauth.Auth, error) {
dir := strings.TrimSpace(s.AuthDir())
if dir == "" {
return nil, fmt.Errorf("object store: auth directory not configured")
}
entries := make([]*cliproxyauth.Auth, 0, 32)
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
return nil
}
auth, err := s.readAuthFile(path, dir)
if err != nil {
log.WithError(err).Warnf("object store: skip auth %s", path)
return nil
}
if auth != nil {
entries = append(entries, auth)
}
return nil
})
if err != nil {
return nil, fmt.Errorf("object store: walk auth directory: %w", err)
}
return entries, nil
}
// Delete removes an auth file locally and remotely.
func (s *ObjectTokenStore) Delete(ctx context.Context, id string) error {
id = strings.TrimSpace(id)
if id == "" {
return fmt.Errorf("object store: id is empty")
}
path, err := s.resolveDeletePath(id)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if err = os.Remove(path); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("object store: delete auth file: %w", err)
}
if err = s.deleteAuthObject(ctx, path); err != nil {
return err
}
return nil
}
// PersistAuthFiles uploads the provided auth files to the object storage backend.
func (s *ObjectTokenStore) PersistAuthFiles(ctx context.Context, _ string, paths ...string) error {
if len(paths) == 0 {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
for _, p := range paths {
trimmed := strings.TrimSpace(p)
if trimmed == "" {
continue
}
abs := trimmed
if !filepath.IsAbs(abs) {
abs = filepath.Join(s.authDir, trimmed)
}
if err := s.uploadAuth(ctx, abs); err != nil {
return err
}
}
return nil
}
// PersistConfig uploads the local configuration file to the object storage backend.
func (s *ObjectTokenStore) PersistConfig(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
data, err := os.ReadFile(s.configPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return s.deleteObject(ctx, objectStoreConfigKey)
}
return fmt.Errorf("object store: read config file: %w", err)
}
if len(data) == 0 {
return s.deleteObject(ctx, objectStoreConfigKey)
}
return s.putObject(ctx, objectStoreConfigKey, data, "application/x-yaml")
}
func (s *ObjectTokenStore) ensureBucket(ctx context.Context) error {
exists, err := s.client.BucketExists(ctx, s.cfg.Bucket)
if err != nil {
return fmt.Errorf("object store: check bucket: %w", err)
}
if exists {
return nil
}
if err = s.client.MakeBucket(ctx, s.cfg.Bucket, minio.MakeBucketOptions{Region: s.cfg.Region}); err != nil {
return fmt.Errorf("object store: create bucket: %w", err)
}
return nil
}
func (s *ObjectTokenStore) syncConfigFromBucket(ctx context.Context, example string) error {
key := s.prefixedKey(objectStoreConfigKey)
_, err := s.client.StatObject(ctx, s.cfg.Bucket, key, minio.StatObjectOptions{})
switch {
case err == nil:
object, errGet := s.client.GetObject(ctx, s.cfg.Bucket, key, minio.GetObjectOptions{})
if errGet != nil {
return fmt.Errorf("object store: fetch config: %w", errGet)
}
defer object.Close()
data, errRead := io.ReadAll(object)
if errRead != nil {
return fmt.Errorf("object store: read config: %w", errRead)
}
if errWrite := os.WriteFile(s.configPath, normalizeLineEndingsBytes(data), 0o600); errWrite != nil {
return fmt.Errorf("object store: write config: %w", errWrite)
}
case isObjectNotFound(err):
if _, statErr := os.Stat(s.configPath); errors.Is(statErr, fs.ErrNotExist) {
if example != "" {
if errCopy := misc.CopyConfigTemplate(example, s.configPath); errCopy != nil {
return fmt.Errorf("object store: copy example config: %w", errCopy)
}
} else {
if errCreate := os.MkdirAll(filepath.Dir(s.configPath), 0o700); errCreate != nil {
return fmt.Errorf("object store: prepare config directory: %w", errCreate)
}
if errWrite := os.WriteFile(s.configPath, []byte{}, 0o600); errWrite != nil {
return fmt.Errorf("object store: create empty config: %w", errWrite)
}
}
}
data, errRead := os.ReadFile(s.configPath)
if errRead != nil {
return fmt.Errorf("object store: read local config: %w", errRead)
}
if len(data) > 0 {
if errPut := s.putObject(ctx, objectStoreConfigKey, data, "application/x-yaml"); errPut != nil {
return errPut
}
}
default:
return fmt.Errorf("object store: stat config: %w", err)
}
return nil
}
func (s *ObjectTokenStore) syncAuthFromBucket(ctx context.Context) error {
// NOTE: We intentionally do NOT use os.RemoveAll here.
// Wiping the directory triggers file watcher delete events, which then
// propagate deletions to the remote object store (race condition).
// Instead, we just ensure the directory exists and overwrite files incrementally.
if err := os.MkdirAll(s.authDir, 0o700); err != nil {
return fmt.Errorf("object store: create auth directory: %w", err)
}
prefix := s.prefixedKey(objectStoreAuthPrefix + "/")
objectCh := s.client.ListObjects(ctx, s.cfg.Bucket, minio.ListObjectsOptions{
Prefix: prefix,
Recursive: true,
})
for object := range objectCh {
if object.Err != nil {
return fmt.Errorf("object store: list auth objects: %w", object.Err)
}
rel := strings.TrimPrefix(object.Key, prefix)
if rel == "" || strings.HasSuffix(rel, "/") {
continue
}
relPath := filepath.FromSlash(rel)
if filepath.IsAbs(relPath) {
log.WithField("key", object.Key).Warn("object store: skip auth outside mirror")
continue
}
cleanRel := filepath.Clean(relPath)
if cleanRel == "." || cleanRel == ".." || strings.HasPrefix(cleanRel, ".."+string(os.PathSeparator)) {
log.WithField("key", object.Key).Warn("object store: skip auth outside mirror")
continue
}
local := filepath.Join(s.authDir, cleanRel)
if err := os.MkdirAll(filepath.Dir(local), 0o700); err != nil {
return fmt.Errorf("object store: prepare auth subdir: %w", err)
}
reader, errGet := s.client.GetObject(ctx, s.cfg.Bucket, object.Key, minio.GetObjectOptions{})
if errGet != nil {
return fmt.Errorf("object store: download auth %s: %w", object.Key, errGet)
}
data, errRead := io.ReadAll(reader)
_ = reader.Close()
if errRead != nil {
return fmt.Errorf("object store: read auth %s: %w", object.Key, errRead)
}
if errWrite := os.WriteFile(local, data, 0o600); errWrite != nil {
return fmt.Errorf("object store: write auth %s: %w", local, errWrite)
}
}
return nil
}
func (s *ObjectTokenStore) uploadAuth(ctx context.Context, path string) error {
if path == "" {
return nil
}
rel, err := filepath.Rel(s.authDir, path)
if err != nil {
return fmt.Errorf("object store: resolve auth relative path: %w", err)
}
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return s.deleteAuthObject(ctx, path)
}
return fmt.Errorf("object store: read auth file: %w", err)
}
if len(data) == 0 {
return s.deleteAuthObject(ctx, path)
}
key := objectStoreAuthPrefix + "/" + filepath.ToSlash(rel)
return s.putObject(ctx, key, data, "application/json")
}
func (s *ObjectTokenStore) deleteAuthObject(ctx context.Context, path string) error {
if path == "" {
return nil
}
rel, err := filepath.Rel(s.authDir, path)
if err != nil {
return fmt.Errorf("object store: resolve auth relative path: %w", err)
}
key := objectStoreAuthPrefix + "/" + filepath.ToSlash(rel)
return s.deleteObject(ctx, key)
}
func (s *ObjectTokenStore) putObject(ctx context.Context, key string, data []byte, contentType string) error {
if len(data) == 0 {
return s.deleteObject(ctx, key)
}
fullKey := s.prefixedKey(key)
reader := bytes.NewReader(data)
_, err := s.client.PutObject(ctx, s.cfg.Bucket, fullKey, reader, int64(len(data)), minio.PutObjectOptions{
ContentType: contentType,
})
if err != nil {
return fmt.Errorf("object store: put object %s: %w", fullKey, err)
}
return nil
}
func (s *ObjectTokenStore) deleteObject(ctx context.Context, key string) error {
fullKey := s.prefixedKey(key)
err := s.client.RemoveObject(ctx, s.cfg.Bucket, fullKey, minio.RemoveObjectOptions{})
if err != nil {
if isObjectNotFound(err) {
return nil
}
return fmt.Errorf("object store: delete object %s: %w", fullKey, err)
}
return nil
}
func (s *ObjectTokenStore) prefixedKey(key string) string {
key = strings.TrimLeft(key, "/")
if s.cfg.Prefix == "" {
return key
}
return strings.TrimLeft(s.cfg.Prefix+"/"+key, "/")
}
func (s *ObjectTokenStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, error) {
if auth == nil {
return "", fmt.Errorf("object store: auth is nil")
}
if auth.Attributes != nil {
if path := strings.TrimSpace(auth.Attributes["path"]); path != "" {
if filepath.IsAbs(path) {
return path, nil
}
return filepath.Join(s.authDir, path), nil
}
}
fileName := strings.TrimSpace(auth.FileName)
if fileName == "" {
fileName = strings.TrimSpace(auth.ID)
}
if fileName == "" {
return "", fmt.Errorf("object store: auth %s missing filename", auth.ID)
}
if !strings.HasSuffix(strings.ToLower(fileName), ".json") {
fileName += ".json"
}
return filepath.Join(s.authDir, fileName), nil
}
func (s *ObjectTokenStore) resolveDeletePath(id string) (string, error) {
id = strings.TrimSpace(id)
if id == "" {
return "", fmt.Errorf("object store: id is empty")
}
// Absolute paths are honored as-is; callers must ensure they point inside the mirror.
if filepath.IsAbs(id) {
return id, nil
}
// Treat any non-absolute id (including nested like "team/foo") as relative to the mirror authDir.
// Normalize separators and guard against path traversal.
clean := filepath.Clean(filepath.FromSlash(id))
if clean == "." || clean == ".." || strings.HasPrefix(clean, ".."+string(os.PathSeparator)) {
return "", fmt.Errorf("object store: invalid auth identifier %s", id)
}
// Ensure .json suffix.
if !strings.HasSuffix(strings.ToLower(clean), ".json") {
clean += ".json"
}
return filepath.Join(s.authDir, clean), nil
}
func (s *ObjectTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read file: %w", err)
}
if len(data) == 0 {
return nil, nil
}
metadata := make(map[string]any)
if err = json.Unmarshal(data, &metadata); err != nil {
return nil, fmt.Errorf("unmarshal auth json: %w", err)
}
provider := strings.TrimSpace(valueAsString(metadata["type"]))
if provider == "" {
provider = "unknown"
}
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("stat auth file: %w", err)
}
rel, errRel := filepath.Rel(baseDir, path)
if errRel != nil {
rel = filepath.Base(path)
}
rel = normalizeAuthID(rel)
attr := map[string]string{"path": path}
if email := strings.TrimSpace(valueAsString(metadata["email"])); email != "" {
attr["email"] = email
}
auth := &cliproxyauth.Auth{
ID: rel,
Provider: provider,
FileName: rel,
Label: labelFor(metadata),
Status: cliproxyauth.StatusActive,
Attributes: attr,
Metadata: metadata,
CreatedAt: info.ModTime(),
UpdatedAt: info.ModTime(),
LastRefreshedAt: time.Time{},
NextRefreshAfter: time.Time{},
}
return auth, nil
}
func normalizeLineEndingsBytes(data []byte) []byte {
replaced := bytes.ReplaceAll(data, []byte{'\r', '\n'}, []byte{'\n'})
return bytes.ReplaceAll(replaced, []byte{'\r'}, []byte{'\n'})
}
func isObjectNotFound(err error) bool {
if err == nil {
return false
}
resp := minio.ToErrorResponse(err)
if resp.StatusCode == http.StatusNotFound {
return true
}
switch resp.Code {
case "NoSuchKey", "NotFound", "NoSuchBucket":
return true
}
return false
}
================================================
FILE: internal/store/postgresstore.go
================================================
package store
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"sync"
"time"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
const (
defaultConfigTable = "config_store"
defaultAuthTable = "auth_store"
defaultConfigKey = "config"
)
// PostgresStoreConfig captures configuration required to initialize a Postgres-backed store.
type PostgresStoreConfig struct {
DSN string
Schema string
ConfigTable string
AuthTable string
SpoolDir string
}
// PostgresStore persists configuration and authentication metadata using PostgreSQL as backend
// while mirroring data to a local workspace so existing file-based workflows continue to operate.
type PostgresStore struct {
db *sql.DB
cfg PostgresStoreConfig
spoolRoot string
configPath string
authDir string
mu sync.Mutex
}
// NewPostgresStore establishes a connection to PostgreSQL and prepares the local workspace.
func NewPostgresStore(ctx context.Context, cfg PostgresStoreConfig) (*PostgresStore, error) {
trimmedDSN := strings.TrimSpace(cfg.DSN)
if trimmedDSN == "" {
return nil, fmt.Errorf("postgres store: DSN is required")
}
cfg.DSN = trimmedDSN
if cfg.ConfigTable == "" {
cfg.ConfigTable = defaultConfigTable
}
if cfg.AuthTable == "" {
cfg.AuthTable = defaultAuthTable
}
spoolRoot := strings.TrimSpace(cfg.SpoolDir)
if spoolRoot == "" {
if cwd, err := os.Getwd(); err == nil {
spoolRoot = filepath.Join(cwd, "pgstore")
} else {
spoolRoot = filepath.Join(os.TempDir(), "pgstore")
}
}
absSpool, err := filepath.Abs(spoolRoot)
if err != nil {
return nil, fmt.Errorf("postgres store: resolve spool directory: %w", err)
}
configDir := filepath.Join(absSpool, "config")
authDir := filepath.Join(absSpool, "auths")
if err = os.MkdirAll(configDir, 0o700); err != nil {
return nil, fmt.Errorf("postgres store: create config directory: %w", err)
}
if err = os.MkdirAll(authDir, 0o700); err != nil {
return nil, fmt.Errorf("postgres store: create auth directory: %w", err)
}
db, err := sql.Open("pgx", cfg.DSN)
if err != nil {
return nil, fmt.Errorf("postgres store: open database connection: %w", err)
}
if err = db.PingContext(ctx); err != nil {
_ = db.Close()
return nil, fmt.Errorf("postgres store: ping database: %w", err)
}
store := &PostgresStore{
db: db,
cfg: cfg,
spoolRoot: absSpool,
configPath: filepath.Join(configDir, "config.yaml"),
authDir: authDir,
}
return store, nil
}
// Close releases the underlying database connection.
func (s *PostgresStore) Close() error {
if s == nil || s.db == nil {
return nil
}
return s.db.Close()
}
// EnsureSchema creates the required tables (and schema when provided).
func (s *PostgresStore) EnsureSchema(ctx context.Context) error {
if s == nil || s.db == nil {
return fmt.Errorf("postgres store: not initialized")
}
if schema := strings.TrimSpace(s.cfg.Schema); schema != "" {
query := fmt.Sprintf("CREATE SCHEMA IF NOT EXISTS %s", quoteIdentifier(schema))
if _, err := s.db.ExecContext(ctx, query); err != nil {
return fmt.Errorf("postgres store: create schema: %w", err)
}
}
configTable := s.fullTableName(s.cfg.ConfigTable)
if _, err := s.db.ExecContext(ctx, fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id TEXT PRIMARY KEY,
content TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
`, configTable)); err != nil {
return fmt.Errorf("postgres store: create config table: %w", err)
}
authTable := s.fullTableName(s.cfg.AuthTable)
if _, err := s.db.ExecContext(ctx, fmt.Sprintf(`
CREATE TABLE IF NOT EXISTS %s (
id TEXT PRIMARY KEY,
content JSONB NOT NULL,
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
)
`, authTable)); err != nil {
return fmt.Errorf("postgres store: create auth table: %w", err)
}
return nil
}
// Bootstrap synchronizes configuration and auth records between PostgreSQL and the local workspace.
func (s *PostgresStore) Bootstrap(ctx context.Context, exampleConfigPath string) error {
if err := s.EnsureSchema(ctx); err != nil {
return err
}
if err := s.syncConfigFromDatabase(ctx, exampleConfigPath); err != nil {
return err
}
if err := s.syncAuthFromDatabase(ctx); err != nil {
return err
}
return nil
}
// ConfigPath returns the managed configuration file path inside the spool directory.
func (s *PostgresStore) ConfigPath() string {
if s == nil {
return ""
}
return s.configPath
}
// AuthDir returns the local directory containing mirrored auth files.
func (s *PostgresStore) AuthDir() string {
if s == nil {
return ""
}
return s.authDir
}
// WorkDir exposes the root spool directory used for mirroring.
func (s *PostgresStore) WorkDir() string {
if s == nil {
return ""
}
return s.spoolRoot
}
// SetBaseDir implements the optional interface used by authenticators; it is a no-op because
// the Postgres-backed store controls its own workspace.
func (s *PostgresStore) SetBaseDir(string) {}
// Save persists authentication metadata to disk and PostgreSQL.
func (s *PostgresStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
if auth == nil {
return "", fmt.Errorf("postgres store: auth is nil")
}
path, err := s.resolveAuthPath(auth)
if err != nil {
return "", err
}
if path == "" {
return "", fmt.Errorf("postgres store: missing file path attribute for %s", auth.ID)
}
if auth.Disabled {
if _, statErr := os.Stat(path); errors.Is(statErr, fs.ErrNotExist) {
return "", nil
}
}
s.mu.Lock()
defer s.mu.Unlock()
if err = os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return "", fmt.Errorf("postgres store: create auth directory: %w", err)
}
switch {
case auth.Storage != nil:
if err = auth.Storage.SaveTokenToFile(path); err != nil {
return "", err
}
case auth.Metadata != nil:
raw, errMarshal := json.Marshal(auth.Metadata)
if errMarshal != nil {
return "", fmt.Errorf("postgres store: marshal metadata: %w", errMarshal)
}
if existing, errRead := os.ReadFile(path); errRead == nil {
if jsonEqual(existing, raw) {
return path, nil
}
} else if errRead != nil && !errors.Is(errRead, fs.ErrNotExist) {
return "", fmt.Errorf("postgres store: read existing metadata: %w", errRead)
}
tmp := path + ".tmp"
if errWrite := os.WriteFile(tmp, raw, 0o600); errWrite != nil {
return "", fmt.Errorf("postgres store: write temp auth file: %w", errWrite)
}
if errRename := os.Rename(tmp, path); errRename != nil {
return "", fmt.Errorf("postgres store: rename auth file: %w", errRename)
}
default:
return "", fmt.Errorf("postgres store: nothing to persist for %s", auth.ID)
}
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
auth.Attributes["path"] = path
if strings.TrimSpace(auth.FileName) == "" {
auth.FileName = auth.ID
}
relID, err := s.relativeAuthID(path)
if err != nil {
return "", err
}
if err = s.upsertAuthRecord(ctx, relID, path); err != nil {
return "", err
}
return path, nil
}
// List enumerates all auth records stored in PostgreSQL.
func (s *PostgresStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error) {
query := fmt.Sprintf("SELECT id, content, created_at, updated_at FROM %s ORDER BY id", s.fullTableName(s.cfg.AuthTable))
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, fmt.Errorf("postgres store: list auth: %w", err)
}
defer rows.Close()
auths := make([]*cliproxyauth.Auth, 0, 32)
for rows.Next() {
var (
id string
payload string
createdAt time.Time
updatedAt time.Time
)
if err = rows.Scan(&id, &payload, &createdAt, &updatedAt); err != nil {
return nil, fmt.Errorf("postgres store: scan auth row: %w", err)
}
path, errPath := s.absoluteAuthPath(id)
if errPath != nil {
log.WithError(errPath).Warnf("postgres store: skipping auth %s outside spool", id)
continue
}
metadata := make(map[string]any)
if err = json.Unmarshal([]byte(payload), &metadata); err != nil {
log.WithError(err).Warnf("postgres store: skipping auth %s with invalid json", id)
continue
}
provider := strings.TrimSpace(valueAsString(metadata["type"]))
if provider == "" {
provider = "unknown"
}
attr := map[string]string{"path": path}
if email := strings.TrimSpace(valueAsString(metadata["email"])); email != "" {
attr["email"] = email
}
auth := &cliproxyauth.Auth{
ID: normalizeAuthID(id),
Provider: provider,
FileName: normalizeAuthID(id),
Label: labelFor(metadata),
Status: cliproxyauth.StatusActive,
Attributes: attr,
Metadata: metadata,
CreatedAt: createdAt,
UpdatedAt: updatedAt,
LastRefreshedAt: time.Time{},
NextRefreshAfter: time.Time{},
}
auths = append(auths, auth)
}
if err = rows.Err(); err != nil {
return nil, fmt.Errorf("postgres store: iterate auth rows: %w", err)
}
return auths, nil
}
// Delete removes an auth file and the corresponding database record.
func (s *PostgresStore) Delete(ctx context.Context, id string) error {
id = strings.TrimSpace(id)
if id == "" {
return fmt.Errorf("postgres store: id is empty")
}
path, err := s.resolveDeletePath(id)
if err != nil {
return err
}
s.mu.Lock()
defer s.mu.Unlock()
if err = os.Remove(path); err != nil && !errors.Is(err, fs.ErrNotExist) {
return fmt.Errorf("postgres store: delete auth file: %w", err)
}
relID, err := s.relativeAuthID(path)
if err != nil {
return err
}
return s.deleteAuthRecord(ctx, relID)
}
// PersistAuthFiles stores the provided auth file changes in PostgreSQL.
func (s *PostgresStore) PersistAuthFiles(ctx context.Context, _ string, paths ...string) error {
if len(paths) == 0 {
return nil
}
s.mu.Lock()
defer s.mu.Unlock()
for _, p := range paths {
trimmed := strings.TrimSpace(p)
if trimmed == "" {
continue
}
relID, err := s.relativeAuthID(trimmed)
if err != nil {
// Attempt to resolve absolute path under authDir.
abs := trimmed
if !filepath.IsAbs(abs) {
abs = filepath.Join(s.authDir, trimmed)
}
relID, err = s.relativeAuthID(abs)
if err != nil {
log.WithError(err).Warnf("postgres store: ignoring auth path %s", trimmed)
continue
}
trimmed = abs
}
if err = s.syncAuthFile(ctx, relID, trimmed); err != nil {
return err
}
}
return nil
}
// PersistConfig mirrors the local configuration file to PostgreSQL.
func (s *PostgresStore) PersistConfig(ctx context.Context) error {
s.mu.Lock()
defer s.mu.Unlock()
data, err := os.ReadFile(s.configPath)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return s.deleteConfigRecord(ctx)
}
return fmt.Errorf("postgres store: read config file: %w", err)
}
return s.persistConfig(ctx, data)
}
// syncConfigFromDatabase writes the database-stored config to disk or seeds the database from template.
func (s *PostgresStore) syncConfigFromDatabase(ctx context.Context, exampleConfigPath string) error {
query := fmt.Sprintf("SELECT content FROM %s WHERE id = $1", s.fullTableName(s.cfg.ConfigTable))
var content string
err := s.db.QueryRowContext(ctx, query, defaultConfigKey).Scan(&content)
switch {
case errors.Is(err, sql.ErrNoRows):
if _, errStat := os.Stat(s.configPath); errors.Is(errStat, fs.ErrNotExist) {
if exampleConfigPath != "" {
if errCopy := misc.CopyConfigTemplate(exampleConfigPath, s.configPath); errCopy != nil {
return fmt.Errorf("postgres store: copy example config: %w", errCopy)
}
} else {
if errCreate := os.MkdirAll(filepath.Dir(s.configPath), 0o700); errCreate != nil {
return fmt.Errorf("postgres store: prepare config directory: %w", errCreate)
}
if errWrite := os.WriteFile(s.configPath, []byte{}, 0o600); errWrite != nil {
return fmt.Errorf("postgres store: create empty config: %w", errWrite)
}
}
}
data, errRead := os.ReadFile(s.configPath)
if errRead != nil {
return fmt.Errorf("postgres store: read local config: %w", errRead)
}
if errPersist := s.persistConfig(ctx, data); errPersist != nil {
return errPersist
}
case err != nil:
return fmt.Errorf("postgres store: load config from database: %w", err)
default:
if err = os.MkdirAll(filepath.Dir(s.configPath), 0o700); err != nil {
return fmt.Errorf("postgres store: prepare config directory: %w", err)
}
normalized := normalizeLineEndings(content)
if err = os.WriteFile(s.configPath, []byte(normalized), 0o600); err != nil {
return fmt.Errorf("postgres store: write config to spool: %w", err)
}
}
return nil
}
// syncAuthFromDatabase populates the local auth directory from PostgreSQL data.
func (s *PostgresStore) syncAuthFromDatabase(ctx context.Context) error {
query := fmt.Sprintf("SELECT id, content FROM %s", s.fullTableName(s.cfg.AuthTable))
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return fmt.Errorf("postgres store: load auth from database: %w", err)
}
defer rows.Close()
if err = os.RemoveAll(s.authDir); err != nil {
return fmt.Errorf("postgres store: reset auth directory: %w", err)
}
if err = os.MkdirAll(s.authDir, 0o700); err != nil {
return fmt.Errorf("postgres store: recreate auth directory: %w", err)
}
for rows.Next() {
var (
id string
payload string
)
if err = rows.Scan(&id, &payload); err != nil {
return fmt.Errorf("postgres store: scan auth row: %w", err)
}
path, errPath := s.absoluteAuthPath(id)
if errPath != nil {
log.WithError(errPath).Warnf("postgres store: skipping auth %s outside spool", id)
continue
}
if err = os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return fmt.Errorf("postgres store: create auth subdir: %w", err)
}
if err = os.WriteFile(path, []byte(payload), 0o600); err != nil {
return fmt.Errorf("postgres store: write auth file: %w", err)
}
}
if err = rows.Err(); err != nil {
return fmt.Errorf("postgres store: iterate auth rows: %w", err)
}
return nil
}
func (s *PostgresStore) syncAuthFile(ctx context.Context, relID, path string) error {
data, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, fs.ErrNotExist) {
return s.deleteAuthRecord(ctx, relID)
}
return fmt.Errorf("postgres store: read auth file: %w", err)
}
if len(data) == 0 {
return s.deleteAuthRecord(ctx, relID)
}
return s.persistAuth(ctx, relID, data)
}
func (s *PostgresStore) upsertAuthRecord(ctx context.Context, relID, path string) error {
data, err := os.ReadFile(path)
if err != nil {
return fmt.Errorf("postgres store: read auth file: %w", err)
}
if len(data) == 0 {
return s.deleteAuthRecord(ctx, relID)
}
return s.persistAuth(ctx, relID, data)
}
func (s *PostgresStore) persistAuth(ctx context.Context, relID string, data []byte) error {
jsonPayload := json.RawMessage(data)
query := fmt.Sprintf(`
INSERT INTO %s (id, content, created_at, updated_at)
VALUES ($1, $2, NOW(), NOW())
ON CONFLICT (id)
DO UPDATE SET content = EXCLUDED.content, updated_at = NOW()
`, s.fullTableName(s.cfg.AuthTable))
if _, err := s.db.ExecContext(ctx, query, relID, jsonPayload); err != nil {
return fmt.Errorf("postgres store: upsert auth record: %w", err)
}
return nil
}
func (s *PostgresStore) deleteAuthRecord(ctx context.Context, relID string) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = $1", s.fullTableName(s.cfg.AuthTable))
if _, err := s.db.ExecContext(ctx, query, relID); err != nil {
return fmt.Errorf("postgres store: delete auth record: %w", err)
}
return nil
}
func (s *PostgresStore) persistConfig(ctx context.Context, data []byte) error {
query := fmt.Sprintf(`
INSERT INTO %s (id, content, created_at, updated_at)
VALUES ($1, $2, NOW(), NOW())
ON CONFLICT (id)
DO UPDATE SET content = EXCLUDED.content, updated_at = NOW()
`, s.fullTableName(s.cfg.ConfigTable))
normalized := normalizeLineEndings(string(data))
if _, err := s.db.ExecContext(ctx, query, defaultConfigKey, normalized); err != nil {
return fmt.Errorf("postgres store: upsert config: %w", err)
}
return nil
}
func (s *PostgresStore) deleteConfigRecord(ctx context.Context) error {
query := fmt.Sprintf("DELETE FROM %s WHERE id = $1", s.fullTableName(s.cfg.ConfigTable))
if _, err := s.db.ExecContext(ctx, query, defaultConfigKey); err != nil {
return fmt.Errorf("postgres store: delete config: %w", err)
}
return nil
}
func (s *PostgresStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, error) {
if auth == nil {
return "", fmt.Errorf("postgres store: auth is nil")
}
if auth.Attributes != nil {
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
return p, nil
}
}
if fileName := strings.TrimSpace(auth.FileName); fileName != "" {
if filepath.IsAbs(fileName) {
return fileName, nil
}
return filepath.Join(s.authDir, fileName), nil
}
if auth.ID == "" {
return "", fmt.Errorf("postgres store: missing id")
}
if filepath.IsAbs(auth.ID) {
return auth.ID, nil
}
return filepath.Join(s.authDir, filepath.FromSlash(auth.ID)), nil
}
func (s *PostgresStore) resolveDeletePath(id string) (string, error) {
if strings.ContainsRune(id, os.PathSeparator) || filepath.IsAbs(id) {
return id, nil
}
return filepath.Join(s.authDir, filepath.FromSlash(id)), nil
}
func (s *PostgresStore) relativeAuthID(path string) (string, error) {
if s == nil {
return "", fmt.Errorf("postgres store: store not initialized")
}
if !filepath.IsAbs(path) {
path = filepath.Join(s.authDir, path)
}
clean := filepath.Clean(path)
rel, err := filepath.Rel(s.authDir, clean)
if err != nil {
return "", fmt.Errorf("postgres store: compute relative path: %w", err)
}
if strings.HasPrefix(rel, "..") {
return "", fmt.Errorf("postgres store: path %s outside managed directory", path)
}
return filepath.ToSlash(rel), nil
}
func (s *PostgresStore) absoluteAuthPath(id string) (string, error) {
if s == nil {
return "", fmt.Errorf("postgres store: store not initialized")
}
clean := filepath.Clean(filepath.FromSlash(id))
if strings.HasPrefix(clean, "..") {
return "", fmt.Errorf("postgres store: invalid auth identifier %s", id)
}
path := filepath.Join(s.authDir, clean)
rel, err := filepath.Rel(s.authDir, path)
if err != nil {
return "", err
}
if strings.HasPrefix(rel, "..") {
return "", fmt.Errorf("postgres store: resolved auth path escapes auth directory")
}
return path, nil
}
func (s *PostgresStore) fullTableName(name string) string {
if strings.TrimSpace(s.cfg.Schema) == "" {
return quoteIdentifier(name)
}
return quoteIdentifier(s.cfg.Schema) + "." + quoteIdentifier(name)
}
func quoteIdentifier(identifier string) string {
replaced := strings.ReplaceAll(identifier, "\"", "\"\"")
return "\"" + replaced + "\""
}
func valueAsString(v any) string {
switch t := v.(type) {
case string:
return t
case fmt.Stringer:
return t.String()
default:
return ""
}
}
func labelFor(metadata map[string]any) string {
if metadata == nil {
return ""
}
if v := strings.TrimSpace(valueAsString(metadata["label"])); v != "" {
return v
}
if v := strings.TrimSpace(valueAsString(metadata["email"])); v != "" {
return v
}
if v := strings.TrimSpace(valueAsString(metadata["project_id"])); v != "" {
return v
}
return ""
}
func normalizeAuthID(id string) string {
return filepath.ToSlash(filepath.Clean(id))
}
func normalizeLineEndings(s string) string {
if s == "" {
return s
}
s = strings.ReplaceAll(s, "\r\n", "\n")
s = strings.ReplaceAll(s, "\r", "\n")
return s
}
================================================
FILE: internal/thinking/apply.go
================================================
// Package thinking provides unified thinking configuration processing.
package thinking
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// providerAppliers maps provider names to their ProviderApplier implementations.
var providerAppliers = map[string]ProviderApplier{
"gemini": nil,
"gemini-cli": nil,
"claude": nil,
"openai": nil,
"codex": nil,
"iflow": nil,
"antigravity": nil,
"kimi": nil,
}
// GetProviderApplier returns the ProviderApplier for the given provider name.
// Returns nil if the provider is not registered.
func GetProviderApplier(provider string) ProviderApplier {
return providerAppliers[provider]
}
// RegisterProvider registers a provider applier by name.
func RegisterProvider(name string, applier ProviderApplier) {
providerAppliers[name] = applier
}
// IsUserDefinedModel reports whether the model is a user-defined model that should
// have thinking configuration passed through without validation.
//
// User-defined models are configured via config file's models[] array
// (e.g., openai-compatibility.*.models[], *-api-key.models[]). These models
// are marked with UserDefined=true at registration time.
//
// User-defined models should have their thinking configuration applied directly,
// letting the upstream service validate the configuration.
func IsUserDefinedModel(modelInfo *registry.ModelInfo) bool {
if modelInfo == nil {
return true
}
return modelInfo.UserDefined
}
// ApplyThinking applies thinking configuration to a request body.
//
// This is the unified entry point for all providers. It follows the processing
// order defined in FR25: route check → model capability query → config extraction
// → validation → application.
//
// Suffix Priority: When the model name includes a thinking suffix (e.g., "gemini-2.5-pro(8192)"),
// the suffix configuration takes priority over any thinking parameters in the request body.
// This enables users to override thinking settings via the model name without modifying their
// request payload.
//
// Parameters:
// - body: Original request body JSON
// - model: Model name, optionally with thinking suffix (e.g., "claude-sonnet-4-5(16384)")
// - fromFormat: Source request format (e.g., openai, codex, gemini)
// - toFormat: Target provider format for the request body (gemini, gemini-cli, antigravity, claude, openai, codex, iflow)
// - providerKey: Provider identifier used for registry model lookups (may differ from toFormat, e.g., openrouter -> openai)
//
// Returns:
// - Modified request body JSON with thinking configuration applied
// - Error if validation fails (ThinkingError). On error, the original body
// is returned (not nil) to enable defensive programming patterns.
//
// Passthrough behavior (returns original body without error):
// - Unknown provider (not in providerAppliers map)
// - modelInfo.Thinking is nil (model doesn't support thinking)
//
// Note: Unknown models (modelInfo is nil) are treated as user-defined models: we skip
// validation and still apply the thinking config so the upstream can validate it.
//
// Example:
//
// // With suffix - suffix config takes priority
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro(8192)", "gemini", "gemini", "gemini")
//
// // Without suffix - uses body config
// result, err := thinking.ApplyThinking(body, "gemini-2.5-pro", "gemini", "gemini", "gemini")
func ApplyThinking(body []byte, model string, fromFormat string, toFormat string, providerKey string) ([]byte, error) {
providerFormat := strings.ToLower(strings.TrimSpace(toFormat))
providerKey = strings.ToLower(strings.TrimSpace(providerKey))
if providerKey == "" {
providerKey = providerFormat
}
fromFormat = strings.ToLower(strings.TrimSpace(fromFormat))
if fromFormat == "" {
fromFormat = providerFormat
}
// 1. Route check: Get provider applier
applier := GetProviderApplier(providerFormat)
if applier == nil {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": model,
}).Debug("thinking: unknown provider, passthrough |")
return body, nil
}
// 2. Parse suffix and get modelInfo
suffixResult := ParseSuffix(model)
baseModel := suffixResult.ModelName
// Use provider-specific lookup to handle capability differences across providers.
modelInfo := registry.LookupModelInfo(baseModel, providerKey)
// 3. Model capability check
// Unknown models are treated as user-defined so thinking config can still be applied.
// The upstream service is responsible for validating the configuration.
if IsUserDefinedModel(modelInfo) {
return applyUserDefinedModel(body, modelInfo, fromFormat, providerFormat, suffixResult)
}
if modelInfo.Thinking == nil {
config := extractThinkingConfig(body, providerFormat)
if hasThinkingConfig(config) {
log.WithFields(log.Fields{
"model": baseModel,
"provider": providerFormat,
}).Debug("thinking: model does not support thinking, stripping config |")
return StripThinkingConfig(body, providerFormat), nil
}
log.WithFields(log.Fields{
"provider": providerFormat,
"model": baseModel,
}).Debug("thinking: model does not support thinking, passthrough |")
return body, nil
}
// 4. Get config: suffix priority over body
var config ThinkingConfig
if suffixResult.HasSuffix {
config = parseSuffixToConfig(suffixResult.RawSuffix, providerFormat, model)
log.WithFields(log.Fields{
"provider": providerFormat,
"model": model,
"mode": config.Mode,
"budget": config.Budget,
"level": config.Level,
}).Debug("thinking: config from model suffix |")
} else {
config = extractThinkingConfig(body, providerFormat)
if hasThinkingConfig(config) {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
"mode": config.Mode,
"budget": config.Budget,
"level": config.Level,
}).Debug("thinking: original config from request |")
}
}
if !hasThinkingConfig(config) {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
}).Debug("thinking: no config found, passthrough |")
return body, nil
}
// 5. Validate and normalize configuration
validated, err := ValidateConfig(config, modelInfo, fromFormat, providerFormat, suffixResult.HasSuffix)
if err != nil {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
"error": err.Error(),
}).Warn("thinking: validation failed |")
// Return original body on validation failure (defensive programming).
// This ensures callers who ignore the error won't receive nil body.
// The upstream service will decide how to handle the unmodified request.
return body, err
}
// Defensive check: ValidateConfig should never return (nil, nil)
if validated == nil {
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
}).Warn("thinking: ValidateConfig returned nil config without error, passthrough |")
return body, nil
}
log.WithFields(log.Fields{
"provider": providerFormat,
"model": modelInfo.ID,
"mode": validated.Mode,
"budget": validated.Budget,
"level": validated.Level,
}).Debug("thinking: processed config to apply |")
// 6. Apply configuration using provider-specific applier
return applier.Apply(body, *validated, modelInfo)
}
// parseSuffixToConfig converts a raw suffix string to ThinkingConfig.
//
// Parsing priority:
// 1. Special values: "none" → ModeNone, "auto"/"-1" → ModeAuto
// 2. Level names: "minimal", "low", "medium", "high", "xhigh" → ModeLevel
// 3. Numeric values: positive integers → ModeBudget, 0 → ModeNone
//
// If none of the above match, returns empty ThinkingConfig (treated as no config).
func parseSuffixToConfig(rawSuffix, provider, model string) ThinkingConfig {
// 1. Try special values first (none, auto, -1)
if mode, ok := ParseSpecialSuffix(rawSuffix); ok {
switch mode {
case ModeNone:
return ThinkingConfig{Mode: ModeNone, Budget: 0}
case ModeAuto:
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
}
}
// 2. Try level parsing (minimal, low, medium, high, xhigh)
if level, ok := ParseLevelSuffix(rawSuffix); ok {
return ThinkingConfig{Mode: ModeLevel, Level: level}
}
// 3. Try numeric parsing
if budget, ok := ParseNumericSuffix(rawSuffix); ok {
if budget == 0 {
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
return ThinkingConfig{Mode: ModeBudget, Budget: budget}
}
// Unknown suffix format - return empty config
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"raw_suffix": rawSuffix,
}).Debug("thinking: unknown suffix format, treating as no config |")
return ThinkingConfig{}
}
// applyUserDefinedModel applies thinking configuration for user-defined models
// without ThinkingSupport validation.
func applyUserDefinedModel(body []byte, modelInfo *registry.ModelInfo, fromFormat, toFormat string, suffixResult SuffixResult) ([]byte, error) {
// Get model ID for logging
modelID := ""
if modelInfo != nil {
modelID = modelInfo.ID
} else {
modelID = suffixResult.ModelName
}
// Get config: suffix priority over body
var config ThinkingConfig
if suffixResult.HasSuffix {
config = parseSuffixToConfig(suffixResult.RawSuffix, toFormat, modelID)
} else {
config = extractThinkingConfig(body, fromFormat)
if !hasThinkingConfig(config) && fromFormat != toFormat {
config = extractThinkingConfig(body, toFormat)
}
}
if !hasThinkingConfig(config) {
log.WithFields(log.Fields{
"model": modelID,
"provider": toFormat,
}).Debug("thinking: user-defined model, passthrough (no config) |")
return body, nil
}
applier := GetProviderApplier(toFormat)
if applier == nil {
log.WithFields(log.Fields{
"model": modelID,
"provider": toFormat,
}).Debug("thinking: user-defined model, passthrough (unknown provider) |")
return body, nil
}
log.WithFields(log.Fields{
"provider": toFormat,
"model": modelID,
"mode": config.Mode,
"budget": config.Budget,
"level": config.Level,
}).Debug("thinking: applying config for user-defined model (skip validation)")
config = normalizeUserDefinedConfig(config, fromFormat, toFormat)
return applier.Apply(body, config, modelInfo)
}
func normalizeUserDefinedConfig(config ThinkingConfig, fromFormat, toFormat string) ThinkingConfig {
if config.Mode != ModeLevel {
return config
}
if toFormat == "claude" {
return config
}
if !isBudgetCapableProvider(toFormat) {
return config
}
budget, ok := ConvertLevelToBudget(string(config.Level))
if !ok {
return config
}
config.Mode = ModeBudget
config.Budget = budget
config.Level = ""
return config
}
// extractThinkingConfig extracts provider-specific thinking config from request body.
func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
if len(body) == 0 || !gjson.ValidBytes(body) {
return ThinkingConfig{}
}
switch provider {
case "claude":
return extractClaudeConfig(body)
case "gemini", "gemini-cli", "antigravity":
return extractGeminiConfig(body, provider)
case "openai":
return extractOpenAIConfig(body)
case "codex":
return extractCodexConfig(body)
case "iflow":
config := extractIFlowConfig(body)
if hasThinkingConfig(config) {
return config
}
return extractOpenAIConfig(body)
case "kimi":
// Kimi uses OpenAI-compatible reasoning_effort format
return extractOpenAIConfig(body)
default:
return ThinkingConfig{}
}
}
func hasThinkingConfig(config ThinkingConfig) bool {
return config.Mode != ModeBudget || config.Budget != 0 || config.Level != ""
}
// extractClaudeConfig extracts thinking configuration from Claude format request body.
//
// Claude API format:
// - thinking.type: "enabled" or "disabled"
// - thinking.budget_tokens: integer (-1=auto, 0=disabled, >0=budget)
//
// Priority: thinking.type="disabled" takes precedence over budget_tokens.
// When type="enabled" without budget_tokens, returns ModeAuto to indicate
// the user wants thinking enabled but didn't specify a budget.
func extractClaudeConfig(body []byte) ThinkingConfig {
thinkingType := gjson.GetBytes(body, "thinking.type").String()
if thinkingType == "disabled" {
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
if thinkingType == "adaptive" || thinkingType == "auto" {
// Claude adaptive thinking uses output_config.effort (low/medium/high/max).
// We only treat it as a thinking config when effort is explicitly present;
// otherwise we passthrough and let upstream defaults apply.
if effort := gjson.GetBytes(body, "output_config.effort"); effort.Exists() && effort.Type == gjson.String {
value := strings.ToLower(strings.TrimSpace(effort.String()))
if value == "" {
return ThinkingConfig{}
}
switch value {
case "none":
return ThinkingConfig{Mode: ModeNone, Budget: 0}
case "auto":
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
default:
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
}
}
return ThinkingConfig{}
}
// Check budget_tokens
if budget := gjson.GetBytes(body, "thinking.budget_tokens"); budget.Exists() {
value := int(budget.Int())
switch value {
case 0:
return ThinkingConfig{Mode: ModeNone, Budget: 0}
case -1:
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
default:
return ThinkingConfig{Mode: ModeBudget, Budget: value}
}
}
// If type="enabled" but no budget_tokens, treat as auto (user wants thinking but no budget specified)
if thinkingType == "enabled" {
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
}
return ThinkingConfig{}
}
// extractGeminiConfig extracts thinking configuration from Gemini format request body.
//
// Gemini API format:
// - generationConfig.thinkingConfig.thinkingLevel: "none", "auto", or level name (Gemini 3)
// - generationConfig.thinkingConfig.thinkingBudget: integer (Gemini 2.5)
//
// For gemini-cli and antigravity providers, the path is prefixed with "request.".
//
// Priority: thinkingLevel is checked first (Gemini 3 format), then thinkingBudget (Gemini 2.5 format).
// This allows newer Gemini 3 level-based configs to take precedence.
func extractGeminiConfig(body []byte, provider string) ThinkingConfig {
prefix := "generationConfig.thinkingConfig"
if provider == "gemini-cli" || provider == "antigravity" {
prefix = "request.generationConfig.thinkingConfig"
}
// Check thinkingLevel first (Gemini 3 format takes precedence)
level := gjson.GetBytes(body, prefix+".thinkingLevel")
if !level.Exists() {
// Google official Gemini Python SDK sends snake_case field names
level = gjson.GetBytes(body, prefix+".thinking_level")
}
if level.Exists() {
value := level.String()
switch value {
case "none":
return ThinkingConfig{Mode: ModeNone, Budget: 0}
case "auto":
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
default:
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
}
}
// Check thinkingBudget (Gemini 2.5 format)
budget := gjson.GetBytes(body, prefix+".thinkingBudget")
if !budget.Exists() {
// Google official Gemini Python SDK sends snake_case field names
budget = gjson.GetBytes(body, prefix+".thinking_budget")
}
if budget.Exists() {
value := int(budget.Int())
switch value {
case 0:
return ThinkingConfig{Mode: ModeNone, Budget: 0}
case -1:
return ThinkingConfig{Mode: ModeAuto, Budget: -1}
default:
return ThinkingConfig{Mode: ModeBudget, Budget: value}
}
}
return ThinkingConfig{}
}
// extractOpenAIConfig extracts thinking configuration from OpenAI format request body.
//
// OpenAI API format:
// - reasoning_effort: "none", "low", "medium", "high" (discrete levels)
//
// OpenAI uses level-based thinking configuration only, no numeric budget support.
// The "none" value is treated specially to return ModeNone.
func extractOpenAIConfig(body []byte) ThinkingConfig {
// Check reasoning_effort (OpenAI Chat Completions format)
if effort := gjson.GetBytes(body, "reasoning_effort"); effort.Exists() {
value := effort.String()
if value == "none" {
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
}
return ThinkingConfig{}
}
// extractCodexConfig extracts thinking configuration from Codex format request body.
//
// Codex API format (OpenAI Responses API):
// - reasoning.effort: "none", "low", "medium", "high"
//
// This is similar to OpenAI but uses nested field "reasoning.effort" instead of "reasoning_effort".
func extractCodexConfig(body []byte) ThinkingConfig {
// Check reasoning.effort (Codex / OpenAI Responses API format)
if effort := gjson.GetBytes(body, "reasoning.effort"); effort.Exists() {
value := effort.String()
if value == "none" {
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
return ThinkingConfig{Mode: ModeLevel, Level: ThinkingLevel(value)}
}
return ThinkingConfig{}
}
// extractIFlowConfig extracts thinking configuration from iFlow format request body.
//
// iFlow API format (supports multiple model families):
// - GLM format: chat_template_kwargs.enable_thinking (boolean)
// - MiniMax format: reasoning_split (boolean)
//
// Returns ModeBudget with Budget=1 as a sentinel value indicating "enabled".
// The actual budget/configuration is determined by the iFlow applier based on model capabilities.
// Budget=1 is used because iFlow models don't use numeric budgets; they only support on/off.
func extractIFlowConfig(body []byte) ThinkingConfig {
// GLM format: chat_template_kwargs.enable_thinking
if enabled := gjson.GetBytes(body, "chat_template_kwargs.enable_thinking"); enabled.Exists() {
if enabled.Bool() {
// Budget=1 is a sentinel meaning "enabled" (iFlow doesn't use numeric budgets)
return ThinkingConfig{Mode: ModeBudget, Budget: 1}
}
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
// MiniMax format: reasoning_split
if split := gjson.GetBytes(body, "reasoning_split"); split.Exists() {
if split.Bool() {
// Budget=1 is a sentinel meaning "enabled" (iFlow doesn't use numeric budgets)
return ThinkingConfig{Mode: ModeBudget, Budget: 1}
}
return ThinkingConfig{Mode: ModeNone, Budget: 0}
}
return ThinkingConfig{}
}
================================================
FILE: internal/thinking/apply_user_defined_test.go
================================================
package thinking_test
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
"github.com/tidwall/gjson"
)
func TestApplyThinking_UserDefinedClaudePreservesAdaptiveLevel(t *testing.T) {
reg := registry.GetGlobalRegistry()
clientID := "test-user-defined-claude-" + t.Name()
modelID := "custom-claude-4-6"
reg.RegisterClient(clientID, "claude", []*registry.ModelInfo{{ID: modelID, UserDefined: true}})
t.Cleanup(func() {
reg.UnregisterClient(clientID)
})
tests := []struct {
name string
model string
body []byte
}{
{
name: "claude adaptive effort body",
model: modelID,
body: []byte(`{"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`),
},
{
name: "suffix level",
model: modelID + "(high)",
body: []byte(`{}`),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out, err := thinking.ApplyThinking(tt.body, tt.model, "openai", "claude", "claude")
if err != nil {
t.Fatalf("ApplyThinking() error = %v", err)
}
if got := gjson.GetBytes(out, "thinking.type").String(); got != "adaptive" {
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "adaptive", string(out))
}
if got := gjson.GetBytes(out, "output_config.effort").String(); got != "high" {
t.Fatalf("output_config.effort = %q, want %q, body=%s", got, "high", string(out))
}
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
}
})
}
}
================================================
FILE: internal/thinking/convert.go
================================================
package thinking
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
// levelToBudgetMap defines the standard Level → Budget mapping.
// All keys are lowercase; lookups should use strings.ToLower.
var levelToBudgetMap = map[string]int{
"none": 0,
"auto": -1,
"minimal": 512,
"low": 1024,
"medium": 8192,
"high": 24576,
"xhigh": 32768,
// "max" is used by Claude adaptive thinking effort. We map it to a large budget
// and rely on per-model clamping when converting to budget-only providers.
"max": 128000,
}
// ConvertLevelToBudget converts a thinking level to a budget value.
//
// This is a semantic conversion that maps discrete levels to numeric budgets.
// Level matching is case-insensitive.
//
// Level → Budget mapping:
// - none → 0
// - auto → -1
// - minimal → 512
// - low → 1024
// - medium → 8192
// - high → 24576
// - xhigh → 32768
// - max → 128000
//
// Returns:
// - budget: The converted budget value
// - ok: true if level is valid, false otherwise
func ConvertLevelToBudget(level string) (int, bool) {
budget, ok := levelToBudgetMap[strings.ToLower(level)]
return budget, ok
}
// BudgetThreshold constants define the upper bounds for each thinking level.
// These are used by ConvertBudgetToLevel for range-based mapping.
const (
// ThresholdMinimal is the upper bound for "minimal" level (1-512)
ThresholdMinimal = 512
// ThresholdLow is the upper bound for "low" level (513-1024)
ThresholdLow = 1024
// ThresholdMedium is the upper bound for "medium" level (1025-8192)
ThresholdMedium = 8192
// ThresholdHigh is the upper bound for "high" level (8193-24576)
ThresholdHigh = 24576
)
// ConvertBudgetToLevel converts a budget value to the nearest thinking level.
//
// This is a semantic conversion that maps numeric budgets to discrete levels.
// Uses threshold-based mapping for range conversion.
//
// Budget → Level thresholds:
// - -1 → auto
// - 0 → none
// - 1-512 → minimal
// - 513-1024 → low
// - 1025-8192 → medium
// - 8193-24576 → high
// - 24577+ → xhigh
//
// Returns:
// - level: The converted thinking level string
// - ok: true if budget is valid, false for invalid negatives (< -1)
func ConvertBudgetToLevel(budget int) (string, bool) {
switch {
case budget < -1:
// Invalid negative values
return "", false
case budget == -1:
return string(LevelAuto), true
case budget == 0:
return string(LevelNone), true
case budget <= ThresholdMinimal:
return string(LevelMinimal), true
case budget <= ThresholdLow:
return string(LevelLow), true
case budget <= ThresholdMedium:
return string(LevelMedium), true
case budget <= ThresholdHigh:
return string(LevelHigh), true
default:
return string(LevelXHigh), true
}
}
// HasLevel reports whether the given target level exists in the levels slice.
// Matching is case-insensitive with leading/trailing whitespace trimmed.
func HasLevel(levels []string, target string) bool {
for _, level := range levels {
if strings.EqualFold(strings.TrimSpace(level), target) {
return true
}
}
return false
}
// MapToClaudeEffort maps a generic thinking level string to a Claude adaptive
// thinking effort value (low/medium/high/max).
//
// supportsMax indicates whether the target model supports "max" effort.
// Returns the mapped effort and true if the level is valid, or ("", false) otherwise.
func MapToClaudeEffort(level string, supportsMax bool) (string, bool) {
level = strings.ToLower(strings.TrimSpace(level))
switch level {
case "":
return "", false
case "minimal":
return "low", true
case "low", "medium", "high":
return level, true
case "xhigh", "max":
if supportsMax {
return "max", true
}
return "high", true
case "auto":
return "high", true
default:
return "", false
}
}
// ModelCapability describes the thinking format support of a model.
type ModelCapability int
const (
// CapabilityUnknown indicates modelInfo is nil (passthrough behavior, internal use).
CapabilityUnknown ModelCapability = iota - 1
// CapabilityNone indicates model doesn't support thinking (Thinking is nil).
CapabilityNone
// CapabilityBudgetOnly indicates the model supports numeric budgets only.
CapabilityBudgetOnly
// CapabilityLevelOnly indicates the model supports discrete levels only.
CapabilityLevelOnly
// CapabilityHybrid indicates the model supports both budgets and levels.
CapabilityHybrid
)
// detectModelCapability determines the thinking format capability of a model.
//
// This is an internal function used by validation and conversion helpers.
// It analyzes the model's ThinkingSupport configuration to classify the model:
// - CapabilityNone: modelInfo.Thinking is nil (model doesn't support thinking)
// - CapabilityBudgetOnly: Has Min/Max but no Levels (Claude, Gemini 2.5)
// - CapabilityLevelOnly: Has Levels but no Min/Max (OpenAI, iFlow)
// - CapabilityHybrid: Has both Min/Max and Levels (Gemini 3)
//
// Note: Returns a special sentinel value when modelInfo itself is nil (unknown model).
func detectModelCapability(modelInfo *registry.ModelInfo) ModelCapability {
if modelInfo == nil {
return CapabilityUnknown // sentinel for "passthrough" behavior
}
if modelInfo.Thinking == nil {
return CapabilityNone
}
support := modelInfo.Thinking
hasBudget := support.Min > 0 || support.Max > 0
hasLevels := len(support.Levels) > 0
switch {
case hasBudget && hasLevels:
return CapabilityHybrid
case hasBudget:
return CapabilityBudgetOnly
case hasLevels:
return CapabilityLevelOnly
default:
return CapabilityNone
}
}
================================================
FILE: internal/thinking/errors.go
================================================
// Package thinking provides unified thinking configuration processing logic.
package thinking
import "net/http"
// ErrorCode represents the type of thinking configuration error.
type ErrorCode string
// Error codes for thinking configuration processing.
const (
// ErrInvalidSuffix indicates the suffix format cannot be parsed.
// Example: "model(abc" (missing closing parenthesis)
ErrInvalidSuffix ErrorCode = "INVALID_SUFFIX"
// ErrUnknownLevel indicates the level value is not in the valid list.
// Example: "model(ultra)" where "ultra" is not a valid level
ErrUnknownLevel ErrorCode = "UNKNOWN_LEVEL"
// ErrThinkingNotSupported indicates the model does not support thinking.
// Example: claude-haiku-4-5 does not have thinking capability
ErrThinkingNotSupported ErrorCode = "THINKING_NOT_SUPPORTED"
// ErrLevelNotSupported indicates the model does not support level mode.
// Example: using level with a budget-only model
ErrLevelNotSupported ErrorCode = "LEVEL_NOT_SUPPORTED"
// ErrBudgetOutOfRange indicates the budget value is outside model range.
// Example: budget 64000 exceeds max 20000
ErrBudgetOutOfRange ErrorCode = "BUDGET_OUT_OF_RANGE"
// ErrProviderMismatch indicates the provider does not match the model.
// Example: applying Claude format to a Gemini model
ErrProviderMismatch ErrorCode = "PROVIDER_MISMATCH"
)
// ThinkingError represents an error that occurred during thinking configuration processing.
//
// This error type provides structured information about the error, including:
// - Code: A machine-readable error code for programmatic handling
// - Message: A human-readable description of the error
// - Model: The model name related to the error (optional)
// - Details: Additional context information (optional)
type ThinkingError struct {
// Code is the machine-readable error code
Code ErrorCode
// Message is the human-readable error description.
// Should be lowercase, no trailing period, with context if applicable.
Message string
// Model is the model name related to this error (optional)
Model string
// Details contains additional context information (optional)
Details map[string]interface{}
}
// Error implements the error interface.
// Returns the message directly without code prefix.
// Use Code field for programmatic error handling.
func (e *ThinkingError) Error() string {
return e.Message
}
// NewThinkingError creates a new ThinkingError with the given code and message.
func NewThinkingError(code ErrorCode, message string) *ThinkingError {
return &ThinkingError{
Code: code,
Message: message,
}
}
// NewThinkingErrorWithModel creates a new ThinkingError with model context.
func NewThinkingErrorWithModel(code ErrorCode, message, model string) *ThinkingError {
return &ThinkingError{
Code: code,
Message: message,
Model: model,
}
}
// StatusCode implements a portable status code interface for HTTP handlers.
func (e *ThinkingError) StatusCode() int {
return http.StatusBadRequest
}
================================================
FILE: internal/thinking/provider/antigravity/apply.go
================================================
// Package antigravity implements thinking configuration for Antigravity API format.
//
// Antigravity uses request.generationConfig.thinkingConfig.* path (same as gemini-cli)
// but requires additional normalization for Claude models:
// - Ensure thinking budget < max_tokens
// - Remove thinkingConfig if budget < minimum allowed
package antigravity
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier applies thinking configuration for Antigravity API format.
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new Antigravity thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("antigravity", NewApplier())
}
// Apply applies thinking configuration to Antigravity request body.
//
// For Claude models, additional constraints are applied:
// - Ensure thinking budget < max_tokens
// - Remove thinkingConfig if budget < minimum allowed
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return a.applyCompatible(body, config, modelInfo)
}
if modelInfo.Thinking == nil {
return body, nil
}
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
isClaude := strings.Contains(strings.ToLower(modelInfo.ID), "claude")
// ModeAuto: Always use Budget format with thinkingBudget=-1
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
if config.Mode == thinking.ModeBudget {
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
// For non-auto modes, choose format based on model capabilities
support := modelInfo.Thinking
if len(support.Levels) > 0 {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
isClaude := false
if modelInfo != nil {
isClaude = strings.Contains(strings.ToLower(modelInfo.ID), "claude")
}
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config, modelInfo, isClaude)
}
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
if config.Mode == thinking.ModeNone {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
if config.Level != "" {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", string(config.Level))
}
return result, nil
}
// Only handle ModeLevel - budget conversion should be done by upper layer
if config.Mode != thinking.ModeLevel {
return body, nil
}
level := string(config.Level)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
// Respect user's explicit includeThoughts setting from original body; default to true if not set
// Support both camelCase and snake_case variants
includeThoughts := true
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
includeThoughts = inc.Bool()
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
includeThoughts = inc.Bool()
}
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
return result, nil
}
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo, isClaude bool) ([]byte, error) {
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
budget := config.Budget
// Apply Claude-specific constraints first to get the final budget value
if isClaude && modelInfo != nil {
budget, result = a.normalizeClaudeBudget(budget, result, modelInfo)
// Check if budget was removed entirely
if budget == -2 {
return result, nil
}
}
// For ModeNone, always set includeThoughts to false regardless of user setting.
// This ensures that when user requests budget=0 (disable thinking output),
// the includeThoughts is correctly set to false even if budget is clamped to min.
if config.Mode == thinking.ModeNone {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
return result, nil
}
// Determine includeThoughts: respect user's explicit setting from original body if provided
// Support both camelCase and snake_case variants
var includeThoughts bool
var userSetIncludeThoughts bool
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
includeThoughts = inc.Bool()
userSetIncludeThoughts = true
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
includeThoughts = inc.Bool()
userSetIncludeThoughts = true
}
if !userSetIncludeThoughts {
// No explicit setting, use default logic based on mode
switch config.Mode {
case thinking.ModeAuto:
includeThoughts = true
default:
includeThoughts = budget > 0
}
}
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
return result, nil
}
// normalizeClaudeBudget applies Claude-specific constraints to thinking budget.
//
// It handles:
// - Ensuring thinking budget < max_tokens
// - Removing thinkingConfig if budget < minimum allowed
//
// Returns the normalized budget and updated payload.
// Returns budget=-2 as a sentinel indicating thinkingConfig was removed entirely.
func (a *Applier) normalizeClaudeBudget(budget int, payload []byte, modelInfo *registry.ModelInfo) (int, []byte) {
if modelInfo == nil {
return budget, payload
}
// Get effective max tokens
effectiveMax, setDefaultMax := a.effectiveMaxTokens(payload, modelInfo)
if effectiveMax > 0 && budget >= effectiveMax {
budget = effectiveMax - 1
}
// Check minimum budget
minBudget := 0
if modelInfo.Thinking != nil {
minBudget = modelInfo.Thinking.Min
}
if minBudget > 0 && budget >= 0 && budget < minBudget {
// Budget is below minimum, remove thinking config entirely
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.thinkingConfig")
return -2, payload
}
// Set default max tokens if needed
if setDefaultMax && effectiveMax > 0 {
payload, _ = sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax)
}
return budget, payload
}
// effectiveMaxTokens returns the max tokens to cap thinking:
// prefer request-provided maxOutputTokens; otherwise fall back to model default.
// The boolean indicates whether the value came from the model default (and thus should be written back).
func (a *Applier) effectiveMaxTokens(payload []byte, modelInfo *registry.ModelInfo) (max int, fromModel bool) {
if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 {
return int(maxTok.Int()), false
}
if modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
return modelInfo.MaxCompletionTokens, true
}
return 0, false
}
================================================
FILE: internal/thinking/provider/claude/apply.go
================================================
// Package claude implements thinking configuration scaffolding for Claude models.
//
// Claude models support two thinking control styles:
// - Manual thinking: thinking.type="enabled" with thinking.budget_tokens (token budget)
// - Adaptive thinking (Claude 4.6): thinking.type="adaptive" with output_config.effort (low/medium/high/max)
//
// Some Claude models support ZeroAllowed (sonnet-4-5, opus-4-5), while older models do not.
// See: _bmad-output/planning-artifacts/architecture.md#Epic-6
package claude
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier implements thinking.ProviderApplier for Claude models.
// This applier is stateless and holds no configuration.
type Applier struct{}
// NewApplier creates a new Claude thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("claude", NewApplier())
}
// Apply applies thinking configuration to Claude request body.
//
// IMPORTANT: This method expects config to be pre-validated by thinking.ValidateConfig.
// ValidateConfig handles:
// - Mode conversion (Level→Budget, Auto→Budget)
// - Budget clamping to model range
// - ZeroAllowed constraint enforcement
//
// Apply processes:
// - ModeBudget: manual thinking budget_tokens
// - ModeLevel: adaptive thinking effort (Claude 4.6)
// - ModeAuto: provider default adaptive/manual behavior
// - ModeNone: disabled
//
// Expected output format when enabled:
//
// {
// "thinking": {
// "type": "enabled",
// "budget_tokens": 16384
// }
// }
//
// Expected output format for adaptive:
//
// {
// "thinking": {
// "type": "adaptive"
// },
// "output_config": {
// "effort": "high"
// }
// }
//
// Expected output format when disabled:
//
// {
// "thinking": {
// "type": "disabled"
// }
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return applyCompatibleClaude(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
supportsAdaptive := modelInfo != nil && modelInfo.Thinking != nil && len(modelInfo.Thinking.Levels) > 0
switch config.Mode {
case thinking.ModeNone:
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
result, _ = sjson.DeleteBytes(result, "output_config.effort")
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
result, _ = sjson.DeleteBytes(result, "output_config")
}
return result, nil
case thinking.ModeLevel:
// Adaptive thinking effort is only valid when the model advertises discrete levels.
// (Claude 4.6 uses output_config.effort.)
if supportsAdaptive && config.Level != "" {
result, _ := sjson.SetBytes(body, "thinking.type", "adaptive")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
result, _ = sjson.SetBytes(result, "output_config.effort", string(config.Level))
return result, nil
}
// Fallback for non-adaptive Claude models: convert level to budget_tokens.
if budget, ok := thinking.ConvertLevelToBudget(string(config.Level)); ok {
config.Mode = thinking.ModeBudget
config.Budget = budget
config.Level = ""
} else {
return body, nil
}
fallthrough
case thinking.ModeBudget:
// Budget is expected to be pre-validated by ValidateConfig (clamped, ZeroAllowed enforced).
// Decide enabled/disabled based on budget value.
if config.Budget == 0 {
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
result, _ = sjson.DeleteBytes(result, "output_config.effort")
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
result, _ = sjson.DeleteBytes(result, "output_config")
}
return result, nil
}
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
result, _ = sjson.DeleteBytes(result, "output_config.effort")
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
result, _ = sjson.DeleteBytes(result, "output_config")
}
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint).
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
return result, nil
case thinking.ModeAuto:
// For Claude 4.6 models, auto maps to adaptive thinking with upstream defaults.
if supportsAdaptive {
result, _ := sjson.SetBytes(body, "thinking.type", "adaptive")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
// Explicit effort is optional for adaptive thinking; omit it to allow upstream default.
result, _ = sjson.DeleteBytes(result, "output_config.effort")
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
result, _ = sjson.DeleteBytes(result, "output_config")
}
return result, nil
}
// Legacy fallback: enable thinking without specifying budget_tokens.
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
result, _ = sjson.DeleteBytes(result, "output_config.effort")
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
result, _ = sjson.DeleteBytes(result, "output_config")
}
return result, nil
default:
return body, nil
}
}
// normalizeClaudeBudget applies Claude-specific constraints to ensure max_tokens > budget_tokens.
// Anthropic API requires this constraint; violating it returns a 400 error.
func (a *Applier) normalizeClaudeBudget(body []byte, budgetTokens int, modelInfo *registry.ModelInfo) []byte {
if budgetTokens <= 0 {
return body
}
// Ensure the request satisfies Claude constraints:
// 1) Determine effective max_tokens (request overrides model default)
// 2) If budget_tokens >= max_tokens, reduce budget_tokens to max_tokens-1
// 3) If the adjusted budget falls below the model minimum, leave the request unchanged
// 4) If max_tokens came from model default, write it back into the request
effectiveMax, setDefaultMax := a.effectiveMaxTokens(body, modelInfo)
if setDefaultMax && effectiveMax > 0 {
body, _ = sjson.SetBytes(body, "max_tokens", effectiveMax)
}
// Compute the budget we would apply after enforcing budget_tokens < max_tokens.
adjustedBudget := budgetTokens
if effectiveMax > 0 && adjustedBudget >= effectiveMax {
adjustedBudget = effectiveMax - 1
}
minBudget := 0
if modelInfo != nil && modelInfo.Thinking != nil {
minBudget = modelInfo.Thinking.Min
}
if minBudget > 0 && adjustedBudget > 0 && adjustedBudget < minBudget {
// If enforcing the max_tokens constraint would push the budget below the model minimum,
// leave the request unchanged.
return body
}
if adjustedBudget != budgetTokens {
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", adjustedBudget)
}
return body
}
// effectiveMaxTokens returns the max tokens to cap thinking:
// prefer request-provided max_tokens; otherwise fall back to model default.
// The boolean indicates whether the value came from the model default (and thus should be written back).
func (a *Applier) effectiveMaxTokens(body []byte, modelInfo *registry.ModelInfo) (max int, fromModel bool) {
if maxTok := gjson.GetBytes(body, "max_tokens"); maxTok.Exists() && maxTok.Int() > 0 {
return int(maxTok.Int()), false
}
if modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
return modelInfo.MaxCompletionTokens, true
}
return 0, false
}
func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto && config.Mode != thinking.ModeLevel {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
switch config.Mode {
case thinking.ModeNone:
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
result, _ = sjson.DeleteBytes(result, "output_config.effort")
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
result, _ = sjson.DeleteBytes(result, "output_config")
}
return result, nil
case thinking.ModeAuto:
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
result, _ = sjson.DeleteBytes(result, "output_config.effort")
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
result, _ = sjson.DeleteBytes(result, "output_config")
}
return result, nil
case thinking.ModeLevel:
// For user-defined models, interpret ModeLevel as Claude adaptive thinking effort.
// Upstream is responsible for validating whether the target model supports it.
if config.Level == "" {
return body, nil
}
result, _ := sjson.SetBytes(body, "thinking.type", "adaptive")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
result, _ = sjson.SetBytes(result, "output_config.effort", string(config.Level))
return result, nil
default:
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
result, _ = sjson.DeleteBytes(result, "output_config.effort")
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
result, _ = sjson.DeleteBytes(result, "output_config")
}
return result, nil
}
}
================================================
FILE: internal/thinking/provider/codex/apply.go
================================================
// Package codex implements thinking configuration for Codex (OpenAI Responses API) models.
//
// Codex models use the reasoning.effort format with discrete levels
// (low/medium/high). This is similar to OpenAI but uses nested field
// "reasoning.effort" instead of "reasoning_effort".
// See: _bmad-output/planning-artifacts/architecture.md#Epic-8
package codex
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier implements thinking.ProviderApplier for Codex models.
//
// Codex-specific behavior:
// - Output format: reasoning.effort (string: low/medium/high/xhigh)
// - Level-only mode: no numeric budget support
// - Some models support ZeroAllowed (gpt-5.1, gpt-5.2)
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new Codex thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("codex", NewApplier())
}
// Apply applies thinking configuration to Codex request body.
//
// Expected output format:
//
// {
// "reasoning": {
// "effort": "high"
// }
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return applyCompatibleCodex(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
// Only handle ModeLevel and ModeNone; other modes pass through unchanged.
if config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
if config.Mode == thinking.ModeLevel {
result, _ := sjson.SetBytes(body, "reasoning.effort", string(config.Level))
return result, nil
}
effort := ""
support := modelInfo.Thinking
if config.Budget == 0 {
if support.ZeroAllowed || thinking.HasLevel(support.Levels, string(thinking.LevelNone)) {
effort = string(thinking.LevelNone)
}
}
if effort == "" && config.Level != "" {
effort = string(config.Level)
}
if effort == "" && len(support.Levels) > 0 {
effort = support.Levels[0]
}
if effort == "" {
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning.effort", effort)
return result, nil
}
func applyCompatibleCodex(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
var effort string
switch config.Mode {
case thinking.ModeLevel:
if config.Level == "" {
return body, nil
}
effort = string(config.Level)
case thinking.ModeNone:
effort = string(thinking.LevelNone)
if config.Level != "" {
effort = string(config.Level)
}
case thinking.ModeAuto:
// Auto mode for user-defined models: pass through as "auto"
effort = string(thinking.LevelAuto)
case thinking.ModeBudget:
// Budget mode: convert budget to level using threshold mapping
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
if !ok {
return body, nil
}
effort = level
default:
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning.effort", effort)
return result, nil
}
================================================
FILE: internal/thinking/provider/gemini/apply.go
================================================
// Package gemini implements thinking configuration for Gemini models.
//
// Gemini models have two formats:
// - Gemini 2.5: Uses thinkingBudget (numeric)
// - Gemini 3.x: Uses thinkingLevel (string: minimal/low/medium/high)
// or thinkingBudget=-1 for auto/dynamic mode
//
// Output format is determined by ThinkingConfig.Mode and ThinkingSupport.Levels:
// - ModeAuto: Always uses thinkingBudget=-1 (both Gemini 2.5 and 3.x)
// - len(Levels) > 0: Uses thinkingLevel (Gemini 3.x discrete levels)
// - len(Levels) == 0: Uses thinkingBudget (Gemini 2.5)
package gemini
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier applies thinking configuration for Gemini models.
//
// Gemini-specific behavior:
// - Gemini 2.5: thinkingBudget format, flash series supports ZeroAllowed
// - Gemini 3.x: thinkingLevel format, cannot be disabled
// - Use ThinkingSupport.Levels to decide output format
type Applier struct{}
// NewApplier creates a new Gemini thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("gemini", NewApplier())
}
// Apply applies thinking configuration to Gemini request body.
//
// Expected output format (Gemini 2.5):
//
// {
// "generationConfig": {
// "thinkingConfig": {
// "thinkingBudget": 8192,
// "includeThoughts": true
// }
// }
// }
//
// Expected output format (Gemini 3.x):
//
// {
// "generationConfig": {
// "thinkingConfig": {
// "thinkingLevel": "high",
// "includeThoughts": true
// }
// }
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return a.applyCompatible(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
// Choose format based on config.Mode and model capabilities:
// - ModeLevel: use Level format (validation will reject unsupported levels)
// - ModeNone: use Level format if model has Levels, else Budget format
// - ModeBudget/ModeAuto: use Budget format
switch config.Mode {
case thinking.ModeLevel:
return a.applyLevelFormat(body, config)
case thinking.ModeNone:
// ModeNone: route based on model capability (has Levels or not)
if len(modelInfo.Thinking.Levels) > 0 {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config)
default:
return a.applyBudgetFormat(body, config)
}
}
func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config)
}
if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config)
}
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// ModeNone semantics:
// - ModeNone + Budget=0: completely disable thinking (not possible for Level-only models)
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
// ValidateConfig sets config.Level to the lowest level when ModeNone + Budget > 0.
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingBudget")
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_budget")
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_level")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
if config.Mode == thinking.ModeNone {
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", false)
if config.Level != "" {
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingLevel", string(config.Level))
}
return result, nil
}
// Only handle ModeLevel - budget conversion should be done by upper layer
if config.Mode != thinking.ModeLevel {
return body, nil
}
level := string(config.Level)
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingLevel", level)
// Respect user's explicit includeThoughts setting from original body; default to true if not set
// Support both camelCase and snake_case variants
includeThoughts := true
if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
includeThoughts = inc.Bool()
} else if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
includeThoughts = inc.Bool()
}
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", includeThoughts)
return result, nil
}
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingLevel")
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_level")
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_budget")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
budget := config.Budget
// For ModeNone, always set includeThoughts to false regardless of user setting.
// This ensures that when user requests budget=0 (disable thinking output),
// the includeThoughts is correctly set to false even if budget is clamped to min.
if config.Mode == thinking.ModeNone {
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", false)
return result, nil
}
// Determine includeThoughts: respect user's explicit setting from original body if provided
// Support both camelCase and snake_case variants
var includeThoughts bool
var userSetIncludeThoughts bool
if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
includeThoughts = inc.Bool()
userSetIncludeThoughts = true
} else if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
includeThoughts = inc.Bool()
userSetIncludeThoughts = true
}
if !userSetIncludeThoughts {
// No explicit setting, use default logic based on mode
switch config.Mode {
case thinking.ModeAuto:
includeThoughts = true
default:
includeThoughts = budget > 0
}
}
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", includeThoughts)
return result, nil
}
================================================
FILE: internal/thinking/provider/geminicli/apply.go
================================================
// Package geminicli implements thinking configuration for Gemini CLI API format.
//
// Gemini CLI uses request.generationConfig.thinkingConfig.* path instead of
// generationConfig.thinkingConfig.* used by standard Gemini API.
package geminicli
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier applies thinking configuration for Gemini CLI API format.
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new Gemini CLI thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("gemini-cli", NewApplier())
}
// Apply applies thinking configuration to Gemini CLI request body.
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return a.applyCompatible(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
// ModeAuto: Always use Budget format with thinkingBudget=-1
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config)
}
if config.Mode == thinking.ModeBudget {
return a.applyBudgetFormat(body, config)
}
// For non-auto modes, choose format based on model capabilities
support := modelInfo.Thinking
if len(support.Levels) > 0 {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config)
}
func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if config.Mode != thinking.ModeBudget && config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone && config.Mode != thinking.ModeAuto {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
if config.Mode == thinking.ModeAuto {
return a.applyBudgetFormat(body, config)
}
if config.Mode == thinking.ModeLevel || (config.Mode == thinking.ModeNone && config.Level != "") {
return a.applyLevelFormat(body, config)
}
return a.applyBudgetFormat(body, config)
}
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
if config.Mode == thinking.ModeNone {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
if config.Level != "" {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", string(config.Level))
}
return result, nil
}
// Only handle ModeLevel - budget conversion should be done by upper layer
if config.Mode != thinking.ModeLevel {
return body, nil
}
level := string(config.Level)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
// Respect user's explicit includeThoughts setting from original body; default to true if not set
// Support both camelCase and snake_case variants
includeThoughts := true
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
includeThoughts = inc.Bool()
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
includeThoughts = inc.Bool()
}
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
return result, nil
}
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
budget := config.Budget
// For ModeNone, always set includeThoughts to false regardless of user setting.
// This ensures that when user requests budget=0 (disable thinking output),
// the includeThoughts is correctly set to false even if budget is clamped to min.
if config.Mode == thinking.ModeNone {
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
return result, nil
}
// Determine includeThoughts: respect user's explicit setting from original body if provided
// Support both camelCase and snake_case variants
var includeThoughts bool
var userSetIncludeThoughts bool
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
includeThoughts = inc.Bool()
userSetIncludeThoughts = true
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
includeThoughts = inc.Bool()
userSetIncludeThoughts = true
}
if !userSetIncludeThoughts {
// No explicit setting, use default logic based on mode
switch config.Mode {
case thinking.ModeAuto:
includeThoughts = true
default:
includeThoughts = budget > 0
}
}
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
return result, nil
}
================================================
FILE: internal/thinking/provider/iflow/apply.go
================================================
// Package iflow implements thinking configuration for iFlow models.
//
// iFlow models use boolean toggle semantics:
// - Models using chat_template_kwargs.enable_thinking (boolean toggle)
// - MiniMax models: reasoning_split (boolean)
//
// Level values are converted to boolean: none=false, all others=true
// See: _bmad-output/planning-artifacts/architecture.md#Epic-9
package iflow
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier implements thinking.ProviderApplier for iFlow models.
//
// iFlow-specific behavior:
// - enable_thinking toggle models: enable_thinking boolean
// - GLM models: enable_thinking boolean + clear_thinking=false
// - MiniMax models: reasoning_split boolean
// - Level to boolean: none=false, others=true
// - No quantized support (only on/off)
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new iFlow thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("iflow", NewApplier())
}
// Apply applies thinking configuration to iFlow request body.
//
// Expected output format (GLM):
//
// {
// "chat_template_kwargs": {
// "enable_thinking": true,
// "clear_thinking": false
// }
// }
//
// Expected output format (MiniMax):
//
// {
// "reasoning_split": true
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return body, nil
}
if modelInfo.Thinking == nil {
return body, nil
}
if isEnableThinkingModel(modelInfo.ID) {
return applyEnableThinking(body, config, isGLMModel(modelInfo.ID)), nil
}
if isMiniMaxModel(modelInfo.ID) {
return applyMiniMax(body, config), nil
}
return body, nil
}
// configToBoolean converts ThinkingConfig to boolean for iFlow models.
//
// Conversion rules:
// - ModeNone: false
// - ModeAuto: true
// - ModeBudget + Budget=0: false
// - ModeBudget + Budget>0: true
// - ModeLevel + Level="none": false
// - ModeLevel + any other level: true
// - Default (unknown mode): true
func configToBoolean(config thinking.ThinkingConfig) bool {
switch config.Mode {
case thinking.ModeNone:
return false
case thinking.ModeAuto:
return true
case thinking.ModeBudget:
return config.Budget > 0
case thinking.ModeLevel:
return config.Level != thinking.LevelNone
default:
return true
}
}
// applyEnableThinking applies thinking configuration for models that use
// chat_template_kwargs.enable_thinking format.
//
// Output format when enabled:
//
// {"chat_template_kwargs": {"enable_thinking": true, "clear_thinking": false}}
//
// Output format when disabled:
//
// {"chat_template_kwargs": {"enable_thinking": false}}
//
// Note: clear_thinking is only set for GLM models when thinking is enabled.
func applyEnableThinking(body []byte, config thinking.ThinkingConfig, setClearThinking bool) []byte {
enableThinking := configToBoolean(config)
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
result, _ := sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
// clear_thinking is a GLM-only knob, strip it for other models.
result, _ = sjson.DeleteBytes(result, "chat_template_kwargs.clear_thinking")
// clear_thinking only needed when thinking is enabled
if enableThinking && setClearThinking {
result, _ = sjson.SetBytes(result, "chat_template_kwargs.clear_thinking", false)
}
return result
}
// applyMiniMax applies thinking configuration for MiniMax models.
//
// Output format:
//
// {"reasoning_split": true/false}
func applyMiniMax(body []byte, config thinking.ThinkingConfig) []byte {
reasoningSplit := configToBoolean(config)
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
result, _ := sjson.SetBytes(body, "reasoning_split", reasoningSplit)
return result
}
// isEnableThinkingModel determines if the model uses chat_template_kwargs.enable_thinking format.
func isEnableThinkingModel(modelID string) bool {
if isGLMModel(modelID) {
return true
}
id := strings.ToLower(modelID)
switch id {
case "qwen3-max-preview", "deepseek-v3.2", "deepseek-v3.1":
return true
default:
return false
}
}
// isGLMModel determines if the model is a GLM series model.
func isGLMModel(modelID string) bool {
return strings.HasPrefix(strings.ToLower(modelID), "glm")
}
// isMiniMaxModel determines if the model is a MiniMax series model.
// MiniMax models use reasoning_split format.
func isMiniMaxModel(modelID string) bool {
return strings.HasPrefix(strings.ToLower(modelID), "minimax")
}
================================================
FILE: internal/thinking/provider/kimi/apply.go
================================================
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
//
// Kimi models use the OpenAI-compatible reasoning_effort format for enabled thinking
// levels, but use thinking.type=disabled when thinking is explicitly turned off.
package kimi
import (
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier implements thinking.ProviderApplier for Kimi models.
//
// Kimi-specific behavior:
// - Enabled thinking: reasoning_effort (string levels)
// - Disabled thinking: thinking.type="disabled"
// - Supports budget-to-level conversion
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new Kimi thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("kimi", NewApplier())
}
// Apply applies thinking configuration to Kimi request body.
//
// Expected output format (enabled):
//
// {
// "reasoning_effort": "high"
// }
//
// Expected output format (disabled):
//
// {
// "thinking": {
// "type": "disabled"
// }
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return applyCompatibleKimi(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
var effort string
switch config.Mode {
case thinking.ModeLevel:
if config.Level == "" {
return body, nil
}
effort = string(config.Level)
case thinking.ModeNone:
// Respect clamped fallback level for models that cannot disable thinking.
if config.Level != "" && config.Level != thinking.LevelNone {
effort = string(config.Level)
break
}
// Kimi requires explicit disabled thinking object.
return applyDisabledThinking(body)
case thinking.ModeBudget:
// Convert budget to level using threshold mapping
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
if !ok {
return body, nil
}
effort = level
case thinking.ModeAuto:
// Auto mode maps to "auto" effort
effort = string(thinking.LevelAuto)
default:
return body, nil
}
if effort == "" {
return body, nil
}
return applyReasoningEffort(body, effort)
}
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
var effort string
switch config.Mode {
case thinking.ModeLevel:
if config.Level == "" {
return body, nil
}
effort = string(config.Level)
case thinking.ModeNone:
if config.Level == "" || config.Level == thinking.LevelNone {
return applyDisabledThinking(body)
}
if config.Level != "" {
effort = string(config.Level)
}
case thinking.ModeAuto:
effort = string(thinking.LevelAuto)
case thinking.ModeBudget:
// Convert budget to level
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
if !ok {
return body, nil
}
effort = level
default:
return body, nil
}
return applyReasoningEffort(body, effort)
}
func applyReasoningEffort(body []byte, effort string) ([]byte, error) {
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
if errDeleteThinking != nil {
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
}
result, errSetEffort := sjson.SetBytes(result, "reasoning_effort", effort)
if errSetEffort != nil {
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", errSetEffort)
}
return result, nil
}
func applyDisabledThinking(body []byte) ([]byte, error) {
result, errDeleteThinking := sjson.DeleteBytes(body, "thinking")
if errDeleteThinking != nil {
return body, fmt.Errorf("kimi thinking: failed to clear thinking object: %w", errDeleteThinking)
}
result, errDeleteEffort := sjson.DeleteBytes(result, "reasoning_effort")
if errDeleteEffort != nil {
return body, fmt.Errorf("kimi thinking: failed to clear reasoning_effort: %w", errDeleteEffort)
}
result, errSetType := sjson.SetBytes(result, "thinking.type", "disabled")
if errSetType != nil {
return body, fmt.Errorf("kimi thinking: failed to set thinking.type: %w", errSetType)
}
return result, nil
}
================================================
FILE: internal/thinking/provider/kimi/apply_test.go
================================================
package kimi
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
)
func TestApply_ModeNone_UsesDisabledThinking(t *testing.T) {
applier := NewApplier()
modelInfo := ®istry.ModelInfo{
ID: "kimi-k2.5",
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
}
body := []byte(`{"model":"kimi-k2.5","reasoning_effort":"none","thinking":{"type":"enabled","budget_tokens":2048}}`)
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
if errApply != nil {
t.Fatalf("Apply() error = %v", errApply)
}
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
}
if gjson.GetBytes(out, "thinking.budget_tokens").Exists() {
t.Fatalf("thinking.budget_tokens should be removed, body=%s", string(out))
}
if gjson.GetBytes(out, "reasoning_effort").Exists() {
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
}
}
func TestApply_ModeLevel_UsesReasoningEffort(t *testing.T) {
applier := NewApplier()
modelInfo := ®istry.ModelInfo{
ID: "kimi-k2.5",
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
}
body := []byte(`{"model":"kimi-k2.5","thinking":{"type":"disabled"}}`)
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeLevel, Level: thinking.LevelHigh}, modelInfo)
if errApply != nil {
t.Fatalf("Apply() error = %v", errApply)
}
if got := gjson.GetBytes(out, "reasoning_effort").String(); got != "high" {
t.Fatalf("reasoning_effort = %q, want %q, body=%s", got, "high", string(out))
}
if gjson.GetBytes(out, "thinking").Exists() {
t.Fatalf("thinking should be removed when reasoning_effort is used, body=%s", string(out))
}
}
func TestApply_UserDefinedModeNone_UsesDisabledThinking(t *testing.T) {
applier := NewApplier()
modelInfo := ®istry.ModelInfo{
ID: "custom-kimi-model",
UserDefined: true,
}
body := []byte(`{"model":"custom-kimi-model","reasoning_effort":"none"}`)
out, errApply := applier.Apply(body, thinking.ThinkingConfig{Mode: thinking.ModeNone}, modelInfo)
if errApply != nil {
t.Fatalf("Apply() error = %v", errApply)
}
if got := gjson.GetBytes(out, "thinking.type").String(); got != "disabled" {
t.Fatalf("thinking.type = %q, want %q, body=%s", got, "disabled", string(out))
}
if gjson.GetBytes(out, "reasoning_effort").Exists() {
t.Fatalf("reasoning_effort should be removed in ModeNone, body=%s", string(out))
}
}
================================================
FILE: internal/thinking/provider/openai/apply.go
================================================
// Package openai implements thinking configuration for OpenAI/Codex models.
//
// OpenAI models use the reasoning_effort format with discrete levels
// (low/medium/high). Some models support xhigh and none levels.
// See: _bmad-output/planning-artifacts/architecture.md#Epic-8
package openai
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Applier implements thinking.ProviderApplier for OpenAI models.
//
// OpenAI-specific behavior:
// - Output format: reasoning_effort (string: low/medium/high/xhigh)
// - Level-only mode: no numeric budget support
// - Some models support ZeroAllowed (gpt-5.1, gpt-5.2)
type Applier struct{}
var _ thinking.ProviderApplier = (*Applier)(nil)
// NewApplier creates a new OpenAI thinking applier.
func NewApplier() *Applier {
return &Applier{}
}
func init() {
thinking.RegisterProvider("openai", NewApplier())
}
// Apply applies thinking configuration to OpenAI request body.
//
// Expected output format:
//
// {
// "reasoning_effort": "high"
// }
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
if thinking.IsUserDefinedModel(modelInfo) {
return applyCompatibleOpenAI(body, config)
}
if modelInfo.Thinking == nil {
return body, nil
}
// Only handle ModeLevel and ModeNone; other modes pass through unchanged.
if config.Mode != thinking.ModeLevel && config.Mode != thinking.ModeNone {
return body, nil
}
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
if config.Mode == thinking.ModeLevel {
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
return result, nil
}
effort := ""
support := modelInfo.Thinking
if config.Budget == 0 {
if support.ZeroAllowed || thinking.HasLevel(support.Levels, string(thinking.LevelNone)) {
effort = string(thinking.LevelNone)
}
}
if effort == "" && config.Level != "" {
effort = string(config.Level)
}
if effort == "" && len(support.Levels) > 0 {
effort = support.Levels[0]
}
if effort == "" {
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}
func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
if len(body) == 0 || !gjson.ValidBytes(body) {
body = []byte(`{}`)
}
var effort string
switch config.Mode {
case thinking.ModeLevel:
if config.Level == "" {
return body, nil
}
effort = string(config.Level)
case thinking.ModeNone:
effort = string(thinking.LevelNone)
if config.Level != "" {
effort = string(config.Level)
}
case thinking.ModeAuto:
// Auto mode for user-defined models: pass through as "auto"
effort = string(thinking.LevelAuto)
case thinking.ModeBudget:
// Budget mode: convert budget to level using threshold mapping
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
if !ok {
return body, nil
}
effort = level
default:
return body, nil
}
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
return result, nil
}
================================================
FILE: internal/thinking/strip.go
================================================
// Package thinking provides unified thinking configuration processing.
package thinking
import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// StripThinkingConfig removes thinking configuration fields from request body.
//
// This function is used when a model doesn't support thinking but the request
// contains thinking configuration. The configuration is silently removed to
// prevent upstream API errors.
//
// Parameters:
// - body: Original request body JSON
// - provider: Provider name (determines which fields to strip)
//
// Returns:
// - Modified request body JSON with thinking configuration removed
// - Original body is returned unchanged if:
// - body is empty or invalid JSON
// - provider is unknown
// - no thinking configuration found
func StripThinkingConfig(body []byte, provider string) []byte {
if len(body) == 0 || !gjson.ValidBytes(body) {
return body
}
var paths []string
switch provider {
case "claude":
paths = []string{"thinking", "output_config.effort"}
case "gemini":
paths = []string{"generationConfig.thinkingConfig"}
case "gemini-cli", "antigravity":
paths = []string{"request.generationConfig.thinkingConfig"}
case "openai":
paths = []string{"reasoning_effort"}
case "kimi":
paths = []string{
"reasoning_effort",
"thinking",
}
case "codex":
paths = []string{"reasoning.effort"}
case "iflow":
paths = []string{
"chat_template_kwargs.enable_thinking",
"chat_template_kwargs.clear_thinking",
"reasoning_split",
"reasoning_effort",
}
default:
return body
}
result := body
for _, path := range paths {
result, _ = sjson.DeleteBytes(result, path)
}
// Avoid leaving an empty output_config object for Claude when effort was the only field.
if provider == "claude" {
if oc := gjson.GetBytes(result, "output_config"); oc.Exists() && oc.IsObject() && len(oc.Map()) == 0 {
result, _ = sjson.DeleteBytes(result, "output_config")
}
}
return result
}
================================================
FILE: internal/thinking/suffix.go
================================================
// Package thinking provides unified thinking configuration processing.
//
// This file implements suffix parsing functionality for extracting
// thinking configuration from model names in the format model(value).
package thinking
import (
"strconv"
"strings"
)
// ParseSuffix extracts thinking suffix from a model name.
//
// The suffix format is: model-name(value)
// Examples:
// - "claude-sonnet-4-5(16384)" -> ModelName="claude-sonnet-4-5", RawSuffix="16384"
// - "gpt-5.2(high)" -> ModelName="gpt-5.2", RawSuffix="high"
// - "gemini-2.5-pro" -> ModelName="gemini-2.5-pro", HasSuffix=false
//
// This function only extracts the suffix; it does not validate or interpret
// the suffix content. Use ParseNumericSuffix, ParseLevelSuffix, etc. for
// content interpretation.
func ParseSuffix(model string) SuffixResult {
// Find the last opening parenthesis
lastOpen := strings.LastIndex(model, "(")
if lastOpen == -1 {
return SuffixResult{ModelName: model, HasSuffix: false}
}
// Check if the string ends with a closing parenthesis
if !strings.HasSuffix(model, ")") {
return SuffixResult{ModelName: model, HasSuffix: false}
}
// Extract components
modelName := model[:lastOpen]
rawSuffix := model[lastOpen+1 : len(model)-1]
return SuffixResult{
ModelName: modelName,
HasSuffix: true,
RawSuffix: rawSuffix,
}
}
// ParseNumericSuffix attempts to parse a raw suffix as a numeric budget value.
//
// This function parses the raw suffix content (from ParseSuffix.RawSuffix) as an integer.
// Only non-negative integers are considered valid numeric suffixes.
//
// Platform note: The budget value uses Go's int type, which is 32-bit on 32-bit
// systems and 64-bit on 64-bit systems. Values exceeding the platform's int range
// will return ok=false.
//
// Leading zeros are accepted: "08192" parses as 8192.
//
// Examples:
// - "8192" -> budget=8192, ok=true
// - "0" -> budget=0, ok=true (represents ModeNone)
// - "08192" -> budget=8192, ok=true (leading zeros accepted)
// - "-1" -> budget=0, ok=false (negative numbers are not valid numeric suffixes)
// - "high" -> budget=0, ok=false (not a number)
// - "9223372036854775808" -> budget=0, ok=false (overflow on 64-bit systems)
//
// For special handling of -1 as auto mode, use ParseSpecialSuffix instead.
func ParseNumericSuffix(rawSuffix string) (budget int, ok bool) {
if rawSuffix == "" {
return 0, false
}
value, err := strconv.Atoi(rawSuffix)
if err != nil {
return 0, false
}
// Negative numbers are not valid numeric suffixes
// -1 should be handled by special value parsing as "auto"
if value < 0 {
return 0, false
}
return value, true
}
// ParseSpecialSuffix attempts to parse a raw suffix as a special thinking mode value.
//
// This function handles special strings that represent a change in thinking mode:
// - "none" -> ModeNone (disables thinking)
// - "auto" -> ModeAuto (automatic/dynamic thinking)
// - "-1" -> ModeAuto (numeric representation of auto mode)
//
// String values are case-insensitive.
func ParseSpecialSuffix(rawSuffix string) (mode ThinkingMode, ok bool) {
if rawSuffix == "" {
return ModeBudget, false
}
// Case-insensitive matching
switch strings.ToLower(rawSuffix) {
case "none":
return ModeNone, true
case "auto", "-1":
return ModeAuto, true
default:
return ModeBudget, false
}
}
// ParseLevelSuffix attempts to parse a raw suffix as a discrete thinking level.
//
// This function parses the raw suffix content (from ParseSuffix.RawSuffix) as a level.
// Only discrete effort levels are valid: minimal, low, medium, high, xhigh, max.
// Level matching is case-insensitive.
//
// Special values (none, auto) are NOT handled by this function; use ParseSpecialSuffix
// instead. This separation allows callers to prioritize special value handling.
//
// Examples:
// - "high" -> level=LevelHigh, ok=true
// - "HIGH" -> level=LevelHigh, ok=true (case insensitive)
// - "medium" -> level=LevelMedium, ok=true
// - "none" -> level="", ok=false (special value, use ParseSpecialSuffix)
// - "auto" -> level="", ok=false (special value, use ParseSpecialSuffix)
// - "8192" -> level="", ok=false (numeric, use ParseNumericSuffix)
// - "ultra" -> level="", ok=false (unknown level)
func ParseLevelSuffix(rawSuffix string) (level ThinkingLevel, ok bool) {
if rawSuffix == "" {
return "", false
}
// Case-insensitive matching
switch strings.ToLower(rawSuffix) {
case "minimal":
return LevelMinimal, true
case "low":
return LevelLow, true
case "medium":
return LevelMedium, true
case "high":
return LevelHigh, true
case "xhigh":
return LevelXHigh, true
case "max":
return LevelMax, true
default:
return "", false
}
}
================================================
FILE: internal/thinking/text.go
================================================
package thinking
import (
"github.com/tidwall/gjson"
)
// GetThinkingText extracts the thinking text from a content part.
// Handles various formats:
// - Simple string: { "thinking": "text" } or { "text": "text" }
// - Wrapped object: { "thinking": { "text": "text", "cache_control": {...} } }
// - Gemini-style: { "thought": true, "text": "text" }
// Returns the extracted text string.
func GetThinkingText(part gjson.Result) string {
// Try direct text field first (Gemini-style)
if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
}
// Try thinking field
thinkingField := part.Get("thinking")
if !thinkingField.Exists() {
return ""
}
// thinking is a string
if thinkingField.Type == gjson.String {
return thinkingField.String()
}
// thinking is an object with inner text/thinking
if thinkingField.IsObject() {
if inner := thinkingField.Get("text"); inner.Exists() && inner.Type == gjson.String {
return inner.String()
}
if inner := thinkingField.Get("thinking"); inner.Exists() && inner.Type == gjson.String {
return inner.String()
}
}
return ""
}
================================================
FILE: internal/thinking/types.go
================================================
// Package thinking provides unified thinking configuration processing.
//
// This package offers a unified interface for parsing, validating, and applying
// thinking configurations across various AI providers (Claude, Gemini, OpenAI, iFlow).
package thinking
import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
// ThinkingMode represents the type of thinking configuration mode.
type ThinkingMode int
const (
// ModeBudget indicates using a numeric budget (corresponds to suffix "(1000)" etc.)
ModeBudget ThinkingMode = iota
// ModeLevel indicates using a discrete level (corresponds to suffix "(high)" etc.)
ModeLevel
// ModeNone indicates thinking is disabled (corresponds to suffix "(none)" or budget=0)
ModeNone
// ModeAuto indicates automatic/dynamic thinking (corresponds to suffix "(auto)" or budget=-1)
ModeAuto
)
// String returns the string representation of ThinkingMode.
func (m ThinkingMode) String() string {
switch m {
case ModeBudget:
return "budget"
case ModeLevel:
return "level"
case ModeNone:
return "none"
case ModeAuto:
return "auto"
default:
return "unknown"
}
}
// ThinkingLevel represents a discrete thinking level.
type ThinkingLevel string
const (
// LevelNone disables thinking
LevelNone ThinkingLevel = "none"
// LevelAuto enables automatic/dynamic thinking
LevelAuto ThinkingLevel = "auto"
// LevelMinimal sets minimal thinking effort
LevelMinimal ThinkingLevel = "minimal"
// LevelLow sets low thinking effort
LevelLow ThinkingLevel = "low"
// LevelMedium sets medium thinking effort
LevelMedium ThinkingLevel = "medium"
// LevelHigh sets high thinking effort
LevelHigh ThinkingLevel = "high"
// LevelXHigh sets extra-high thinking effort
LevelXHigh ThinkingLevel = "xhigh"
// LevelMax sets maximum thinking effort.
// This is currently used by Claude 4.6 adaptive thinking (opus supports "max").
LevelMax ThinkingLevel = "max"
)
// ThinkingConfig represents a unified thinking configuration.
//
// This struct is used to pass thinking configuration information between components.
// Depending on Mode, either Budget or Level field is effective:
// - ModeNone: Budget=0, Level is ignored
// - ModeAuto: Budget=-1, Level is ignored
// - ModeBudget: Budget is a positive integer, Level is ignored
// - ModeLevel: Budget is ignored, Level is a valid level
type ThinkingConfig struct {
// Mode specifies the configuration mode
Mode ThinkingMode
// Budget is the thinking budget (token count), only effective when Mode is ModeBudget.
// Special values: 0 means disabled, -1 means automatic
Budget int
// Level is the thinking level, only effective when Mode is ModeLevel
Level ThinkingLevel
}
// SuffixResult represents the result of parsing a model name for thinking suffix.
//
// A thinking suffix is specified in the format model-name(value), where value
// can be a numeric budget (e.g., "16384") or a level name (e.g., "high").
type SuffixResult struct {
// ModelName is the model name with the suffix removed.
// If no suffix was found, this equals the original input.
ModelName string
// HasSuffix indicates whether a valid suffix was found.
HasSuffix bool
// RawSuffix is the content inside the parentheses, without the parentheses.
// Empty string if HasSuffix is false.
RawSuffix string
}
// ProviderApplier defines the interface for provider-specific thinking configuration application.
//
// Types implementing this interface are responsible for converting a unified ThinkingConfig
// into provider-specific format and applying it to the request body.
//
// Implementation requirements:
// - Apply method must be idempotent
// - Must not modify the input config or modelInfo
// - Returns a modified copy of the request body
// - Returns appropriate ThinkingError for unsupported configurations
type ProviderApplier interface {
// Apply applies the thinking configuration to the request body.
//
// Parameters:
// - body: Original request body JSON
// - config: Unified thinking configuration
// - modelInfo: Model registry information containing ThinkingSupport properties
//
// Returns:
// - Modified request body JSON
// - ThinkingError if the configuration is invalid or unsupported
Apply(body []byte, config ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error)
}
================================================
FILE: internal/thinking/validate.go
================================================
// Package thinking provides unified thinking configuration processing logic.
package thinking
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
log "github.com/sirupsen/logrus"
)
// ValidateConfig validates a thinking configuration against model capabilities.
//
// This function performs comprehensive validation:
// - Checks if the model supports thinking
// - Auto-converts between Budget and Level formats based on model capability
// - Validates that requested level is in the model's supported levels list
// - Clamps budget values to model's allowed range
// - When converting Budget -> Level for level-only models, clamps the derived standard level to the nearest supported level
// (special values none/auto are preserved)
// - When config comes from a model suffix, strict budget validation is disabled (we clamp instead of error)
//
// Parameters:
// - config: The thinking configuration to validate
// - support: Model's ThinkingSupport properties (nil means no thinking support)
// - fromFormat: Source provider format (used to determine strict validation rules)
// - toFormat: Target provider format
// - fromSuffix: Whether config was sourced from model suffix
//
// Returns:
// - Normalized ThinkingConfig with clamped values
// - ThinkingError if validation fails (ErrThinkingNotSupported, ErrLevelNotSupported, etc.)
//
// Auto-conversion behavior:
// - Budget-only model + Level config → Level converted to Budget
// - Level-only model + Budget config → Budget converted to Level
// - Hybrid model → preserve original format
func ValidateConfig(config ThinkingConfig, modelInfo *registry.ModelInfo, fromFormat, toFormat string, fromSuffix bool) (*ThinkingConfig, error) {
fromFormat, toFormat = strings.ToLower(strings.TrimSpace(fromFormat)), strings.ToLower(strings.TrimSpace(toFormat))
model := "unknown"
support := (*registry.ThinkingSupport)(nil)
if modelInfo != nil {
if modelInfo.ID != "" {
model = modelInfo.ID
}
support = modelInfo.Thinking
}
if support == nil {
if config.Mode != ModeNone {
return nil, NewThinkingErrorWithModel(ErrThinkingNotSupported, "thinking not supported for this model", model)
}
return &config, nil
}
// allowClampUnsupported determines whether to clamp unsupported levels instead of returning an error.
// This applies when crossing provider families (e.g., openai→gemini, claude→gemini) and the target
// model supports discrete levels. Same-family conversions require strict validation.
toCapability := detectModelCapability(modelInfo)
toHasLevelSupport := toCapability == CapabilityLevelOnly || toCapability == CapabilityHybrid
allowClampUnsupported := toHasLevelSupport && !isSameProviderFamily(fromFormat, toFormat)
// strictBudget determines whether to enforce strict budget range validation.
// This applies when: (1) config comes from request body (not suffix), (2) source format is known,
// and (3) source and target are in the same provider family. Cross-family or suffix-based configs
// are clamped instead of rejected to improve interoperability.
strictBudget := !fromSuffix && fromFormat != "" && isSameProviderFamily(fromFormat, toFormat)
budgetDerivedFromLevel := false
capability := detectModelCapability(modelInfo)
switch capability {
case CapabilityBudgetOnly:
if config.Mode == ModeLevel {
if config.Level == LevelAuto {
break
}
budget, ok := ConvertLevelToBudget(string(config.Level))
if !ok {
return nil, NewThinkingError(ErrUnknownLevel, fmt.Sprintf("unknown level: %s", config.Level))
}
config.Mode = ModeBudget
config.Budget = budget
config.Level = ""
budgetDerivedFromLevel = true
}
case CapabilityLevelOnly:
if config.Mode == ModeBudget {
level, ok := ConvertBudgetToLevel(config.Budget)
if !ok {
return nil, NewThinkingError(ErrUnknownLevel, fmt.Sprintf("budget %d cannot be converted to a valid level", config.Budget))
}
// When converting Budget -> Level for level-only models, clamp the derived standard level
// to the nearest supported level. Special values (none/auto) are preserved.
config.Mode = ModeLevel
config.Level = clampLevel(ThinkingLevel(level), modelInfo, toFormat)
config.Budget = 0
}
case CapabilityHybrid:
}
if config.Mode == ModeLevel && config.Level == LevelNone {
config.Mode = ModeNone
config.Budget = 0
config.Level = ""
}
if config.Mode == ModeLevel && config.Level == LevelAuto {
config.Mode = ModeAuto
config.Budget = -1
config.Level = ""
}
if config.Mode == ModeBudget && config.Budget == 0 {
config.Mode = ModeNone
config.Level = ""
}
if len(support.Levels) > 0 && config.Mode == ModeLevel {
if !isLevelSupported(string(config.Level), support.Levels) {
if allowClampUnsupported {
config.Level = clampLevel(config.Level, modelInfo, toFormat)
}
if !isLevelSupported(string(config.Level), support.Levels) {
// User explicitly specified an unsupported level - return error
// (budget-derived levels may be clamped based on source format)
validLevels := normalizeLevels(support.Levels)
message := fmt.Sprintf("level %q not supported, valid levels: %s", strings.ToLower(string(config.Level)), strings.Join(validLevels, ", "))
return nil, NewThinkingError(ErrLevelNotSupported, message)
}
}
}
if strictBudget && config.Mode == ModeBudget && !budgetDerivedFromLevel {
min, max := support.Min, support.Max
if min != 0 || max != 0 {
if config.Budget < min || config.Budget > max || (config.Budget == 0 && !support.ZeroAllowed) {
message := fmt.Sprintf("budget %d out of range [%d,%d]", config.Budget, min, max)
return nil, NewThinkingError(ErrBudgetOutOfRange, message)
}
}
}
// Convert ModeAuto to mid-range if dynamic not allowed
if config.Mode == ModeAuto && !support.DynamicAllowed {
config = convertAutoToMidRange(config, support, toFormat, model)
}
if config.Mode == ModeNone && toFormat == "claude" {
// Claude supports explicit disable via thinking.type="disabled".
// Keep Budget=0 so applier can omit budget_tokens.
config.Budget = 0
config.Level = ""
} else {
switch config.Mode {
case ModeBudget, ModeAuto, ModeNone:
config.Budget = clampBudget(config.Budget, modelInfo, toFormat)
}
// ModeNone with clamped Budget > 0: set Level to lowest for Level-only/Hybrid models
// This ensures Apply layer doesn't need to access support.Levels
if config.Mode == ModeNone && config.Budget > 0 && len(support.Levels) > 0 {
config.Level = ThinkingLevel(support.Levels[0])
}
}
return &config, nil
}
// convertAutoToMidRange converts ModeAuto to a mid-range value when dynamic is not allowed.
//
// This function handles the case where a model does not support dynamic/auto thinking.
// The auto mode is silently converted to a fixed value based on model capability:
// - Level-only models: convert to ModeLevel with LevelMedium
// - Budget models: convert to ModeBudget with mid = (Min + Max) / 2
//
// Logging:
// - Debug level when conversion occurs
// - Fields: original_mode, clamped_to, reason
func convertAutoToMidRange(config ThinkingConfig, support *registry.ThinkingSupport, provider, model string) ThinkingConfig {
// For level-only models (has Levels but no Min/Max range), use ModeLevel with medium
if len(support.Levels) > 0 && support.Min == 0 && support.Max == 0 {
config.Mode = ModeLevel
config.Level = LevelMedium
config.Budget = 0
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_mode": "auto",
"clamped_to": string(LevelMedium),
}).Debug("thinking: mode converted, dynamic not allowed, using medium level |")
return config
}
// For budget models, use mid-range budget
mid := (support.Min + support.Max) / 2
if mid <= 0 && support.ZeroAllowed {
config.Mode = ModeNone
config.Budget = 0
} else if mid <= 0 {
config.Mode = ModeBudget
config.Budget = support.Min
} else {
config.Mode = ModeBudget
config.Budget = mid
}
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_mode": "auto",
"clamped_to": config.Budget,
}).Debug("thinking: mode converted, dynamic not allowed |")
return config
}
// standardLevelOrder defines the canonical ordering of thinking levels from lowest to highest.
var standardLevelOrder = []ThinkingLevel{LevelMinimal, LevelLow, LevelMedium, LevelHigh, LevelXHigh, LevelMax}
// clampLevel clamps the given level to the nearest supported level.
// On tie, prefers the lower level.
func clampLevel(level ThinkingLevel, modelInfo *registry.ModelInfo, provider string) ThinkingLevel {
model := "unknown"
var supported []string
if modelInfo != nil {
if modelInfo.ID != "" {
model = modelInfo.ID
}
if modelInfo.Thinking != nil {
supported = modelInfo.Thinking.Levels
}
}
if len(supported) == 0 || isLevelSupported(string(level), supported) {
return level
}
pos := levelIndex(string(level))
if pos == -1 {
return level
}
bestIdx, bestDist := -1, len(standardLevelOrder)+1
for _, s := range supported {
if idx := levelIndex(strings.TrimSpace(s)); idx != -1 {
if dist := abs(pos - idx); dist < bestDist || (dist == bestDist && idx < bestIdx) {
bestIdx, bestDist = idx, dist
}
}
}
if bestIdx >= 0 {
clamped := standardLevelOrder[bestIdx]
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_value": string(level),
"clamped_to": string(clamped),
}).Debug("thinking: level clamped |")
return clamped
}
return level
}
// clampBudget clamps a budget value to the model's supported range.
func clampBudget(value int, modelInfo *registry.ModelInfo, provider string) int {
model := "unknown"
support := (*registry.ThinkingSupport)(nil)
if modelInfo != nil {
if modelInfo.ID != "" {
model = modelInfo.ID
}
support = modelInfo.Thinking
}
if support == nil {
return value
}
// Auto value (-1) passes through without clamping.
if value == -1 {
return value
}
min, max := support.Min, support.Max
if value == 0 && !support.ZeroAllowed {
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_value": value,
"clamped_to": min,
"min": min,
"max": max,
}).Warn("thinking: budget zero not allowed |")
return min
}
// Some models are level-only and do not define numeric budget ranges.
if min == 0 && max == 0 {
return value
}
if value < min {
if value == 0 && support.ZeroAllowed {
return 0
}
logClamp(provider, model, value, min, min, max)
return min
}
if value > max {
logClamp(provider, model, value, max, min, max)
return max
}
return value
}
func isLevelSupported(level string, supported []string) bool {
for _, s := range supported {
if strings.EqualFold(level, strings.TrimSpace(s)) {
return true
}
}
return false
}
func levelIndex(level string) int {
for i, l := range standardLevelOrder {
if strings.EqualFold(level, string(l)) {
return i
}
}
return -1
}
func normalizeLevels(levels []string) []string {
out := make([]string, len(levels))
for i, l := range levels {
out[i] = strings.ToLower(strings.TrimSpace(l))
}
return out
}
// isBudgetCapableProvider returns true if the provider supports budget-based thinking.
// These providers may also support level-based thinking (hybrid models).
func isBudgetCapableProvider(provider string) bool {
switch provider {
case "gemini", "gemini-cli", "antigravity", "claude":
return true
default:
return false
}
}
func isGeminiFamily(provider string) bool {
switch provider {
case "gemini", "gemini-cli", "antigravity":
return true
default:
return false
}
}
func isOpenAIFamily(provider string) bool {
switch provider {
case "openai", "openai-response", "codex":
return true
default:
return false
}
}
func isSameProviderFamily(from, to string) bool {
if from == to {
return true
}
return (isGeminiFamily(from) && isGeminiFamily(to)) ||
(isOpenAIFamily(from) && isOpenAIFamily(to))
}
func abs(x int) int {
if x < 0 {
return -x
}
return x
}
func logClamp(provider, model string, original, clampedTo, min, max int) {
log.WithFields(log.Fields{
"provider": provider,
"model": model,
"original_value": original,
"min": min,
"max": max,
"clamped_to": clampedTo,
}).Debug("thinking: budget clamped |")
}
================================================
FILE: internal/translator/antigravity/claude/antigravity_claude_request.go
================================================
// Package claude provides request translation functionality for Claude Code API compatibility.
// This package handles the conversion of Claude Code API requests into Gemini CLI-compatible
// JSON format, transforming message contents, system instructions, and tool declarations
// into the format expected by Gemini CLI API clients. It performs JSON data transformation
// to ensure compatibility between Claude Code API format and Gemini CLI API's expected format.
package claude
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
// The function performs the following transformations:
// 1. Extracts the model information from the request
// 2. Restructures the JSON to match Gemini CLI API format
// 3. Converts system instructions to the expected format
// 4. Maps message contents with proper role transformations
// 5. Handles tool declarations and tool choices
// 6. Maps generation configuration parameters
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the Claude Code API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini CLI API format
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
enableThoughtTranslate := true
rawJSON := inputRawJSON
// system instruction
systemInstructionJSON := ""
hasSystemInstruction := false
systemResult := gjson.GetBytes(rawJSON, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstructionJSON = `{"role":"user","parts":[]}`
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
partJSON := `{}`
if systemPrompt != "" {
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
}
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
hasSystemInstruction = true
}
}
} else if systemResult.Type == gjson.String {
systemInstructionJSON = `{"role":"user","parts":[{"text":""}]}`
systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.0.text", systemResult.String())
hasSystemInstruction = true
}
// contents
contentsJSON := "[]"
hasContents := false
// tool_use_id → tool_name lookup, populated incrementally during the main loop.
// Claude's tool_result references tool_use by ID; Gemini requires functionResponse.name.
toolNameByID := make(map[string]string)
messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
numMessages := len(messageResults)
for i := 0; i < numMessages; i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
originalRole := roleResult.String()
role := originalRole
if role == "assistant" {
role = "model"
}
clientContentJSON := `{"role":"","parts":[]}`
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
numContents := len(contentResults)
var currentMessageThinkingSignature string
for j := 0; j < numContents; j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
// Use GetThinkingText to handle wrapped thinking objects
thinkingText := thinking.GetThinkingText(contentResult)
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
if thinkingText != "" {
if cachedSig := cache.GetCachedSignature(modelName, thinkingText); cachedSig != "" {
signature = cachedSig
// log.Debugf("Using cached signature for thinking block")
}
}
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" {
signatureResult := contentResult.Get("signature")
clientSignature := ""
if signatureResult.Exists() && signatureResult.String() != "" {
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
if len(arrayClientSignatures) == 2 {
if cache.GetModelGroup(modelName) == arrayClientSignatures[0] {
clientSignature = arrayClientSignatures[1]
}
}
}
if cache.HasValidSignature(modelName, clientSignature) {
signature = clientSignature
}
// log.Debugf("Using client-provided signature for thinking block")
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(modelName, signature) {
currentMessageThinkingSignature = signature
}
// Skip trailing unsigned thinking blocks on last assistant message
isUnsigned := !cache.HasValidSignature(modelName, signature)
// If unsigned, skip entirely (don't convert to text)
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
// Converting to text would break this requirement
if isUnsigned {
// log.Debugf("Dropping unsigned thinking block (no valid signature)")
enableThoughtTranslate = false
continue
}
// Valid signature, send as thought block
partJSON := `{}`
partJSON, _ = sjson.Set(partJSON, "thought", true)
if thinkingText != "" {
partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
}
if signature != "" {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
}
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
// Skip empty text parts to avoid Gemini API error:
// "required oneof field 'data' must have one initialized field"
if prompt == "" {
continue
}
partJSON := `{}`
partJSON, _ = sjson.Set(partJSON, "text", prompt)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
// NOTE: Do NOT inject dummy thinking blocks here.
// Antigravity API validates signatures, so dummy values are rejected.
functionName := contentResult.Get("name").String()
argsResult := contentResult.Get("input")
functionID := contentResult.Get("id").String()
if functionID != "" && functionName != "" {
toolNameByID[functionID] = functionName
}
// Handle both object and string input formats
var argsRaw string
if argsResult.IsObject() {
argsRaw = argsResult.Raw
} else if argsResult.Type == gjson.String {
// Input is a JSON string, parse and validate it
parsed := gjson.Parse(argsResult.String())
if parsed.IsObject() {
argsRaw = parsed.Raw
}
}
if argsRaw != "" {
partJSON := `{}`
// Use skip_thought_signature_validator for tool calls without valid thinking signature
// This is the approach used in opencode-google-antigravity-auth for Gemini
// and also works for Claude through Antigravity API
const skipSentinel = "skip_thought_signature_validator"
if cache.HasValidSignature(modelName, currentMessageThinkingSignature) {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
} else {
// No valid signature - use skip sentinel to bypass validation
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
}
if functionID != "" {
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
}
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName, ok := toolNameByID[toolCallID]
if !ok {
// Fallback: derive a semantic name from the ID by stripping
// the last two dash-separated segments (e.g. "get_weather-call-123" → "get_weather").
// Only use the raw ID as a last resort when the heuristic produces an empty string.
parts := strings.Split(toolCallID, "-")
if len(parts) > 2 {
funcName = strings.Join(parts[:len(parts)-2], "-")
}
if funcName == "" {
funcName = toolCallID
}
log.Warnf("antigravity claude request: tool_result references unknown tool_use_id=%s, derived function name=%s", toolCallID, funcName)
}
functionResponseResult := contentResult.Get("content")
functionResponseJSON := `{}`
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
responseData := ""
if functionResponseResult.Type == gjson.String {
responseData = functionResponseResult.String()
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
} else if functionResponseResult.IsArray() {
frResults := functionResponseResult.Array()
nonImageCount := 0
lastNonImageRaw := ""
filteredJSON := "[]"
imagePartsJSON := "[]"
for _, fr := range frResults {
if fr.Get("type").String() == "image" && fr.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := fr.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := fr.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
continue
}
nonImageCount++
lastNonImageRaw = fr.Raw
filteredJSON, _ = sjson.SetRaw(filteredJSON, "-1", fr.Raw)
}
if nonImageCount == 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", lastNonImageRaw)
} else if nonImageCount > 1 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", filteredJSON)
} else {
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
}
// Place image data inside functionResponse.parts as inlineData
// instead of as sibling parts in the outer content, to avoid
// base64 data bloating the text context.
if gjson.Get(imagePartsJSON, "#").Int() > 0 {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
}
} else if functionResponseResult.IsObject() {
if functionResponseResult.Get("type").String() == "image" && functionResponseResult.Get("source.type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := functionResponseResult.Get("source.media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := functionResponseResult.Get("source.data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
}
imagePartJSON := `{}`
imagePartJSON, _ = sjson.SetRaw(imagePartJSON, "inlineData", inlineDataJSON)
imagePartsJSON := "[]"
imagePartsJSON, _ = sjson.SetRaw(imagePartsJSON, "-1", imagePartJSON)
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "parts", imagePartsJSON)
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
} else {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
}
} else if functionResponseResult.Raw != "" {
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
} else {
// Content field is missing entirely — .Raw is empty which
// causes sjson.SetRaw to produce invalid JSON (e.g. "result":}).
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", "")
}
partJSON := `{}`
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
sourceResult := contentResult.Get("source")
if sourceResult.Get("type").String() == "base64" {
inlineDataJSON := `{}`
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mimeType", mimeType)
}
if data := sourceResult.Get("data").String(); data != "" {
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
}
partJSON := `{}`
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
}
}
}
// Reorder parts for 'model' role to ensure thinking block is first
if role == "model" {
partsResult := gjson.Get(clientContentJSON, "parts")
if partsResult.IsArray() {
parts := partsResult.Array()
var thinkingParts []gjson.Result
var otherParts []gjson.Result
for _, part := range parts {
if part.Get("thought").Bool() {
thinkingParts = append(thinkingParts, part)
} else {
otherParts = append(otherParts, part)
}
}
if len(thinkingParts) > 0 {
firstPartIsThinking := parts[0].Get("thought").Bool()
if !firstPartIsThinking || len(thinkingParts) > 1 {
var newParts []interface{}
for _, p := range thinkingParts {
newParts = append(newParts, p.Value())
}
for _, p := range otherParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
}
}
}
}
// Skip messages with empty parts array to avoid Gemini API error:
// "required oneof field 'data' must have one initialized field"
partsCheck := gjson.Get(clientContentJSON, "parts")
if !partsCheck.IsArray() || len(partsCheck.Array()) == 0 {
continue
}
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
hasContents = true
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
partJSON := `{}`
if prompt != "" {
partJSON, _ = sjson.Set(partJSON, "text", prompt)
}
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
hasContents = true
}
}
}
// tools
toolsJSON := ""
toolDeclCount := 0
allowedToolKeys := []string{"name", "description", "behavior", "parameters", "parametersJsonSchema", "response", "responseJsonSchema"}
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
toolsJSON = `[{"functionDeclarations":[]}]`
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
// Sanitize the input schema for Antigravity API compatibility
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
for toolKey := range gjson.Parse(tool).Map() {
if util.InArray(allowedToolKeys, toolKey) {
continue
}
tool, _ = sjson.Delete(tool, toolKey)
}
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
toolDeclCount++
}
}
}
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
// Inject interleaved thinking hint when both tools and thinking are active
hasTools := toolDeclCount > 0
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
thinkingType := thinkingResult.Get("type").String()
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && (thinkingType == "enabled" || thinkingType == "adaptive" || thinkingType == "auto")
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
if hasTools && hasThinking && isClaudeThinking {
interleavedHint := "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them."
if hasSystemInstruction {
// Append hint as a new part to existing system instruction
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
} else {
// Create new system instruction with hint
systemInstructionJSON = `{"role":"user","parts":[]}`
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
hasSystemInstruction = true
}
}
if hasSystemInstruction {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
}
if hasContents {
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
}
if toolDeclCount > 0 {
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
}
// tool_choice
toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice")
if toolChoiceResult.Exists() {
toolChoiceType := ""
toolChoiceName := ""
if toolChoiceResult.IsObject() {
toolChoiceType = toolChoiceResult.Get("type").String()
toolChoiceName = toolChoiceResult.Get("name").String()
} else if toolChoiceResult.Type == gjson.String {
toolChoiceType = toolChoiceResult.String()
}
switch toolChoiceType {
case "auto":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
case "none":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
case "any":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
case "tool":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" {
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
}
}
}
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
if t := gjson.GetBytes(rawJSON, "thinking"); enableThoughtTranslate && t.Exists() && t.IsObject() {
switch t.Get("type").String() {
case "enabled":
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive", "auto":
// For adaptive thinking:
// - If output_config.effort is explicitly present, pass through as thinkingLevel.
// - Otherwise, treat it as "enabled with target-model maximum" and emit high.
// ApplyThinking handles clamping to target model's supported levels.
effort := ""
if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String {
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
}
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
}
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
}
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num)
}
outBytes := []byte(out)
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
return outBytes
}
================================================
FILE: internal/translator/antigravity/claude/antigravity_claude_request_test.go
================================================
package claude
import (
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToAntigravity_BasicStructure(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"}
]
}
],
"system": [
{"type": "text", "text": "You are helpful"}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Check model
if gjson.Get(outputStr, "model").String() != "claude-sonnet-4-5" {
t.Errorf("Expected model 'claude-sonnet-4-5', got '%s'", gjson.Get(outputStr, "model").String())
}
// Check contents exist
contents := gjson.Get(outputStr, "request.contents")
if !contents.Exists() || !contents.IsArray() {
t.Error("request.contents should exist and be an array")
}
// Check role mapping (assistant -> model)
firstContent := gjson.Get(outputStr, "request.contents.0")
if firstContent.Get("role").String() != "user" {
t.Errorf("Expected role 'user', got '%s'", firstContent.Get("role").String())
}
// Check systemInstruction
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if !sysInstruction.Exists() {
t.Error("systemInstruction should exist")
}
if sysInstruction.Get("parts.0.text").String() != "You are helpful" {
t.Error("systemInstruction text mismatch")
}
}
func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hi"}]},
{"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// assistant should be mapped to model
secondContent := gjson.Get(outputStr, "request.contents.1")
if secondContent.Get("role").String() != "model" {
t.Errorf("Expected role 'model' (mapped from 'assistant'), got '%s'", secondContent.Get("role").String())
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
cache.ClearSignatureCache("")
// Valid signature must be at least 50 characters
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think..."
// Pre-cache the signature (simulating a previous response for the same thinking text)
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Test user message"}]
},
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
{"type": "text", "text": "Answer"}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check thinking block conversion (now in contents.1 due to user message)
firstPart := gjson.Get(outputStr, "request.contents.1.parts.0")
if !firstPart.Get("thought").Bool() {
t.Error("thinking block should have thought: true")
}
if firstPart.Get("text").String() != thinkingText {
t.Error("thinking text mismatch")
}
if firstPart.Get("thoughtSignature").String() != validSignature {
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, firstPart.Get("thoughtSignature").String())
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
cache.ClearSignatureCache("")
// Unsigned thinking blocks should be removed entirely (not converted to text)
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me think..."},
{"type": "text", "text": "Answer"}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Without signature, thinking block should be removed (not converted to text)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
}
// Only text part should remain
if parts[0].Get("thought").Bool() {
t.Error("Thinking block should be removed, not preserved")
}
if parts[0].Get("text").String() != "Answer" {
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
}
}
func TestConvertClaudeRequestToAntigravity_ToolDeclarations(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [],
"tools": [
{
"name": "test_tool",
"description": "A test tool",
"input_schema": {
"type": "object",
"properties": {
"name": {"type": "string"}
},
"required": ["name"]
}
}
]
}`)
output := ConvertClaudeRequestToAntigravity("gemini-1.5-pro", inputJSON, false)
outputStr := string(output)
// Check tools structure
tools := gjson.Get(outputStr, "request.tools")
if !tools.Exists() {
t.Error("Tools should exist in output")
}
funcDecl := gjson.Get(outputStr, "request.tools.0.functionDeclarations.0")
if funcDecl.Get("name").String() != "test_tool" {
t.Errorf("Expected tool name 'test_tool', got '%s'", funcDecl.Get("name").String())
}
// Check input_schema renamed to parametersJsonSchema
if funcDecl.Get("parametersJsonSchema").Exists() {
t.Log("parametersJsonSchema exists (expected)")
}
if funcDecl.Get("input_schema").Exists() {
t.Error("input_schema should be removed")
}
}
func TestConvertClaudeRequestToAntigravity_ToolChoice_SpecificTool(t *testing.T) {
inputJSON := []byte(`{
"model": "gemini-3-flash-preview",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hi"}
]
}
],
"tools": [
{
"name": "json",
"description": "A JSON tool",
"input_schema": {
"type": "object",
"properties": {}
}
}
],
"tool_choice": {"type": "tool", "name": "json"}
}`)
output := ConvertClaudeRequestToAntigravity("gemini-3-flash-preview", inputJSON, false)
outputStr := string(output)
if got := gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.mode").String(); got != "ANY" {
t.Fatalf("Expected toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got)
}
allowed := gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Array()
if len(allowed) != 1 || allowed[0].String() != "json" {
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.Get(outputStr, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": "{\"location\": \"Paris\"}"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Now we expect only 1 part (tool_use), no dummy thinking block injected
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (tool only, no dummy injection), got %d", len(parts))
}
// Check function call conversion at parts[0]
funcCall := parts[0].Get("functionCall")
if !funcCall.Exists() {
t.Error("functionCall should exist at parts[0]")
}
if funcCall.Get("name").String() != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", funcCall.Get("name").String())
}
if funcCall.Get("id").String() != "call_123" {
t.Errorf("Expected function id 'call_123', got '%s'", funcCall.Get("id").String())
}
// Verify skip_thought_signature_validator is added (bypass for tools without valid thinking)
expectedSig := "skip_thought_signature_validator"
actualSig := parts[0].Get("thoughtSignature").String()
if actualSig != expectedSig {
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, actualSig)
}
}
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
cache.ClearSignatureCache("")
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Let me think..."
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Test user message"}]
},
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"},
{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": "{\"location\": \"Paris\"}"
}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check function call has the signature from the preceding thinking block (now in contents.1)
part := gjson.Get(outputStr, "request.contents.1.parts.1")
if part.Get("functionCall.name").String() != "get_weather" {
t.Errorf("Expected functionCall, got %s", part.Raw)
}
if part.Get("thoughtSignature").String() != validSignature {
t.Errorf("Expected thoughtSignature '%s' on tool_use, got '%s'", validSignature, part.Get("thoughtSignature").String())
}
}
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
cache.ClearSignatureCache("")
// Case: text block followed by thinking block -> should be reordered to thinking first
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Planning..."
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Test user message"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is the plan."},
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Verify order: Thinking block MUST be first (now in contents.1 due to user message)
parts := gjson.Get(outputStr, "request.contents.1.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(parts))
}
if !parts[0].Get("thought").Bool() {
t.Error("First part should be thinking block after reordering")
}
if parts[1].Get("text").String() != "Here is the plan." {
t.Error("Second part should be text block")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "get_weather-call-123",
"name": "get_weather",
"input": {"location": "Paris"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "get_weather-call-123",
"content": "22C sunny"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Check function response conversion
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp.Exists() {
t.Error("functionResponse should exist")
}
if funcResp.Get("id").String() != "get_weather-call-123" {
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
}
if funcResp.Get("name").String() != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", funcResp.Get("name").String())
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_TouluFormat(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-haiku-4-5-20251001",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"name": "Glob",
"input": {"pattern": "**/*.py"}
},
{
"type": "tool_use",
"id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
"name": "Bash",
"input": {"command": "ls"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"content": "file1.py\nfile2.py"
},
{
"type": "tool_result",
"tool_use_id": "toolu_tool-cf2d061f75f845c49aacc18ee75ee708",
"content": "total 10"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
outputStr := string(output)
funcResp0 := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp0.Exists() {
t.Fatal("first functionResponse should exist")
}
if got := funcResp0.Get("name").String(); got != "Glob" {
t.Errorf("Expected name 'Glob' for toolu_ format, got '%s'", got)
}
funcResp1 := gjson.Get(outputStr, "request.contents.1.parts.1.functionResponse")
if !funcResp1.Exists() {
t.Fatal("second functionResponse should exist")
}
if got := funcResp1.Get("name").String(); got != "Bash" {
t.Errorf("Expected name 'Bash' for toolu_ format, got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_CustomFormat(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-haiku-4-5-20251001",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "Read-1773420180464065165-1327",
"name": "Read",
"input": {"file_path": "/tmp/test.py"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-1773420180464065165-1327",
"content": "file content here"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-haiku-4-5-20251001", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
if got := funcResp.Get("name").String(); got != "Read" {
t.Errorf("Expected name 'Read', got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_Heuristic(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "get_weather-call-123",
"content": "22C sunny"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
if got := funcResp.Get("name").String(); got != "get_weather" {
t.Errorf("Expected heuristic-derived name 'get_weather', got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultName_NoMatchingToolUse_RawID(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "toolu_tool-48fca351f12844eabf49dad8b63886d2",
"content": "result data"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
got := funcResp.Get("name").String()
if got == "" {
t.Error("functionResponse.name must not be empty")
}
if got != "toolu_tool-48fca351f12844eabf49dad8b63886d2" {
t.Errorf("Expected raw ID as last-resort name, got '%s'", got)
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
// Note: This test requires the model to be registered in the registry
// with Thinking metadata. If the registry is not populated in test environment,
// thinkingConfig won't be added. We'll test the basic structure only.
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [],
"thinking": {
"type": "enabled",
"budget_tokens": 8000
}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check thinking config conversion (only if model supports thinking in registry)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if thinkingConfig.Exists() {
if thinkingConfig.Get("thinkingBudget").Int() != 8000 {
t.Errorf("Expected thinkingBudget 8000, got %d", thinkingConfig.Get("thinkingBudget").Int())
}
if !thinkingConfig.Get("includeThoughts").Bool() {
t.Error("includeThoughts should be true")
}
} else {
t.Log("thinkingConfig not present - model may not be registered in test registry")
}
}
func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Check inline data conversion
inlineData := gjson.Get(outputStr, "request.contents.0.parts.0.inlineData")
if !inlineData.Exists() {
t.Error("inlineData should exist")
}
if inlineData.Get("mimeType").String() != "image/png" {
t.Error("mimeType mismatch")
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
}
}
func TestConvertClaudeRequestToAntigravity_GenerationConfig(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [],
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"max_tokens": 2000
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
genConfig := gjson.Get(outputStr, "request.generationConfig")
if genConfig.Get("temperature").Float() != 0.7 {
t.Errorf("Expected temperature 0.7, got %f", genConfig.Get("temperature").Float())
}
if genConfig.Get("topP").Float() != 0.9 {
t.Errorf("Expected topP 0.9, got %f", genConfig.Get("topP").Float())
}
if genConfig.Get("topK").Float() != 40 {
t.Errorf("Expected topK 40, got %f", genConfig.Get("topK").Float())
}
if genConfig.Get("maxOutputTokens").Float() != 2000 {
t.Errorf("Expected maxOutputTokens 2000, got %f", genConfig.Get("maxOutputTokens").Float())
}
}
// ============================================================================
// Trailing Unsigned Thinking Block Removal
// ============================================================================
func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *testing.T) {
// Last assistant message ends with unsigned thinking block - should be removed
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is my answer"},
{"type": "thinking", "thinking": "I should think more..."}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// The last part of the last assistant message should NOT be a thinking block
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
if !lastMessageParts.IsArray() {
t.Fatal("Last message should have parts array")
}
parts := lastMessageParts.Array()
if len(parts) == 0 {
t.Fatal("Last message should have at least one part")
}
// The unsigned thinking should be removed, leaving only the text
lastPart := parts[len(parts)-1]
if lastPart.Get("thought").Bool() {
t.Error("Trailing unsigned thinking block should be removed")
}
}
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
cache.ClearSignatureCache("")
// Last assistant message ends with signed thinking block - should be kept
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
thinkingText := "Valid thinking..."
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is my answer"},
{"type": "thinking", "thinking": "` + thinkingText + `", "signature": "` + validSignature + `"}
]
}
]
}`)
cache.CacheSignature("claude-sonnet-4-5-thinking", thinkingText, validSignature)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// The signed thinking block should be preserved
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
parts := lastMessageParts.Array()
if len(parts) < 2 {
t.Error("Signed thinking block should be preserved")
}
}
func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_Removed(t *testing.T) {
// Middle message has unsigned thinking - should be removed entirely
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Middle thinking..."},
{"type": "text", "text": "Answer"}
]
},
{
"role": "user",
"content": [{"type": "text", "text": "Follow up"}]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Unsigned thinking should be removed entirely
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
}
// Only text part should remain
if parts[0].Get("thought").Bool() {
t.Error("Thinking block should be removed, not preserved")
}
if parts[0].Get("text").String() != "Answer" {
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
}
}
// ============================================================================
// Tool + Thinking System Hint Injection
// ============================================================================
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_HintInjected(t *testing.T) {
// When both tools and thinking are enabled, hint should be injected into system instruction
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are helpful."}],
"tools": [
{
"name": "get_weather",
"description": "Get weather",
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
}
],
"thinking": {"type": "enabled", "budget_tokens": 8000}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// System instruction should contain the interleaved thinking hint
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if !sysInstruction.Exists() {
t.Fatal("systemInstruction should exist")
}
// Check if hint is appended
sysText := sysInstruction.Get("parts").Array()
found := false
for _, part := range sysText {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
found = true
break
}
}
if !found {
t.Errorf("Interleaved thinking hint should be injected when tools and thinking are both active, got: %v", sysInstruction.Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ToolsOnly_NoHint(t *testing.T) {
// When only tools are present (no thinking), hint should NOT be injected
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are helpful."}],
"tools": [
{
"name": "get_weather",
"description": "Get weather",
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// System instruction should NOT contain the hint
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if sysInstruction.Exists() {
for _, part := range sysInstruction.Get("parts").Array() {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
t.Error("Hint should NOT be injected when only tools are present (no thinking)")
}
}
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
// When only thinking is enabled (no tools), hint should NOT be injected
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are helpful."}],
"thinking": {"type": "enabled", "budget_tokens": 8000}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// System instruction should NOT contain the hint (no tools)
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if sysInstruction.Exists() {
for _, part := range sysInstruction.Get("parts").Array() {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
t.Error("Hint should NOT be injected when only thinking is present (no tools)")
}
}
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultNoContent(t *testing.T) {
// Bug repro: tool_result with no content field produces invalid JSON
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "MyTool-123-456",
"name": "MyTool",
"input": {"key": "value"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "MyTool-123-456"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Errorf("Result is not valid JSON:\n%s", outputStr)
}
// Verify the functionResponse has a valid result value
fr := gjson.Get(outputStr, "request.contents.1.parts.0.functionResponse.response.result")
if !fr.Exists() {
t.Error("functionResponse.response.result should exist")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultNullContent(t *testing.T) {
// Bug repro: tool_result with null content produces invalid JSON
inputJSON := []byte(`{
"model": "claude-opus-4-6-thinking",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "MyTool-123-456",
"name": "MyTool",
"input": {"key": "value"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "MyTool-123-456",
"content": null
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-opus-4-6-thinking", inputJSON, true)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Errorf("Result is not valid JSON:\n%s", outputStr)
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithImage(t *testing.T) {
// tool_result with array content containing text + image should place
// image data inside functionResponse.parts as inlineData, not as a
// sibling part in the outer content (to avoid base64 context bloat).
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-123-456",
"content": [
{
"type": "text",
"text": "File content here"
},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
// Image should be inside functionResponse.parts, not as outer sibling part
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Text content should be in response.result
resultText := funcResp.Get("response.result.text").String()
if resultText != "File content here" {
t.Errorf("Expected response.result.text = 'File content here', got '%s'", resultText)
}
// Image should be in functionResponse.parts[0].inlineData
inlineData := funcResp.Get("parts.0.inlineData")
if !inlineData.Exists() {
t.Fatal("functionResponse.parts[0].inlineData should exist")
}
if inlineData.Get("mimeType").String() != "image/png" {
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineData.Get("mimeType").String())
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
}
// Image should NOT be in outer parts (only functionResponse part should exist)
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
t.Errorf("Expected only 1 outer part (functionResponse), got %d", len(outerParts.Array()))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithSingleImage(t *testing.T) {
// tool_result with single image object as content should place
// image data inside functionResponse.parts, not as outer sibling part.
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Read-789-012",
"content": {
"type": "image",
"source": {
"type": "base64",
"media_type": "image/jpeg",
"data": "/9j/4AAQSkZJRgABAQ=="
}
}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// response.result should be empty (image only)
if funcResp.Get("response.result").String() != "" {
t.Errorf("Expected empty response.result for image-only content, got '%s'", funcResp.Get("response.result").String())
}
// Image should be in functionResponse.parts[0].inlineData
inlineData := funcResp.Get("parts.0.inlineData")
if !inlineData.Exists() {
t.Fatal("functionResponse.parts[0].inlineData should exist")
}
if inlineData.Get("mimeType").String() != "image/jpeg" {
t.Errorf("Expected mimeType 'image/jpeg', got '%s'", inlineData.Get("mimeType").String())
}
// Image should NOT be in outer parts
outerParts := gjson.Get(outputStr, "request.contents.0.parts")
if outerParts.IsArray() && len(outerParts.Array()) > 1 {
t.Errorf("Expected only 1 outer part, got %d", len(outerParts.Array()))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithMultipleImagesAndTexts(t *testing.T) {
// tool_result with array content: 2 text items + 2 images
// All images go into functionResponse.parts, texts into response.result array
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "Multi-001",
"content": [
{"type": "text", "text": "First text"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png", "data": "AAAA"}
},
{"type": "text", "text": "Second text"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/jpeg", "data": "BBBB"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Multiple text items => response.result is an array
resultArr := funcResp.Get("response.result")
if !resultArr.IsArray() {
t.Fatalf("Expected response.result to be an array, got: %s", resultArr.Raw)
}
results := resultArr.Array()
if len(results) != 2 {
t.Fatalf("Expected 2 result items, got %d", len(results))
}
// Both images should be in functionResponse.parts
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 2 {
t.Fatalf("Expected 2 image parts in functionResponse.parts, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Errorf("Expected first image mimeType 'image/png', got '%s'", imgParts[0].Get("inlineData.mimeType").String())
}
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
t.Errorf("Expected first image data 'AAAA', got '%s'", imgParts[0].Get("inlineData.data").String())
}
if imgParts[1].Get("inlineData.mimeType").String() != "image/jpeg" {
t.Errorf("Expected second image mimeType 'image/jpeg', got '%s'", imgParts[1].Get("inlineData.mimeType").String())
}
if imgParts[1].Get("inlineData.data").String() != "BBBB" {
t.Errorf("Expected second image data 'BBBB', got '%s'", imgParts[1].Get("inlineData.data").String())
}
// Only 1 outer part (the functionResponse itself)
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(outerParts) != 1 {
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultWithOnlyMultipleImages(t *testing.T) {
// tool_result with only images (no text) — response.result should be empty string
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "ImgOnly-001",
"content": [
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png", "data": "PNG1"}
},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/gif", "data": "GIF1"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// No text => response.result should be empty string
if funcResp.Get("response.result").String() != "" {
t.Errorf("Expected empty response.result, got '%s'", funcResp.Get("response.result").String())
}
// Both images in functionResponse.parts
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 2 {
t.Fatalf("Expected 2 image parts, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Error("first image mimeType mismatch")
}
if imgParts[1].Get("inlineData.mimeType").String() != "image/gif" {
t.Error("second image mimeType mismatch")
}
// Only 1 outer part
outerParts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(outerParts) != 1 {
t.Errorf("Expected 1 outer part, got %d", len(outerParts))
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageNotBase64(t *testing.T) {
// image with source.type != "base64" should be treated as non-image (falls through)
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NotB64-001",
"content": [
{"type": "text", "text": "some output"},
{
"type": "image",
"source": {"type": "url", "url": "https://example.com/img.png"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// Non-base64 image is treated as non-image, so it goes into the filtered results
// along with the text item. Since there are 2 non-image items, result is array.
resultArr := funcResp.Get("response.result")
if !resultArr.IsArray() {
t.Fatalf("Expected response.result to be an array (2 non-image items), got: %s", resultArr.Raw)
}
results := resultArr.Array()
if len(results) != 2 {
t.Fatalf("Expected 2 result items, got %d", len(results))
}
// No functionResponse.parts (no base64 images collected)
if funcResp.Get("parts").Exists() {
t.Error("functionResponse.parts should NOT exist when no base64 images")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingData(t *testing.T) {
// image with source.type=base64 but missing data field
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NoData-001",
"content": [
{"type": "text", "text": "output"},
{
"type": "image",
"source": {"type": "base64", "media_type": "image/png"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// The image is still classified as base64 image (type check passes),
// but data field is missing => inlineData has mimeType but no data
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 1 {
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Error("mimeType should still be set")
}
if imgParts[0].Get("inlineData.data").Exists() {
t.Error("data should not exist when source.data is missing")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResultImageMissingMediaType(t *testing.T) {
// image with source.type=base64 but missing media_type field
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "NoMime-001",
"content": [
{"type": "text", "text": "output"},
{
"type": "image",
"source": {"type": "base64", "data": "AAAA"}
}
]
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
if !gjson.Valid(outputStr) {
t.Fatalf("Result is not valid JSON:\n%s", outputStr)
}
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist")
}
// The image is still classified as base64 image,
// but media_type is missing => inlineData has data but no mimeType
imgParts := funcResp.Get("parts").Array()
if len(imgParts) != 1 {
t.Fatalf("Expected 1 image part, got %d", len(imgParts))
}
if imgParts[0].Get("inlineData.mimeType").Exists() {
t.Error("mimeType should not exist when media_type is missing")
}
if imgParts[0].Get("inlineData.data").String() != "AAAA" {
t.Error("data should still be set")
}
}
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
// When tools + thinking but no system instruction, should create one with hint
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"tools": [
{
"name": "get_weather",
"description": "Get weather",
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
}
],
"thinking": {"type": "enabled", "budget_tokens": 8000}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// System instruction should be created with hint
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if !sysInstruction.Exists() {
t.Fatal("systemInstruction should be created when tools + thinking are active")
}
sysText := sysInstruction.Get("parts").Array()
found := false
for _, part := range sysText {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
found = true
break
}
}
if !found {
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
}
}
================================================
FILE: internal/translator/antigravity/claude/antigravity_claude_response.go
================================================
// Package claude provides response translation functionality for Claude Code API compatibility.
// This package handles the conversion of backend client responses into Claude Code-compatible
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
// different response types including text content, thinking processes, and function calls.
// The translation ensures proper sequencing of SSE events and maintains state across
// multiple response chunks to provide a seamless streaming experience.
package claude
import (
"bytes"
"context"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Params holds parameters for response conversion and maintains state across streaming chunks.
// This structure tracks the current state of the response translation process to ensure
// proper sequencing of SSE events and transitions between different content types.
type Params struct {
HasFirstResponse bool // Indicates if the initial message_start event has been sent
ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function
ResponseIndex int // Index counter for content blocks in the streaming response
HasFinishReason bool // Tracks whether a finish reason has been observed
FinishReason string // The finish reason string returned by the provider
HasUsageMetadata bool // Tracks whether usage metadata has been observed
PromptTokenCount int64 // Cached prompt token count from usage metadata
CandidatesTokenCount int64 // Cached candidate token count from usage metadata
ThoughtsTokenCount int64 // Cached thinking token count from usage metadata
TotalTokenCount int64 // Cached total token count from usage metadata
CachedTokenCount int64 // Cached content token count (indicates prompt caching)
HasSentFinalEvents bool // Indicates if final content/message events have been sent
HasToolUse bool // Indicates if tool use was observed in the stream
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
// Signature caching support
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
}
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
var toolUseIDCounter uint64
// ConvertAntigravityResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
// and handles state transitions between content blocks, thinking processes, and function calls.
//
// Response type states: 0=none, 1=content, 2=thinking, 3=function
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &Params{
HasFirstResponse: false,
ResponseType: 0,
ResponseIndex: 0,
}
}
modelName := gjson.GetBytes(requestRawJSON, "model").String()
params := (*param).(*Params)
if bytes.Equal(rawJSON, []byte("[DONE]")) {
output := ""
// Only send final events if we have actually output content
if params.HasContent {
appendFinalEvents(params, &output, true)
return []string{
output + "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
}
}
return []string{}
}
output := ""
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk to establish the streaming session
if !params.HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values according to Claude Code API specification
// This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Use cpaUsageMetadata within the message_start event for Claude.
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
}
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
}
// Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
params.HasFirstResponse = true
}
// Process the response parts array from the backend client
// Each part can contain text content, thinking content, or function calls
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
// Extract the different types of content from each part
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
// Handle text content (both regular content and thinking)
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
// log.Debug("Branch: signature_delta")
if params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(modelName, params.CurrentThinkingText.String(), thoughtSignature.String())
// log.Debugf("Cached signature for thinking block (textLen=%d)", params.CurrentThinkingText.Len())
params.CurrentThinkingText.Reset()
}
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thoughtSignature.String()))
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.HasContent = true
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
params.CurrentThinkingText.WriteString(partTextResult.String())
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.HasContent = true
} else {
// Transition from another state to thinking
// First, close any existing content block
if params.ResponseType != 0 {
if params.ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 2 // Set state to thinking
params.HasContent = true
// Start accumulating thinking text for signature caching
params.CurrentThinkingText.Reset()
params.CurrentThinkingText.WriteString(partTextResult.String())
}
} else {
finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason")
if partTextResult.String() != "" || !finishReasonResult.Exists() {
// Process regular text content (user-visible output)
// Continue existing text block if already in content state
if params.ResponseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.HasContent = true
} else {
// Transition from another state to text content
// First, close any existing content block
if params.ResponseType != 0 {
if params.ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
}
if partTextResult.String() != "" {
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 1 // Set state to content
params.HasContent = true
}
}
}
}
} else if functionCallResult.Exists() {
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude Code API compatibility
params.HasToolUse = true
fcName := functionCallResult.Get("name").String()
// Handle state transitions when switching to function calls
// Close any existing function call block first
if params.ResponseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
params.ResponseType = 0
}
// Special handling for thinking state transition
if params.ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, params.ResponseIndex)
// output = output + "\n\n\n"
}
// Close any other existing content block
if params.ResponseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
output = output + "\n\n\n"
params.ResponseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude Code format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, params.ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
}
params.ResponseType = 3
params.HasContent = true
}
}
}
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
params.HasFinishReason = true
params.FinishReason = finishReasonResult.String()
}
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
params.HasUsageMetadata = true
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() - params.CachedTokenCount
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
if params.CandidatesTokenCount == 0 && params.TotalTokenCount > 0 {
params.CandidatesTokenCount = params.TotalTokenCount - params.PromptTokenCount - params.ThoughtsTokenCount
if params.CandidatesTokenCount < 0 {
params.CandidatesTokenCount = 0
}
}
}
if params.HasUsageMetadata && params.HasFinishReason {
appendFinalEvents(params, &output, false)
}
return []string{output}
}
func appendFinalEvents(params *Params, output *string, force bool) {
if params.HasSentFinalEvents {
return
}
if !params.HasUsageMetadata && !force {
return
}
// Only send final events if we have actually output content
if !params.HasContent {
return
}
if params.ResponseType != 0 {
*output = *output + "event: content_block_stop\n"
*output = *output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, params.ResponseIndex)
*output = *output + "\n\n\n"
params.ResponseType = 0
}
stopReason := resolveStopReason(params)
usageOutputTokens := params.CandidatesTokenCount + params.ThoughtsTokenCount
if usageOutputTokens == 0 && params.TotalTokenCount > 0 {
usageOutputTokens = params.TotalTokenCount - params.PromptTokenCount
if usageOutputTokens < 0 {
usageOutputTokens = 0
}
}
*output = *output + "event: message_delta\n"
*output = *output + "data: "
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if params.CachedTokenCount > 0 {
var err error
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
*output = *output + delta + "\n\n\n"
params.HasSentFinalEvents = true
}
func resolveStopReason(params *Params) string {
if params.HasToolUse {
return "tool_use"
}
switch params.FinishReason {
case "MAX_TOKENS":
return "max_tokens"
case "STOP", "FINISH_REASON_UNSPECIFIED", "UNKNOWN":
return "end_turn"
}
return "end_turn"
}
// ConvertAntigravityResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the Gemini CLI API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Claude-compatible JSON response.
func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
_ = originalRequestRawJSON
modelName := gjson.GetBytes(requestRawJSON, "model").String()
root := gjson.ParseBytes(rawJSON)
promptTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int()
thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int()
totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int()
cachedTokens := root.Get("response.usageMetadata.cachedContentTokenCount").Int()
outputTokens := candidateTokens + thoughtTokens
if outputTokens == 0 && totalTokens > 0 {
outputTokens = totalTokens - promptTokens
if outputTokens < 0 {
outputTokens = 0
}
}
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if cachedTokens > 0 {
var err error
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
contentArrayInitialized := false
ensureContentArray := func() {
if contentArrayInitialized {
return
}
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
contentArrayInitialized = true
}
parts := root.Get("response.candidates.0.content.parts")
textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{}
thinkingSignature := ""
toolIDCounter := 0
hasToolCall := false
flushText := func() {
if textBuilder.Len() == 0 {
return
}
ensureContentArray()
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
textBuilder.Reset()
}
flushThinking := func() {
if thinkingBuilder.Len() == 0 && thinkingSignature == "" {
return
}
ensureContentArray()
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
if thinkingSignature != "" {
block, _ = sjson.Set(block, "signature", fmt.Sprintf("%s#%s", cache.GetModelGroup(modelName), thinkingSignature))
}
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
thinkingBuilder.Reset()
thinkingSignature = ""
}
if parts.IsArray() {
for _, part := range parts.Array() {
isThought := part.Get("thought").Bool()
if isThought {
sig := part.Get("thoughtSignature")
if !sig.Exists() {
sig = part.Get("thought_signature")
}
if sig.Exists() && sig.String() != "" {
thinkingSignature = sig.String()
}
}
if text := part.Get("text"); text.Exists() && text.String() != "" {
if isThought {
flushText()
thinkingBuilder.WriteString(text.String())
continue
}
flushThinking()
textBuilder.WriteString(text.String())
continue
}
if functionCall := part.Get("functionCall"); functionCall.Exists() {
flushThinking()
flushText()
hasToolCall = true
name := functionCall.Get("name").String()
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) && args.IsObject() {
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
}
ensureContentArray()
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
continue
}
}
}
flushThinking()
flushText()
stopReason := "end_turn"
if hasToolCall {
stopReason = "tool_use"
} else {
if finish := root.Get("response.candidates.0.finishReason"); finish.Exists() {
switch finish.String() {
case "MAX_TOKENS":
stopReason = "max_tokens"
case "STOP", "FINISH_REASON_UNSPECIFIED", "UNKNOWN":
stopReason = "end_turn"
default:
stopReason = "end_turn"
}
}
}
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
if promptTokens == 0 && outputTokens == 0 {
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
responseJSON, _ = sjson.Delete(responseJSON, "usage")
}
}
return responseJSON
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
}
================================================
FILE: internal/translator/antigravity/claude/antigravity_claude_response_test.go
================================================
package claude
import (
"context"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
)
// ============================================================================
// Signature Caching Tests
// ============================================================================
func TestConvertAntigravityResponseToClaude_ParamsInitialized(t *testing.T) {
cache.ClearSignatureCache("")
// Request with user message - should initialize params
requestJSON := []byte(`{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
]
}`)
// First response chunk with thinking
responseJSON := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "Let me think...", "thought": true}]
}
}]
}
}`)
var param any
ctx := context.Background()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, ¶m)
params := param.(*Params)
if !params.HasFirstResponse {
t.Error("HasFirstResponse should be set after first chunk")
}
if params.CurrentThinkingText.Len() == 0 {
t.Error("Thinking text should be accumulated")
}
}
func TestConvertAntigravityResponseToClaude_ThinkingTextAccumulated(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"messages": [{"role": "user", "content": [{"type": "text", "text": "Test"}]}]
}`)
// First thinking chunk
chunk1 := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "First part of thinking...", "thought": true}]
}
}]
}
}`)
// Second thinking chunk (continuation)
chunk2 := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": " Second part of thinking...", "thought": true}]
}
}]
}
}`)
var param any
ctx := context.Background()
// Process first chunk - starts new thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk1, ¶m)
params := param.(*Params)
if params.CurrentThinkingText.Len() == 0 {
t.Error("Thinking text should be accumulated after first chunk")
}
// Process second chunk - continues thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk2, ¶m)
text := params.CurrentThinkingText.String()
if !strings.Contains(text, "First part") || !strings.Contains(text, "Second part") {
t.Errorf("Thinking text should accumulate both parts, got: %s", text)
}
}
func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
}`)
// Thinking chunk
thinkingChunk := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "My thinking process here", "thought": true}]
}
}]
}
}`)
// Signature chunk
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
signatureChunk := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSignature + `"}]
}
}]
}
}`)
var param any
ctx := context.Background()
// Process thinking chunk
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, ¶m)
params := param.(*Params)
thinkingText := params.CurrentThinkingText.String()
if thinkingText == "" {
t.Fatal("Thinking text should be accumulated")
}
// Process signature chunk - should cache the signature
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, ¶m)
// Verify signature was cached
cachedSig := cache.GetCachedSignature("claude-sonnet-4-5-thinking", thinkingText)
if cachedSig != validSignature {
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
}
// Verify thinking text was reset after caching
if params.CurrentThinkingText.Len() != 0 {
t.Error("Thinking text should be reset after signature is cached")
}
}
func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
}`)
validSig1 := "signature1_12345678901234567890123456789012345678901234567"
validSig2 := "signature2_12345678901234567890123456789012345678901234567"
// First thinking block with signature
block1Thinking := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "First thinking block", "thought": true}]
}
}]
}
}`)
block1Sig := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig1 + `"}]
}
}]
}
}`)
// Text content (breaks thinking)
textBlock := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "Regular text output"}]
}
}]
}
}`)
// Second thinking block with signature
block2Thinking := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "Second thinking block", "thought": true}]
}
}]
}
}`)
block2Sig := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig2 + `"}]
}
}]
}
}`)
var param any
ctx := context.Background()
// Process first thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, ¶m)
params := param.(*Params)
firstThinkingText := params.CurrentThinkingText.String()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, ¶m)
// Verify first signature cached
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", firstThinkingText) != validSig1 {
t.Error("First thinking block signature should be cached")
}
// Process text (transitions out of thinking)
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, textBlock, ¶m)
// Process second thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Thinking, ¶m)
secondThinkingText := params.CurrentThinkingText.String()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, ¶m)
// Verify second signature cached
if cache.GetCachedSignature("claude-sonnet-4-5-thinking", secondThinkingText) != validSig2 {
t.Error("Second thinking block signature should be cached")
}
}
================================================
FILE: internal/translator/antigravity/claude/init.go
================================================
package claude
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Claude,
Antigravity,
ConvertClaudeRequestToAntigravity,
interfaces.TranslateResponse{
Stream: ConvertAntigravityResponseToClaude,
NonStream: ConvertAntigravityResponseToClaudeNonStream,
TokenCount: ClaudeTokenCount,
},
)
}
================================================
FILE: internal/translator/antigravity/gemini/antigravity_gemini_request.go
================================================
// Package gemini provides request translation functionality for Gemini CLI to Gemini API compatibility.
// It handles parsing and transforming Gemini CLI API requests into Gemini API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini CLI API format and Gemini API's expected format.
package gemini
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiRequestToAntigravity parses and transforms a Gemini CLI API request into Gemini API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini API.
// The function performs the following transformations:
// 1. Extracts the model information from the request
// 2. Restructures the JSON to match Gemini API format
// 3. Converts system instructions to the expected format
// 4. Fixes CLI tool response format and grouping
//
// Parameters:
// - modelName: The name of the model to use for the request (unused in current implementation)
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
template := ""
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template)
if errFixCLIToolResponse != nil {
return []byte{}
}
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJSON = []byte(template)
// Normalize roles in request.contents: default to valid values if missing/invalid
contents := gjson.GetBytes(rawJSON, "request.contents")
if contents.Exists() {
prevRole := ""
idx := 0
contents.ForEach(func(_ gjson.Result, value gjson.Result) bool {
role := value.Get("role").String()
valid := role == "user" || role == "model"
if role == "" || !valid {
var newRole string
if prevRole == "" {
newRole = "user"
} else if prevRole == "user" {
newRole = "model"
} else {
newRole = "user"
}
path := fmt.Sprintf("request.contents.%d.role", idx)
rawJSON, _ = sjson.SetBytes(rawJSON, path, newRole)
role = newRole
}
prevRole = role
idx++
return true
})
}
toolsResult := gjson.GetBytes(rawJSON, "request.tools")
if toolsResult.Exists() && toolsResult.IsArray() {
toolResults := toolsResult.Array()
for i := 0; i < len(toolResults); i++ {
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations", i))
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() {
functionDeclarationsResults := functionDeclarationsResult.Array()
for j := 0; j < len(functionDeclarationsResults); j++ {
parametersResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j))
if parametersResult.Exists() {
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j), fmt.Sprintf("request.tools.%d.function_declarations.%d.parametersJsonSchema", i, j))
rawJSON = []byte(strJson)
}
}
}
}
}
// Gemini-specific handling for non-Claude models:
// - Add skip_thought_signature_validator to functionCall parts so upstream can bypass signature validation.
// - Also mark thinking parts with the same sentinel when present (we keep the parts; we only annotate them).
if !strings.Contains(modelName, "claude") {
const skipSentinel = "skip_thought_signature_validator"
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" {
// First pass: collect indices of thinking parts to mark with skip sentinel
var thinkingIndicesToSkipSignature []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Collect indices of thinking blocks to mark with skip sentinel
if part.Get("thought").Bool() {
thinkingIndicesToSkipSignature = append(thinkingIndicesToSkipSignature, partIdx.Int())
}
// Add skip sentinel to functionCall parts
if part.Get("functionCall").Exists() {
existingSig := part.Get("thoughtSignature").String()
if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
}
}
return true
})
// Add skip_thought_signature_validator sentinel to thinking blocks in reverse order to preserve indices
for i := len(thinkingIndicesToSkipSignature) - 1; i >= 0; i-- {
idx := thinkingIndicesToSkipSignature[i]
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), idx), skipSentinel)
}
}
return true
})
}
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
}
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ResponsesNeeded int
CallNames []string // ordered function call names for backfilling empty response names
}
// parseFunctionResponseRaw attempts to normalize a function response part into a JSON object string.
// Falls back to a minimal "functionResponse" object when parsing fails.
// fallbackName is used when the response's own name is empty.
func parseFunctionResponseRaw(response gjson.Result, fallbackName string) string {
if response.IsObject() && gjson.Valid(response.Raw) {
raw := response.Raw
name := response.Get("functionResponse.name").String()
if strings.TrimSpace(name) == "" && fallbackName != "" {
raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName)
}
return raw
}
log.Debugf("parse function response failed, using fallback")
funcResp := response.Get("functionResponse")
if funcResp.Exists() {
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
name := funcResp.Get("name").String()
if strings.TrimSpace(name) == "" {
name = fallbackName
}
fr, _ = sjson.Set(fr, "functionResponse.name", name)
fr, _ = sjson.Set(fr, "functionResponse.response.result", funcResp.Get("response").String())
if id := funcResp.Get("id").String(); id != "" {
fr, _ = sjson.Set(fr, "functionResponse.id", id)
}
return fr
}
useName := fallbackName
if useName == "" {
useName = "unknown"
}
fr := `{"functionResponse":{"name":"","response":{"result":""}}}`
fr, _ = sjson.Set(fr, "functionResponse.name", useName)
fr, _ = sjson.Set(fr, "functionResponse.response.result", response.String())
return fr
}
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
// This function transforms the CLI tool response format by intelligently grouping function calls
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
// and their responses are properly associated and structured.
//
// Parameters:
// - input: The input JSON string to be processed
//
// Returns:
// - string: The processed JSON string with grouped function calls and responses
// - error: An error if the processing fails
func fixCLIToolResponse(input string) (string, error) {
// Parse the input JSON to extract the conversation structure
parsed := gjson.Parse(input)
// Extract the contents array which contains the conversation messages
contents := parsed.Get("request.contents")
if !contents.Exists() {
// log.Debugf(input)
return input, fmt.Errorf("contents not found in input")
}
// Initialize data structures for processing and grouping
contentsWrapper := `{"contents":[]}`
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
// Process each content object in the conversation
// This iterates through messages and groups function calls with their responses
contents.ForEach(func(key, value gjson.Result) bool {
role := value.Get("role").String()
parts := value.Get("parts")
// Check if this content has function responses
var responsePartsInThisContent []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
responsePartsInThisContent = append(responsePartsInThisContent, part)
}
return true
})
// If this content has function responses, collect them
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if pending groups can be satisfied (FIFO: oldest group first)
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
group := pendingGroups[0]
pendingGroups = pendingGroups[1:]
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
functionResponseContent := `{"parts":[],"role":"function"}`
for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
}
return true // Skip adding this content, responses are merged
}
// If this is a model with function calls, create a new group
if role == "model" {
var callNames []string
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
callNames = append(callNames, part.Get("functionCall.name").String())
}
return true
})
if len(callNames) > 0 {
// Add the model content
if !value.IsObject() {
log.Warnf("failed to parse model content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ResponsesNeeded: len(callNames),
CallNames: callNames,
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
if !value.IsObject() {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
}
} else {
// Non-model content (user, etc.)
if !value.IsObject() {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
}
return true
})
// Handle any remaining pending groups with remaining responses
for _, group := range pendingGroups {
if len(collectedResponses) >= group.ResponsesNeeded {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
functionResponseContent := `{"parts":[],"role":"function"}`
for ri, response := range groupResponses {
partRaw := parseFunctionResponseRaw(response, group.CallNames[ri])
if partRaw != "" {
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", partRaw)
}
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
return result, nil
}
================================================
FILE: internal/translator/antigravity/gemini/antigravity_gemini_request_test.go
================================================
package gemini
import (
"fmt"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T) {
// Valid signature on functionCall should be preserved
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(fmt.Sprintf(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "test_tool", "args": {}}, "thoughtSignature": "%s"}
]
}
]
}`, validSignature))
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that valid thoughtSignature is preserved
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part, got %d", len(parts))
}
sig := parts[0].Get("thoughtSignature").String()
if sig != validSignature {
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, sig)
}
}
func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *testing.T) {
// functionCall without signature should get skip_thought_signature_validator
inputJSON := []byte(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "test_tool", "args": {}}}
]
}
]
}`)
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that skip_thought_signature_validator is added to functionCall
sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String()
expectedSig := "skip_thought_signature_validator"
if sig != expectedSig {
t.Errorf("Expected skip sentinel '%s', got '%s'", expectedSig, sig)
}
}
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
// Multiple functionCalls should all get skip_thought_signature_validator
inputJSON := []byte(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "tool_one", "args": {"a": "1"}}},
{"functionCall": {"name": "tool_two", "args": {"b": "2"}}}
]
}
]
}`)
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(parts))
}
expectedSig := "skip_thought_signature_validator"
for i, part := range parts {
sig := part.Get("thoughtSignature").String()
if sig != expectedSig {
t.Errorf("Part %d: Expected '%s', got '%s'", i, expectedSig, sig)
}
}
}
func TestFixCLIToolResponse_PreservesFunctionResponseParts(t *testing.T) {
// When functionResponse contains a "parts" field with inlineData (from Claude
// translator's image embedding), fixCLIToolResponse should preserve it as-is.
// parseFunctionResponseRaw returns response.Raw for valid JSON objects,
// so extra fields like "parts" survive the pipeline.
input := `{
"model": "claude-opus-4-6-thinking",
"request": {
"contents": [
{
"role": "model",
"parts": [
{
"functionCall": {"name": "screenshot", "args": {}}
}
]
},
{
"role": "function",
"parts": [
{
"functionResponse": {
"id": "tool-001",
"name": "screenshot",
"response": {"result": "Screenshot taken"},
"parts": [
{"inlineData": {"mimeType": "image/png", "data": "iVBOR"}}
]
}
}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
// Find the function response content (role=function)
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
// The functionResponse should be preserved with its parts field
funcResp := funcContent.Get("parts.0.functionResponse")
if !funcResp.Exists() {
t.Fatal("functionResponse should exist in output")
}
// Verify the parts field with inlineData is preserved
inlineParts := funcResp.Get("parts").Array()
if len(inlineParts) != 1 {
t.Fatalf("Expected 1 inlineData part in functionResponse.parts, got %d", len(inlineParts))
}
if inlineParts[0].Get("inlineData.mimeType").String() != "image/png" {
t.Errorf("Expected mimeType 'image/png', got '%s'", inlineParts[0].Get("inlineData.mimeType").String())
}
if inlineParts[0].Get("inlineData.data").String() != "iVBOR" {
t.Errorf("Expected data 'iVBOR', got '%s'", inlineParts[0].Get("inlineData.data").String())
}
// Verify response.result is also preserved
if funcResp.Get("response.result").String() != "Screenshot taken" {
t.Errorf("Expected response.result 'Screenshot taken', got '%s'", funcResp.Get("response.result").String())
}
}
func TestFixCLIToolResponse_BackfillsEmptyFunctionResponseName(t *testing.T) {
// When the Amp client sends functionResponse with an empty name,
// fixCLIToolResponse should backfill it from the corresponding functionCall.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
name := funcContent.Get("parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
}
}
func TestFixCLIToolResponse_BackfillsMultipleEmptyNames(t *testing.T) {
// Parallel function calls: both responses have empty names.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
parts := funcContent.Get("parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 function response parts, got %d", len(parts))
}
name0 := parts[0].Get("functionResponse.name").String()
name1 := parts[1].Get("functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first response name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second response name 'Grep', got '%s'", name1)
}
}
func TestFixCLIToolResponse_PreservesExistingName(t *testing.T) {
// When functionResponse already has a valid name, it should be preserved.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
name := funcContent.Get("parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
}
}
func TestFixCLIToolResponse_MoreResponsesThanCalls(t *testing.T) {
// If there are more function responses than calls, unmatched extras are discarded by grouping.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContent gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContent = c
break
}
}
if !funcContent.Exists() {
t.Fatal("function role content should exist in output")
}
// First response should be backfilled from the call
name0 := funcContent.Get("parts.0.functionResponse.name").String()
if name0 != "Bash" {
t.Errorf("Expected first response name 'Bash', got '%s'", name0)
}
}
func TestFixCLIToolResponse_MultipleGroupsFIFO(t *testing.T) {
// Two sequential function call groups should be matched FIFO.
input := `{
"model": "gemini-3-pro-preview",
"request": {
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "file content"}}}
]
},
{
"role": "model",
"parts": [
{"functionCall": {"name": "Grep", "args": {}}}
]
},
{
"role": "function",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "match"}}}
]
}
]
}
}`
result, err := fixCLIToolResponse(input)
if err != nil {
t.Fatalf("fixCLIToolResponse failed: %v", err)
}
contents := gjson.Get(result, "request.contents").Array()
var funcContents []gjson.Result
for _, c := range contents {
if c.Get("role").String() == "function" {
funcContents = append(funcContents, c)
}
}
if len(funcContents) != 2 {
t.Fatalf("Expected 2 function contents, got %d", len(funcContents))
}
name0 := funcContents[0].Get("parts.0.functionResponse.name").String()
name1 := funcContents[1].Get("parts.0.functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first group name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
}
}
================================================
FILE: internal/translator/antigravity/gemini/antigravity_gemini_response.go
================================================
// Package gemini provides request translation functionality for Gemini to Gemini CLI API compatibility.
// It handles parsing and transforming Gemini API requests into Gemini CLI API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini API format and Gemini CLI API's expected format.
package gemini
import (
"bytes"
"context"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertAntigravityResponseToGemini parses and transforms a Gemini CLI API request into Gemini API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini API.
// The function performs the following transformations:
// 1. Extracts the response data from the request
// 2. Handles alternative response formats
// 3. Processes array responses by extracting individual response objects
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model to use for the request (unused in current implementation)
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - []string: The transformed request data in Gemini API format
func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
if alt, ok := ctx.Value("alt").(string); ok {
var chunk []byte
if alt == "" {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
chunk = []byte(responseResult.Raw)
chunk = restoreUsageMetadata(chunk)
}
} else {
chunkTemplate := "[]"
responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() {
responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
}
}
}
chunk = []byte(chunkTemplate)
}
return []string{string(chunk)}
}
return []string{}
}
// ConvertAntigravityResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
// This function processes the complete Gemini CLI request and transforms it into a single Gemini-compatible
// JSON response. It extracts the response data from the request and returns it in the expected format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing the response data
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
chunk := restoreUsageMetadata([]byte(responseResult.Raw))
return string(chunk)
}
return string(rawJSON)
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}
// restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata.
// The executor renames usageMetadata to cpaUsageMetadata in non-terminal chunks
// to preserve usage data while hiding it from clients that don't expect it.
// When returning standard Gemini API format, we must restore the original name.
func restoreUsageMetadata(chunk []byte) []byte {
if cpaUsage := gjson.GetBytes(chunk, "cpaUsageMetadata"); cpaUsage.Exists() {
chunk, _ = sjson.SetRawBytes(chunk, "usageMetadata", []byte(cpaUsage.Raw))
chunk, _ = sjson.DeleteBytes(chunk, "cpaUsageMetadata")
}
return chunk
}
================================================
FILE: internal/translator/antigravity/gemini/antigravity_gemini_response_test.go
================================================
package gemini
import (
"context"
"testing"
)
func TestRestoreUsageMetadata(t *testing.T) {
tests := []struct {
name string
input []byte
expected string
}{
{
name: "cpaUsageMetadata renamed to usageMetadata",
input: []byte(`{"modelVersion":"gemini-3-pro","cpaUsageMetadata":{"promptTokenCount":100,"candidatesTokenCount":200}}`),
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":200}}`,
},
{
name: "no cpaUsageMetadata unchanged",
input: []byte(`{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`),
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`,
},
{
name: "empty input",
input: []byte(`{}`),
expected: `{}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := restoreUsageMetadata(tt.input)
if string(result) != tt.expected {
t.Errorf("restoreUsageMetadata() = %s, want %s", string(result), tt.expected)
}
})
}
}
func TestConvertAntigravityResponseToGeminiNonStream(t *testing.T) {
tests := []struct {
name string
input []byte
expected string
}{
{
name: "cpaUsageMetadata restored in response",
input: []byte(`{"response":{"modelVersion":"gemini-3-pro","cpaUsageMetadata":{"promptTokenCount":100}}}`),
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`,
},
{
name: "usageMetadata preserved",
input: []byte(`{"response":{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}}`),
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil)
if result != tt.expected {
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", result, tt.expected)
}
})
}
}
func TestConvertAntigravityResponseToGeminiStream(t *testing.T) {
ctx := context.WithValue(context.Background(), "alt", "")
tests := []struct {
name string
input []byte
expected string
}{
{
name: "cpaUsageMetadata restored in streaming response",
input: []byte(`data: {"response":{"modelVersion":"gemini-3-pro","cpaUsageMetadata":{"promptTokenCount":100}}}`),
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
results := ConvertAntigravityResponseToGemini(ctx, "", nil, nil, tt.input, nil)
if len(results) != 1 {
t.Fatalf("expected 1 result, got %d", len(results))
}
if results[0] != tt.expected {
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", results[0], tt.expected)
}
})
}
}
================================================
FILE: internal/translator/antigravity/gemini/init.go
================================================
package gemini
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Gemini,
Antigravity,
ConvertGeminiRequestToAntigravity,
interfaces.TranslateResponse{
Stream: ConvertAntigravityResponseToGemini,
NonStream: ConvertAntigravityResponseToGeminiNonStream,
TokenCount: GeminiTokenCount,
},
)
}
================================================
FILE: internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go
================================================
// Package openai provides request translation functionality for OpenAI to Gemini CLI API compatibility.
// It converts OpenAI Chat Completions requests into Gemini CLI compatible JSON using gjson/sjson only.
package chat_completions
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
// ConvertOpenAIRequestToAntigravity converts an OpenAI Chat Completions request (raw JSON)
// into a complete Gemini CLI request JSON. All JSON construction uses sjson and lookups use gjson.
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the OpenAI API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini CLI API format
func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
// Base envelope (no default thinkingConfig)
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
// Model
out, _ = sjson.SetBytes(out, "model", modelName)
// Let user-provided generationConfig pass through
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw))
}
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
re := gjson.GetBytes(rawJSON, "reasoning_effort")
if re.Exists() {
effort := strings.ToLower(strings.TrimSpace(re.String()))
if effort != "" {
thinkingPath := "request.generationConfig.thinkingConfig"
if effort == "auto" {
out, _ = sjson.SetBytes(out, thinkingPath+".thinkingBudget", -1)
out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", true)
} else {
out, _ = sjson.SetBytes(out, thinkingPath+".thinkingLevel", effort)
out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", effort != "none")
}
}
}
// Temperature/top_p/top_k/max_tokens
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
}
if tpr := gjson.GetBytes(rawJSON, "top_p"); tpr.Exists() && tpr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", tpr.Num)
}
if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
}
if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
}
// Candidate count (OpenAI 'n' parameter)
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
if val := n.Int(); val > 1 {
out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val)
}
}
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
var responseMods []string
for _, m := range mods.Array() {
switch strings.ToLower(m.String()) {
case "text":
responseMods = append(responseMods, "TEXT")
case "image":
responseMods = append(responseMods, "IMAGE")
}
}
if len(responseMods) > 0 {
out, _ = sjson.SetBytes(out, "request.generationConfig.responseModalities", responseMods)
}
}
// OpenRouter-style image_config support
// If the input uses top-level image_config.aspect_ratio, map it into request.generationConfig.imageConfig.aspectRatio.
if imgCfg := gjson.GetBytes(rawJSON, "image_config"); imgCfg.Exists() && imgCfg.IsObject() {
if ar := imgCfg.Get("aspect_ratio"); ar.Exists() && ar.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.aspectRatio", ar.Str)
}
if size := imgCfg.Get("image_size"); size.Exists() && size.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.imageSize", size.Str)
}
}
// messages -> systemInstruction + contents
messages := gjson.GetBytes(rawJSON, "messages")
if messages.IsArray() {
arr := messages.Array()
// First pass: assistant tool_calls id->name map
tcID2Name := map[string]string{}
for i := 0; i < len(arr); i++ {
m := arr[i]
if m.Get("role").String() == "assistant" {
tcs := m.Get("tool_calls")
if tcs.IsArray() {
for _, tc := range tcs.Array() {
if tc.Get("type").String() == "function" {
id := tc.Get("id").String()
name := tc.Get("function.name").String()
if id != "" && name != "" {
tcID2Name[id] = name
}
}
}
}
}
}
// Second pass build systemInstruction/tool responses cache
toolResponses := map[string]string{} // tool_call_id -> response text
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
if role == "tool" {
toolCallID := m.Get("tool_call_id").String()
if toolCallID != "" {
c := m.Get("content")
toolResponses[toolCallID] = c.Raw
}
}
}
systemPartIndex := 0
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
content := m.Get("content")
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.String())
systemPartIndex++
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
systemPartIndex++
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
systemPartIndex++
}
}
}
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
} else if content.IsArray() {
items := content.Array()
p := 0
for _, item := range items {
switch item.Get("type").String() {
case "text":
text := item.Get("text").String()
if text != "" {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text)
}
p++
case "image_url":
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 {
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
case "file":
filename := item.Get("file.filename").String()
fileData := item.Get("file.file_data").String()
ext := ""
if sp := strings.Split(filename, "."); len(sp) > 1 {
ext = sp[len(sp)-1]
}
if mimeType, ok := misc.MimeTypes[ext]; ok {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
p++
} else {
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
}
case "input_audio":
audioData := item.Get("input_audio.data").String()
audioFormat := item.Get("input_audio.format").String()
if audioData != "" {
audioMimeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"webm": "audio/webm",
"pcm16": "audio/pcm",
"g711_ulaw": "audio/basic",
"g711_alaw": "audio/basic",
}
mimeType := "audio/wav"
if audioFormat != "" {
if mapped, ok := audioMimeMap[audioFormat]; ok {
mimeType = mapped
} else {
mimeType = "audio/" + audioFormat
}
}
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", audioData)
p++
}
}
}
}
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
} else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`)
p := 0
if content.Type == gjson.String && content.String() != "" {
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
text := item.Get("text").String()
if text != "" {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text)
}
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mimeType", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
}
}
}
// Tool calls -> single model content with functionCall parts
tcs := m.Get("tool_calls")
if tcs.IsArray() {
fIDs := make([]string, 0)
for _, tc := range tcs.Array() {
if tc.Get("type").String() != "function" {
continue
}
fid := tc.Get("id").String()
fname := tc.Get("function.name").String()
fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
if gjson.Valid(fargs) {
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
} else {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.args.params", []byte(fargs))
}
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
if fid != "" {
fIDs = append(fIDs, fid)
}
}
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
resp := toolResponses[fid]
if resp == "" {
resp = "{}"
}
// Handle non-JSON output gracefully (matches dev branch approach)
if resp != "null" {
parsed := gjson.Parse(resp)
if parsed.Type == gjson.JSON {
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
} else {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
}
}
pp++
}
}
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
}
}
}
}
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch/codeExecution/urlContext passthrough
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
functionToolNode := []byte(`{}`)
hasFunction := false
googleSearchNodes := make([][]byte, 0)
codeExecutionNodes := make([][]byte, 0)
urlContextNodes := make([][]byte, 0)
for _, t := range tools.Array() {
if t.Get("type").String() == "function" {
fn := t.Get("function")
if fn.Exists() && fn.IsObject() {
fnRaw := fn.Raw
if fn.Get("parameters").Exists() {
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema")
if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
} else {
fnRaw = renamed
}
} else {
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction {
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
functionToolNode = tmp
hasFunction = true
}
}
if gs := t.Get("google_search"); gs.Exists() {
googleToolNode := []byte(`{}`)
var errSet error
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue
}
googleSearchNodes = append(googleSearchNodes, googleToolNode)
}
if ce := t.Get("code_execution"); ce.Exists() {
codeToolNode := []byte(`{}`)
var errSet error
codeToolNode, errSet = sjson.SetRawBytes(codeToolNode, "codeExecution", []byte(ce.Raw))
if errSet != nil {
log.Warnf("Failed to set codeExecution tool: %v", errSet)
continue
}
codeExecutionNodes = append(codeExecutionNodes, codeToolNode)
}
if uc := t.Get("url_context"); uc.Exists() {
urlToolNode := []byte(`{}`)
var errSet error
urlToolNode, errSet = sjson.SetRawBytes(urlToolNode, "urlContext", []byte(uc.Raw))
if errSet != nil {
log.Warnf("Failed to set urlContext tool: %v", errSet)
continue
}
urlContextNodes = append(urlContextNodes, urlToolNode)
}
}
if hasFunction || len(googleSearchNodes) > 0 || len(codeExecutionNodes) > 0 || len(urlContextNodes) > 0 {
toolsNode := []byte("[]")
if hasFunction {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
}
for _, googleNode := range googleSearchNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
}
for _, codeNode := range codeExecutionNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", codeNode)
}
for _, urlNode := range urlContextNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", urlNode)
}
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
}
}
return common.AttachDefaultSafetySettings(out, "request.safetySettings")
}
// itoa converts int to string without strconv import for few usages.
func itoa(i int) string { return fmt.Sprintf("%d", i) }
================================================
FILE: internal/translator/antigravity/openai/chat-completions/antigravity_openai_response.go
================================================
// Package openai provides response translation functionality for Gemini CLI to OpenAI API compatibility.
// This package handles the conversion of Gemini CLI API responses into OpenAI Chat Completions-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by OpenAI API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, reasoning content, and usage metadata appropriately.
package chat_completions
import (
"bytes"
"context"
"fmt"
"strings"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// convertCliResponseToOpenAIChatParams holds parameters for response conversion.
type convertCliResponseToOpenAIChatParams struct {
UnixTimestamp int64
FunctionIndex int
SawToolCall bool // Tracks if any tool call was seen in the entire stream
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
}
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
var functionCallIDCounter uint64
// ConvertAntigravityResponseToOpenAI translates a single chunk of a streaming response from the
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
// The function handles text content, tool calls, reasoning content, and usage metadata, outputting
// responses that match the OpenAI API format. It supports incremental updates for streaming responses.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &convertCliResponseToOpenAIChatParams{
UnixTimestamp: 0,
FunctionIndex: 0,
}
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
} else {
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
}
// Extract and set the response ID.
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
}
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
(*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(finishReasonResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
}
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
thoughtSignatureResult := partResult.Get("thoughtSignature")
if !thoughtSignatureResult.Exists() {
thoughtSignatureResult = partResult.Get("thought_signature")
}
inlineDataResult := partResult.Get("inlineData")
if !inlineDataResult.Exists() {
inlineDataResult = partResult.Get("inline_data")
}
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
// Ignore encrypted thoughtSignature but keep any actual content in the same part.
if hasThoughtSignature && !hasContentPayload {
continue
}
if partTextResult.Exists() {
textContent := partTextResult.String()
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
functionCallIndex = len(toolCallsResult.Array())
} else {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
}
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
} else if inlineDataResult.Exists() {
data := inlineDataResult.Get("data").String()
if data == "" {
continue
}
mimeType := inlineDataResult.Get("mimeType").String()
if mimeType == "" {
mimeType = inlineDataResult.Get("mime_type").String()
}
if mimeType == "" {
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}
}
}
// Determine finish_reason only on the final chunk (has both finishReason and usage metadata)
params := (*param).(*convertCliResponseToOpenAIChatParams)
upstreamFinishReason := params.UpstreamFinishReason
sawToolCall := params.SawToolCall
usageExists := gjson.GetBytes(rawJSON, "response.usageMetadata").Exists()
isFinalChunk := upstreamFinishReason != "" && usageExists
if isFinalChunk {
var finishReason string
if sawToolCall {
finishReason = "tool_calls"
} else if upstreamFinishReason == "MAX_TOKENS" {
finishReason = "max_tokens"
} else {
finishReason = "stop"
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
}
return []string{template}
}
// ConvertAntigravityResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
// This function processes the complete Gemini CLI response and transforms it into a single OpenAI-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the OpenAI API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertAntigravityResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
}
return ""
}
================================================
FILE: internal/translator/antigravity/openai/chat-completions/antigravity_openai_response_test.go
================================================
package chat_completions
import (
"context"
"testing"
"github.com/tidwall/gjson"
)
func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
ctx := context.Background()
var param any
// Chunk 1: Contains functionCall - should set SawToolCall = true
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_files","args":{"path":"."}}}]}}]}}`)
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
// Verify chunk1 has no finish_reason (null)
if len(result1) != 1 {
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
}
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
}
// Chunk 2: Contains finishReason STOP + usage (final chunk, no functionCall)
// This simulates what the upstream sends AFTER the tool call chunk
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":20,"totalTokenCount":30}}}`)
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
// Verify chunk2 has finish_reason: "tool_calls" (not "stop")
if len(result2) != 1 {
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
}
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String()
if fr2 != "tool_calls" {
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
}
// Verify native_finish_reason is lowercase upstream value
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String()
if nfr2 != "stop" {
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
}
}
func TestFinishReasonStopForNormalText(t *testing.T) {
ctx := context.Background()
var param any
// Chunk 1: Text content only
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello world"}]}}]}}`)
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
// Chunk 2: Final chunk with STOP
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}}`)
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
// Verify finish_reason is "stop" (no tool calls were made)
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
if fr != "stop" {
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
}
}
func TestFinishReasonMaxTokens(t *testing.T) {
ctx := context.Background()
var param any
// Chunk 1: Text content
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}`)
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
// Chunk 2: Final chunk with MAX_TOKENS
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"MAX_TOKENS"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":100,"totalTokenCount":110}}}`)
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
// Verify finish_reason is "max_tokens"
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
if fr != "max_tokens" {
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
}
}
func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
ctx := context.Background()
var param any
// Chunk 1: Contains functionCall
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"test","args":{}}}]}}]}}`)
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
// Chunk 2: Final chunk with MAX_TOKENS (but we had a tool call, so tool_calls should win)
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"MAX_TOKENS"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":100,"totalTokenCount":110}}}`)
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
if fr != "tool_calls" {
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
}
}
func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
ctx := context.Background()
var param any
// Chunk 1: Text content (no finish reason, no usage)
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}`)
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
// Verify no finish_reason on intermediate chunk
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
}
// Chunk 2: More text (no finish reason, no usage)
chunk2 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":" world"}]}}]}}`)
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
// Verify no finish_reason on intermediate chunk
fr2 := gjson.Get(result2[0], "choices.0.finish_reason")
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
}
}
================================================
FILE: internal/translator/antigravity/openai/chat-completions/init.go
================================================
package chat_completions
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenAI,
Antigravity,
ConvertOpenAIRequestToAntigravity,
interfaces.TranslateResponse{
Stream: ConvertAntigravityResponseToOpenAI,
NonStream: ConvertAntigravityResponseToOpenAINonStream,
},
)
}
================================================
FILE: internal/translator/antigravity/openai/responses/antigravity_openai-responses_request.go
================================================
package responses
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
)
func ConvertOpenAIResponsesRequestToAntigravity(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
return ConvertGeminiRequestToAntigravity(modelName, rawJSON, stream)
}
================================================
FILE: internal/translator/antigravity/openai/responses/antigravity_openai-responses_response.go
================================================
package responses
import (
"context"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
"github.com/tidwall/gjson"
)
func ConvertAntigravityResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)
}
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
func ConvertAntigravityResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)
}
requestResult := gjson.GetBytes(originalRequestRawJSON, "request")
if responseResult.Exists() {
originalRequestRawJSON = []byte(requestResult.Raw)
}
requestResult = gjson.GetBytes(requestRawJSON, "request")
if responseResult.Exists() {
requestRawJSON = []byte(requestResult.Raw)
}
return ConvertGeminiResponseToOpenAIResponsesNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
================================================
FILE: internal/translator/antigravity/openai/responses/init.go
================================================
package responses
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenaiResponse,
Antigravity,
ConvertOpenAIResponsesRequestToAntigravity,
interfaces.TranslateResponse{
Stream: ConvertAntigravityResponseToOpenAIResponses,
NonStream: ConvertAntigravityResponseToOpenAIResponsesNonStream,
},
)
}
================================================
FILE: internal/translator/claude/gemini/claude_gemini_request.go
================================================
// Package gemini provides request translation functionality for Gemini to Claude Code API compatibility.
// It handles parsing and transforming Gemini API requests into Claude Code API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini API format and Claude Code API's expected format.
package gemini
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"math/big"
"strings"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
user = ""
account = ""
session = ""
)
// ConvertGeminiRequestToClaude parses and transforms a Gemini API request into Claude Code API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Claude Code API.
// The function performs comprehensive transformation including:
// 1. Model name mapping and generation configuration extraction
// 2. System instruction conversion to Claude Code format
// 3. Message content conversion with proper role mapping
// 4. Tool call and tool result handling with FIFO queue for ID matching
// 5. Image and file data conversion to Claude Code base64 format
// 6. Tool declaration and tool choice configuration mapping
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the Gemini API
// - stream: A boolean indicating if the request is for a streaming response
//
// Returns:
// - []byte: The transformed request data in Claude Code API format
func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
if account == "" {
u, _ := uuid.NewRandom()
account = u.String()
}
if session == "" {
u, _ := uuid.NewRandom()
session = u.String()
}
if user == "" {
sum := sha256.Sum256([]byte(account + session))
user = hex.EncodeToString(sum[:])
}
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude message payload
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
root := gjson.ParseBytes(rawJSON)
// Helper for generating tool call IDs in the form: toolu_
// This ensures unique identifiers for tool calls in the Claude Code format
genToolCallID := func() string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var b strings.Builder
// 24 chars random suffix for uniqueness
for i := 0; i < 24; i++ {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
b.WriteByte(letters[n.Int64()])
}
return "toolu_" + b.String()
}
// FIFO queue to store tool call IDs for matching with tool results
// Gemini uses sequential pairing across possibly multiple in-flight
// functionCalls, so we keep a FIFO queue of generated tool IDs and
// consume them in order when functionResponses arrive.
var pendingToolIDs []string
// Model mapping to specify which Claude Code model to use
out, _ = sjson.Set(out, "model", modelName)
// Generation config extraction from Gemini format
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
// Max output tokens configuration
if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
}
// Temperature setting for controlling response randomness
if temp := genConfig.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
} else if topP := genConfig.Get("topP"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float())
}
// Stop sequences configuration for custom termination conditions
if stopSeqs := genConfig.Get("stopSequences"); stopSeqs.Exists() && stopSeqs.IsArray() {
var stopSequences []string
stopSeqs.ForEach(func(_, value gjson.Result) bool {
stopSequences = append(stopSequences, value.String())
return true
})
if len(stopSequences) > 0 {
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
}
}
// Include thoughts configuration for reasoning process visibility
// Translator only does format conversion, ApplyThinking handles model capability validation.
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
mi := registry.LookupModelInfo(modelName, "claude")
supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0
supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax))
// MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid
// validation errors since validate treats same-provider unsupported levels as errors.
thinkingLevel := thinkingConfig.Get("thinkingLevel")
if !thinkingLevel.Exists() {
thinkingLevel = thinkingConfig.Get("thinking_level")
}
if thinkingLevel.Exists() {
level := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
if supportsAdaptive {
switch level {
case "":
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
default:
if mapped, ok := thinking.MapToClaudeEffort(level, supportsMax); ok {
level = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", level)
}
} else {
switch level {
case "":
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
default:
if budget, ok := thinking.ConvertLevelToBudget(level); ok {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
}
}
}
} else {
thinkingBudget := thinkingConfig.Get("thinkingBudget")
if !thinkingBudget.Exists() {
thinkingBudget = thinkingConfig.Get("thinking_budget")
}
if thinkingBudget.Exists() {
budget := int(thinkingBudget.Int())
if supportsAdaptive {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
default:
level, ok := thinking.ConvertBudgetToLevel(budget)
if ok {
if mapped, okM := thinking.MapToClaudeEffort(level, supportsMax); okM {
level = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", level)
}
}
} else {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
default:
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
}
}
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
out, _ = sjson.Set(out, "thinking.type", "enabled")
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
out, _ = sjson.Set(out, "thinking.type", "enabled")
}
}
}
}
// System instruction conversion to Claude Code format
if sysInstr := root.Get("system_instruction"); sysInstr.Exists() {
if parts := sysInstr.Get("parts"); parts.Exists() && parts.IsArray() {
var systemText strings.Builder
parts.ForEach(func(_, part gjson.Result) bool {
if text := part.Get("text"); text.Exists() {
if systemText.Len() > 0 {
systemText.WriteString("\n")
}
systemText.WriteString(text.String())
}
return true
})
if systemText.Len() > 0 {
// Create system message in Claude Code format
systemMessage := `{"role":"user","content":[{"type":"text","text":""}]}`
systemMessage, _ = sjson.Set(systemMessage, "content.0.text", systemText.String())
out, _ = sjson.SetRaw(out, "messages.-1", systemMessage)
}
}
}
// Contents conversion to messages with proper role mapping
if contents := root.Get("contents"); contents.Exists() && contents.IsArray() {
contents.ForEach(func(_, content gjson.Result) bool {
role := content.Get("role").String()
// Map Gemini roles to Claude Code roles
if role == "model" {
role = "assistant"
}
if role == "function" {
role = "user"
}
if role == "tool" {
role = "user"
}
// Create message structure in Claude Code format
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
if parts := content.Get("parts"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
// Text content conversion
if text := part.Get("text"); text.Exists() {
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text.String())
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
return true
}
// Function call (from model/assistant) conversion to tool use
if fc := part.Get("functionCall"); fc.Exists() && role == "assistant" {
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
// Generate a unique tool ID and enqueue it for later matching
// with the corresponding functionResponse
toolID := genToolCallID()
pendingToolIDs = append(pendingToolIDs, toolID)
toolUse, _ = sjson.Set(toolUse, "id", toolID)
if name := fc.Get("name"); name.Exists() {
toolUse, _ = sjson.Set(toolUse, "name", name.String())
}
if args := fc.Get("args"); args.Exists() && args.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", args.Raw)
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
return true
}
// Function response (from user) conversion to tool result
if fr := part.Get("functionResponse"); fr.Exists() {
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
// Attach the oldest queued tool_id to pair the response
// with its call. If the queue is empty, generate a new id.
var toolID string
if len(pendingToolIDs) > 0 {
toolID = pendingToolIDs[0]
// Pop the first element from the queue
pendingToolIDs = pendingToolIDs[1:]
} else {
// Fallback: generate new ID if no pending tool_use found
toolID = genToolCallID()
}
toolResult, _ = sjson.Set(toolResult, "tool_use_id", toolID)
// Extract result content from the function response
if result := fr.Get("response.result"); result.Exists() {
toolResult, _ = sjson.Set(toolResult, "content", result.String())
} else if response := fr.Get("response"); response.Exists() {
toolResult, _ = sjson.Set(toolResult, "content", response.Raw)
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolResult)
return true
}
// Image content (inline_data) conversion to Claude Code format
if inlineData := part.Get("inline_data"); inlineData.Exists() {
imageContent := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
if mimeType := inlineData.Get("mime_type"); mimeType.Exists() {
imageContent, _ = sjson.Set(imageContent, "source.media_type", mimeType.String())
}
if data := inlineData.Get("data"); data.Exists() {
imageContent, _ = sjson.Set(imageContent, "source.data", data.String())
}
msg, _ = sjson.SetRaw(msg, "content.-1", imageContent)
return true
}
// File data conversion to text content with file info
if fileData := part.Get("file_data"); fileData.Exists() {
// For file data, we'll convert to text content with file info
textContent := `{"type":"text","text":""}`
fileInfo := "File: " + fileData.Get("file_uri").String()
if mimeType := fileData.Get("mime_type"); mimeType.Exists() {
fileInfo += " (Type: " + mimeType.String() + ")"
}
textContent, _ = sjson.Set(textContent, "text", fileInfo)
msg, _ = sjson.SetRaw(msg, "content.-1", textContent)
return true
}
return true
})
}
// Only add message if it has content
if contentArray := gjson.Get(msg, "content"); contentArray.Exists() && len(contentArray.Array()) > 0 {
out, _ = sjson.SetRaw(out, "messages.-1", msg)
}
return true
})
}
// Tools mapping: Gemini functionDeclarations -> Claude Code tools
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
var anthropicTools []interface{}
tools.ForEach(func(_, tool gjson.Result) bool {
if funcDecls := tool.Get("functionDeclarations"); funcDecls.Exists() && funcDecls.IsArray() {
funcDecls.ForEach(func(_, funcDecl gjson.Result) bool {
anthropicTool := `{"name":"","description":"","input_schema":{}}`
if name := funcDecl.Get("name"); name.Exists() {
anthropicTool, _ = sjson.Set(anthropicTool, "name", name.String())
}
if desc := funcDecl.Get("description"); desc.Exists() {
anthropicTool, _ = sjson.Set(anthropicTool, "description", desc.String())
}
if params := funcDecl.Get("parameters"); params.Exists() {
// Clean up the parameters schema for Claude Code compatibility
cleaned := params.Raw
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
} else if params = funcDecl.Get("parametersJsonSchema"); params.Exists() {
// Clean up the parameters schema for Claude Code compatibility
cleaned := params.Raw
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
cleaned, _ = sjson.Set(cleaned, "$schema", "http://json-schema.org/draft-07/schema#")
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", cleaned)
}
anthropicTools = append(anthropicTools, gjson.Parse(anthropicTool).Value())
return true
})
}
return true
})
if len(anthropicTools) > 0 {
out, _ = sjson.Set(out, "tools", anthropicTools)
}
}
// Tool config mapping from Gemini format to Claude Code format
if toolConfig := root.Get("tool_config"); toolConfig.Exists() {
if funcCalling := toolConfig.Get("function_calling_config"); funcCalling.Exists() {
if mode := funcCalling.Get("mode"); mode.Exists() {
switch mode.String() {
case "AUTO":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "NONE":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"none"}`)
case "ANY":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
}
}
}
}
// Stream setting configuration
out, _ = sjson.Set(out, "stream", stream)
// Convert tool parameter types to lowercase for Claude Code compatibility
var pathsToLower []string
toolsResult := gjson.Get(out, "tools")
util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
}
return []byte(out)
}
================================================
FILE: internal/translator/claude/gemini/claude_gemini_response.go
================================================
// Package gemini provides response translation functionality for Claude Code to Gemini API compatibility.
// This package handles the conversion of Claude Code API responses into Gemini-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by Gemini API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, and usage metadata appropriately.
package gemini
import (
"bufio"
"bytes"
"context"
"fmt"
"strings"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
dataTag = []byte("data:")
)
// ConvertAnthropicResponseToGeminiParams holds parameters for response conversion
// It also carries minimal streaming state across calls to assemble tool_use input_json_delta.
// This structure maintains state information needed for proper conversion of streaming responses
// from Claude Code format to Gemini format, particularly for handling tool calls that span
// multiple streaming events.
type ConvertAnthropicResponseToGeminiParams struct {
Model string
CreatedAt int64
ResponseID string
LastStorageOutput string
IsStreaming bool
// Streaming state for tool_use assembly
// Keyed by content_block index from Claude SSE events
ToolUseNames map[int]string // function/tool name per block index
ToolUseArgs map[int]*strings.Builder // accumulates partial_json across deltas
}
// ConvertClaudeResponseToGemini converts Claude Code streaming response format to Gemini format.
// This function processes various Claude Code event types and transforms them into Gemini-compatible JSON responses.
// It handles text content, tool calls, reasoning content, and usage metadata, outputting responses that match
// the Gemini API format. The function supports incremental updates for streaming responses and maintains
// state information to properly assemble multi-part tool calls.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Claude Code API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
func ConvertClaudeResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertAnthropicResponseToGeminiParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
}
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
root := gjson.ParseBytes(rawJSON)
eventType := root.Get("type").String()
// Base Gemini response template with default values
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
// Set model version
if (*param).(*ConvertAnthropicResponseToGeminiParams).Model != "" {
// Map Claude model names back to Gemini model names
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertAnthropicResponseToGeminiParams).Model)
}
// Set response ID and creation time
if (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID != "" {
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID)
}
// Set creation time to current time if not provided
if (*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt == 0 {
(*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt = time.Now().Unix()
}
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertAnthropicResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
switch eventType {
case "message_start":
// Initialize response with message metadata when a new message begins
if message := root.Get("message"); message.Exists() {
(*param).(*ConvertAnthropicResponseToGeminiParams).ResponseID = message.Get("id").String()
(*param).(*ConvertAnthropicResponseToGeminiParams).Model = message.Get("model").String()
}
return []string{}
case "content_block_start":
// Start of a content block - record tool_use name by index for functionCall assembly
if cb := root.Get("content_block"); cb.Exists() {
if cb.Get("type").String() == "tool_use" {
idx := int(root.Get("index").Int())
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames == nil {
(*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames = map[int]string{}
}
if name := cb.Get("name"); name.Exists() {
(*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames[idx] = name.String()
}
}
}
return []string{}
case "content_block_delta":
// Handle content delta (text, thinking, or tool use arguments)
if delta := root.Get("delta"); delta.Exists() {
deltaType := delta.Get("type").String()
switch deltaType {
case "text_delta":
// Regular text content delta for normal response text
if text := delta.Get("text"); text.Exists() && text.String() != "" {
textPart := `{"text":""}`
textPart, _ = sjson.Set(textPart, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", textPart)
}
case "thinking_delta":
// Thinking/reasoning content delta for models with reasoning capabilities
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
thinkingPart := `{"thought":true,"text":""}`
thinkingPart, _ = sjson.Set(thinkingPart, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", thinkingPart)
}
case "input_json_delta":
// Tool use input delta - accumulate partial_json by index for later assembly at content_block_stop
idx := int(root.Get("index").Int())
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs == nil {
(*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs = map[int]*strings.Builder{}
}
b, ok := (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs[idx]
if !ok || b == nil {
bb := &strings.Builder{}
(*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs[idx] = bb
b = bb
}
if pj := delta.Get("partial_json"); pj.Exists() {
b.WriteString(pj.String())
}
return []string{}
}
}
return []string{template}
case "content_block_stop":
// End of content block - finalize tool calls if any
idx := int(root.Get("index").Int())
// Claude's content_block_stop often doesn't include content_block payload (see docs/response-claude.txt)
// So we finalize using accumulated state captured during content_block_start and input_json_delta.
name := ""
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames != nil {
name = (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames[idx]
}
var argsTrim string
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs != nil {
if b := (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs[idx]; b != nil {
argsTrim = strings.TrimSpace(b.String())
}
}
if name != "" || argsTrim != "" {
functionCall := `{"functionCall":{"name":"","args":{}}}`
if name != "" {
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
}
if argsTrim != "" {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsTrim)
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertAnthropicResponseToGeminiParams).LastStorageOutput = template
// cleanup used state for this index
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs != nil {
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseArgs, idx)
}
if (*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames != nil {
delete((*param).(*ConvertAnthropicResponseToGeminiParams).ToolUseNames, idx)
}
return []string{template}
}
return []string{}
case "message_delta":
// Handle message-level changes (like stop reason and usage information)
if delta := root.Get("delta"); delta.Exists() {
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
switch stopReason.String() {
case "end_turn":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
case "tool_use":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
case "max_tokens":
template, _ = sjson.Set(template, "candidates.0.finishReason", "MAX_TOKENS")
case "stop_sequence":
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
default:
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
}
}
}
if usage := root.Get("usage"); usage.Exists() {
// Basic token counts for prompt and completion
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
// Set basic usage metadata according to Gemini API specification
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", inputTokens+outputTokens)
// Add cache-related token counts if present (Claude Code API cache fields)
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", cacheCreationTokens.Int())
}
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
// Add cache read tokens to cached content count
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
template, _ = sjson.Set(template, "usageMetadata.cachedContentTokenCount", totalCacheTokens)
}
// Add thinking tokens if present (for models with reasoning capabilities)
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", thinkingTokens.Int())
}
// Set traffic type (required by Gemini API)
template, _ = sjson.Set(template, "usageMetadata.trafficType", "PROVISIONED_THROUGHPUT")
}
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
return []string{template}
case "message_stop":
// Final message with usage information - no additional output needed
return []string{}
case "error":
// Handle error responses and convert to Gemini error format
errorMsg := root.Get("error.message").String()
if errorMsg == "" {
errorMsg = "Unknown error occurred"
}
// Create error response in Gemini format
errorResponse := `{"error":{"code":400,"message":"","status":"INVALID_ARGUMENT"}}`
errorResponse, _ = sjson.Set(errorResponse, "error.message", errorMsg)
return []string{errorResponse}
default:
// Unknown event type, return empty response
return []string{}
}
}
// ConvertClaudeResponseToGeminiNonStream converts a non-streaming Claude Code response to a non-streaming Gemini response.
// This function processes the complete Claude Code response and transforms it into a single Gemini-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the Gemini API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Claude Code API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// Base Gemini response template for non-streaming with default values
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
// Set model version
template, _ = sjson.Set(template, "modelVersion", modelName)
streamingEvents := make([][]byte, 0)
scanner := bufio.NewScanner(bytes.NewReader(rawJSON))
buffer := make([]byte, 52_428_800) // 50MB
scanner.Buffer(buffer, 52_428_800)
for scanner.Scan() {
line := scanner.Bytes()
// log.Debug(string(line))
if bytes.HasPrefix(line, dataTag) {
jsonData := bytes.TrimSpace(line[5:])
streamingEvents = append(streamingEvents, jsonData)
}
}
// log.Debug("streamingEvents: ", streamingEvents)
// log.Debug("rawJSON: ", string(rawJSON))
// Initialize parameters for streaming conversion with proper state management
newParam := &ConvertAnthropicResponseToGeminiParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
LastStorageOutput: "",
IsStreaming: false,
ToolUseNames: nil,
ToolUseArgs: nil,
}
// Process each streaming event and collect parts
var allParts []string
var finalUsageJSON string
var responseID string
var createdAt int64
for _, eventData := range streamingEvents {
if len(eventData) == 0 {
continue
}
root := gjson.ParseBytes(eventData)
eventType := root.Get("type").String()
switch eventType {
case "message_start":
// Extract response metadata including ID, model, and creation time
if message := root.Get("message"); message.Exists() {
responseID = message.Get("id").String()
newParam.ResponseID = responseID
newParam.Model = message.Get("model").String()
// Set creation time to current time if not provided
createdAt = time.Now().Unix()
newParam.CreatedAt = createdAt
}
case "content_block_start":
// Prepare for content block; record tool_use name by index for later functionCall assembly
idx := int(root.Get("index").Int())
if cb := root.Get("content_block"); cb.Exists() {
if cb.Get("type").String() == "tool_use" {
if newParam.ToolUseNames == nil {
newParam.ToolUseNames = map[int]string{}
}
if name := cb.Get("name"); name.Exists() {
newParam.ToolUseNames[idx] = name.String()
}
}
}
continue
case "content_block_delta":
// Handle content delta (text, thinking, or tool input)
if delta := root.Get("delta"); delta.Exists() {
deltaType := delta.Get("type").String()
switch deltaType {
case "text_delta":
// Process regular text content
if text := delta.Get("text"); text.Exists() && text.String() != "" {
partJSON := `{"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
allParts = append(allParts, partJSON)
}
case "thinking_delta":
// Process reasoning/thinking content
if text := delta.Get("thinking"); text.Exists() && text.String() != "" {
partJSON := `{"thought":true,"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
allParts = append(allParts, partJSON)
}
case "input_json_delta":
// accumulate args partial_json for this index
idx := int(root.Get("index").Int())
if newParam.ToolUseArgs == nil {
newParam.ToolUseArgs = map[int]*strings.Builder{}
}
if _, ok := newParam.ToolUseArgs[idx]; !ok || newParam.ToolUseArgs[idx] == nil {
newParam.ToolUseArgs[idx] = &strings.Builder{}
}
if pj := delta.Get("partial_json"); pj.Exists() {
newParam.ToolUseArgs[idx].WriteString(pj.String())
}
}
}
case "content_block_stop":
// Handle tool use completion by assembling accumulated arguments
idx := int(root.Get("index").Int())
// Claude's content_block_stop often doesn't include content_block payload (see docs/response-claude.txt)
// So we finalize using accumulated state captured during content_block_start and input_json_delta.
name := ""
if newParam.ToolUseNames != nil {
name = newParam.ToolUseNames[idx]
}
var argsTrim string
if newParam.ToolUseArgs != nil {
if b := newParam.ToolUseArgs[idx]; b != nil {
argsTrim = strings.TrimSpace(b.String())
}
}
if name != "" || argsTrim != "" {
functionCallJSON := `{"functionCall":{"name":"","args":{}}}`
if name != "" {
functionCallJSON, _ = sjson.Set(functionCallJSON, "functionCall.name", name)
}
if argsTrim != "" {
functionCallJSON, _ = sjson.SetRaw(functionCallJSON, "functionCall.args", argsTrim)
}
allParts = append(allParts, functionCallJSON)
// cleanup used state for this index
if newParam.ToolUseArgs != nil {
delete(newParam.ToolUseArgs, idx)
}
if newParam.ToolUseNames != nil {
delete(newParam.ToolUseNames, idx)
}
}
case "message_delta":
// Extract final usage information using sjson for token counts and metadata
if usage := root.Get("usage"); usage.Exists() {
usageJSON := `{}`
// Basic token counts for prompt and completion
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
// Set basic usage metadata according to Gemini API specification
usageJSON, _ = sjson.Set(usageJSON, "promptTokenCount", inputTokens)
usageJSON, _ = sjson.Set(usageJSON, "candidatesTokenCount", outputTokens)
usageJSON, _ = sjson.Set(usageJSON, "totalTokenCount", inputTokens+outputTokens)
// Add cache-related token counts if present (Claude Code API cache fields)
if cacheCreationTokens := usage.Get("cache_creation_input_tokens"); cacheCreationTokens.Exists() {
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", cacheCreationTokens.Int())
}
if cacheReadTokens := usage.Get("cache_read_input_tokens"); cacheReadTokens.Exists() {
// Add cache read tokens to cached content count
existingCacheTokens := usage.Get("cache_creation_input_tokens").Int()
totalCacheTokens := existingCacheTokens + cacheReadTokens.Int()
usageJSON, _ = sjson.Set(usageJSON, "cachedContentTokenCount", totalCacheTokens)
}
// Add thinking tokens if present (for models with reasoning capabilities)
if thinkingTokens := usage.Get("thinking_tokens"); thinkingTokens.Exists() {
usageJSON, _ = sjson.Set(usageJSON, "thoughtsTokenCount", thinkingTokens.Int())
}
// Set traffic type (required by Gemini API)
usageJSON, _ = sjson.Set(usageJSON, "trafficType", "PROVISIONED_THROUGHPUT")
finalUsageJSON = usageJSON
}
}
}
// Set response metadata
if responseID != "" {
template, _ = sjson.Set(template, "responseId", responseID)
}
if createdAt > 0 {
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt, 0).Format(time.RFC3339Nano))
}
// Consolidate consecutive text parts and thinking parts for cleaner output
consolidatedParts := consolidateParts(allParts)
// Set the consolidated parts array
if len(consolidatedParts) > 0 {
partsJSON := "[]"
for _, partJSON := range consolidatedParts {
partsJSON, _ = sjson.SetRaw(partsJSON, "-1", partJSON)
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", partsJSON)
}
// Set usage metadata
if finalUsageJSON != "" {
template, _ = sjson.SetRaw(template, "usageMetadata", finalUsageJSON)
}
return template
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}
// consolidateParts merges consecutive text parts and thinking parts to create a cleaner response.
// This function processes the parts array to combine adjacent text elements and thinking elements
// into single consolidated parts, which results in a more readable and efficient response structure.
// Tool calls and other non-text parts are preserved as separate elements.
func consolidateParts(parts []string) []string {
if len(parts) == 0 {
return parts
}
var consolidated []string
var currentTextPart strings.Builder
var currentThoughtPart strings.Builder
var hasText, hasThought bool
flushText := func() {
// Flush accumulated text content to the consolidated parts array
if hasText && currentTextPart.Len() > 0 {
textPartJSON := `{"text":""}`
textPartJSON, _ = sjson.Set(textPartJSON, "text", currentTextPart.String())
consolidated = append(consolidated, textPartJSON)
currentTextPart.Reset()
hasText = false
}
}
flushThought := func() {
// Flush accumulated thinking content to the consolidated parts array
if hasThought && currentThoughtPart.Len() > 0 {
thoughtPartJSON := `{"thought":true,"text":""}`
thoughtPartJSON, _ = sjson.Set(thoughtPartJSON, "text", currentThoughtPart.String())
consolidated = append(consolidated, thoughtPartJSON)
currentThoughtPart.Reset()
hasThought = false
}
}
for _, partJSON := range parts {
part := gjson.Parse(partJSON)
if !part.Exists() || !part.IsObject() {
// Flush any pending parts and add this non-text part
flushText()
flushThought()
consolidated = append(consolidated, partJSON)
continue
}
thought := part.Get("thought")
if thought.Exists() && thought.Type == gjson.True {
// This is a thinking part - flush any pending text first
flushText() // Flush any pending text first
if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
currentThoughtPart.WriteString(text.String())
hasThought = true
}
} else if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
// This is a regular text part - flush any pending thought first
flushThought() // Flush any pending thought first
currentTextPart.WriteString(text.String())
hasText = true
} else {
// This is some other type of part (like function call) - flush both text and thought
flushText()
flushThought()
consolidated = append(consolidated, partJSON)
}
}
// Flush any remaining parts
flushThought() // Flush thought first to maintain order
flushText()
return consolidated
}
================================================
FILE: internal/translator/claude/gemini/init.go
================================================
package gemini
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Gemini,
Claude,
ConvertGeminiRequestToClaude,
interfaces.TranslateResponse{
Stream: ConvertClaudeResponseToGemini,
NonStream: ConvertClaudeResponseToGeminiNonStream,
TokenCount: GeminiTokenCount,
},
)
}
================================================
FILE: internal/translator/claude/gemini-cli/claude_gemini-cli_request.go
================================================
// Package geminiCLI provides request translation functionality for Gemini CLI to Claude Code API compatibility.
// It handles parsing and transforming Gemini CLI API requests into Claude Code API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini CLI API format and Claude Code API's expected format.
package geminiCLI
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiCLIRequestToClaude parses and transforms a Gemini CLI API request into Claude Code API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Claude Code API.
// The function performs the following transformations:
// 1. Extracts the model information from the request
// 2. Restructures the JSON to match Claude Code API format
// 3. Converts system instructions to the expected format
// 4. Delegates to the Gemini-to-Claude conversion function for further processing
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - stream: A boolean indicating if the request is for a streaming response
//
// Returns:
// - []byte: The transformed request data in Claude Code API format
func ConvertGeminiCLIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
modelResult := gjson.GetBytes(rawJSON, "model")
// Extract the inner request object and promote it to the top level
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
// Restore the model information at the top level
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
// Convert systemInstruction field to system_instruction for Claude Code compatibility
if gjson.GetBytes(rawJSON, "systemInstruction").Exists() {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw))
rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction")
}
// Delegate to the Gemini-to-Claude conversion function for further processing
return ConvertGeminiRequestToClaude(modelName, rawJSON, stream)
}
================================================
FILE: internal/translator/claude/gemini-cli/claude_gemini-cli_response.go
================================================
// Package geminiCLI provides response translation functionality for Claude Code to Gemini CLI API compatibility.
// This package handles the conversion of Claude Code API responses into Gemini CLI-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by Gemini CLI API clients.
package geminiCLI
import (
"context"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
"github.com/tidwall/sjson"
)
// ConvertClaudeResponseToGeminiCLI converts Claude Code streaming response format to Gemini CLI format.
// This function processes various Claude Code event types and transforms them into Gemini-compatible JSON responses.
// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini CLI API format.
// The function wraps each converted response in a "response" object to match the Gemini CLI API structure.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Claude Code API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
func ConvertClaudeResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
outputs := ConvertClaudeResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// Wrap each converted response in a "response" object to match Gemini CLI API structure
newOutputs := make([]string, 0)
for i := 0; i < len(outputs); i++ {
json := `{"response": {}}`
output, _ := sjson.SetRaw(json, "response", outputs[i])
newOutputs = append(newOutputs, output)
}
return newOutputs
}
// ConvertClaudeResponseToGeminiCLINonStream converts a non-streaming Claude Code response to a non-streaming Gemini CLI response.
// This function processes the complete Claude Code response and transforms it into a single Gemini-compatible
// JSON response. It wraps the converted response in a "response" object to match the Gemini CLI API structure.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Claude Code API
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: A Gemini-compatible JSON response wrapped in a response object
func ConvertClaudeResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
strJSON := ConvertClaudeResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
// Wrap the converted response in a "response" object to match Gemini CLI API structure
json := `{"response": {}}`
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
return GeminiTokenCount(ctx, count)
}
================================================
FILE: internal/translator/claude/gemini-cli/init.go
================================================
package geminiCLI
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
GeminiCLI,
Claude,
ConvertGeminiCLIRequestToClaude,
interfaces.TranslateResponse{
Stream: ConvertClaudeResponseToGeminiCLI,
NonStream: ConvertClaudeResponseToGeminiCLINonStream,
TokenCount: GeminiCLITokenCount,
},
)
}
================================================
FILE: internal/translator/claude/openai/chat-completions/claude_openai_request.go
================================================
// Package openai provides request translation functionality for OpenAI to Claude Code API compatibility.
// It handles parsing and transforming OpenAI Chat Completions API requests into Claude Code API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between OpenAI API format and Claude Code API's expected format.
package chat_completions
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"math/big"
"strings"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
user = ""
account = ""
session = ""
)
// ConvertOpenAIRequestToClaude parses and transforms an OpenAI Chat Completions API request into Claude Code API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Claude Code API.
// The function performs comprehensive transformation including:
// 1. Model name mapping and parameter extraction (max_tokens, temperature, top_p, etc.)
// 2. Message content conversion from OpenAI to Claude Code format
// 3. Tool call and tool result handling with proper ID mapping
// 4. Image data conversion from OpenAI data URLs to Claude Code base64 format
// 5. Stop sequence and streaming configuration handling
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the OpenAI API
// - stream: A boolean indicating if the request is for a streaming response
//
// Returns:
// - []byte: The transformed request data in Claude Code API format
func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
if account == "" {
u, _ := uuid.NewRandom()
account = u.String()
}
if session == "" {
u, _ := uuid.NewRandom()
session = u.String()
}
if user == "" {
sum := sha256.Sum256([]byte(account + session))
user = hex.EncodeToString(sum[:])
}
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude Code API template with default max_tokens value
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
root := gjson.ParseBytes(rawJSON)
// Convert OpenAI reasoning_effort to Claude thinking config.
if v := root.Get("reasoning_effort"); v.Exists() {
effort := strings.ToLower(strings.TrimSpace(v.String()))
if effort != "" {
mi := registry.LookupModelInfo(modelName, "claude")
supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0
supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax))
// Claude 4.6 supports adaptive thinking with output_config.effort.
// MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid
// validation errors since validate treats same-provider unsupported levels as errors.
if supportsAdaptive {
switch effort {
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
default:
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
effort = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", effort)
}
} else {
// Legacy/manual thinking (budget_tokens).
budget, ok := thinking.ConvertLevelToBudget(effort)
if ok {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
default:
if budget > 0 {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
}
}
}
}
}
}
// Helper for generating tool call IDs in the form: toolu_
// This ensures unique identifiers for tool calls in the Claude Code format
genToolCallID := func() string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var b strings.Builder
// 24 chars random suffix for uniqueness
for i := 0; i < 24; i++ {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
b.WriteByte(letters[n.Int64()])
}
return "toolu_" + b.String()
}
// Model mapping to specify which Claude Code model to use
out, _ = sjson.Set(out, "model", modelName)
// Max tokens configuration with fallback to default value
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
}
// Temperature setting for controlling response randomness
if temp := root.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
} else if topP := root.Get("top_p"); topP.Exists() {
// Top P setting for nucleus sampling (filtered out if temperature is set)
out, _ = sjson.Set(out, "top_p", topP.Float())
}
// Stop sequences configuration for custom termination conditions
if stop := root.Get("stop"); stop.Exists() {
if stop.IsArray() {
var stopSequences []string
stop.ForEach(func(_, value gjson.Result) bool {
stopSequences = append(stopSequences, value.String())
return true
})
if len(stopSequences) > 0 {
out, _ = sjson.Set(out, "stop_sequences", stopSequences)
}
} else {
out, _ = sjson.Set(out, "stop_sequences", []string{stop.String()})
}
}
// Stream configuration to enable or disable streaming responses
out, _ = sjson.Set(out, "stream", stream)
// Process messages and transform them to Claude Code format
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
messageIndex := 0
systemMessageIndex := -1
messages.ForEach(func(_, message gjson.Result) bool {
role := message.Get("role").String()
contentResult := message.Get("content")
switch role {
case "system":
if systemMessageIndex == -1 {
systemMsg := `{"role":"user","content":[]}`
out, _ = sjson.SetRaw(out, "messages.-1", systemMsg)
systemMessageIndex = messageIndex
messageIndex++
}
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", contentResult.String())
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "text" {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
out, _ = sjson.SetRaw(out, fmt.Sprintf("messages.%d.content.-1", systemMessageIndex), textPart)
}
return true
})
}
case "user", "assistant":
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
// Handle content based on its type (string or array)
if contentResult.Exists() && contentResult.Type == gjson.String && contentResult.String() != "" {
part := `{"type":"text","text":""}`
part, _ = sjson.Set(part, "text", contentResult.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
} else if contentResult.Exists() && contentResult.IsArray() {
contentResult.ForEach(func(_, part gjson.Result) bool {
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
msg, _ = sjson.SetRaw(msg, "content.-1", claudePart)
}
return true
})
}
// Handle tool calls (for assistant messages)
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() && role == "assistant" {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
if toolCall.Get("type").String() == "function" {
toolCallID := toolCall.Get("id").String()
if toolCallID == "" {
toolCallID = genToolCallID()
}
function := toolCall.Get("function")
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", toolCallID)
toolUse, _ = sjson.Set(toolUse, "name", function.Get("name").String())
// Parse arguments for the tool call
if args := function.Get("arguments"); args.Exists() {
argsStr := args.String()
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
msg, _ = sjson.SetRaw(msg, "content.-1", toolUse)
}
return true
})
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
messageIndex++
case "tool":
// Handle tool result messages conversion
toolCallID := message.Get("tool_call_id").String()
toolContentResult := message.Get("content")
msg := `{"role":"user","content":[{"type":"tool_result","tool_use_id":"","content":""}]}`
msg, _ = sjson.Set(msg, "content.0.tool_use_id", toolCallID)
toolResultContent, toolResultContentRaw := convertOpenAIToolResultContent(toolContentResult)
if toolResultContentRaw {
msg, _ = sjson.SetRaw(msg, "content.0.content", toolResultContent)
} else {
msg, _ = sjson.Set(msg, "content.0.content", toolResultContent)
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
messageIndex++
}
return true
})
}
// Tools mapping: OpenAI tools -> Claude Code tools
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() && len(tools.Array()) > 0 {
hasAnthropicTools := false
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "function" {
function := tool.Get("function")
anthropicTool := `{"name":"","description":""}`
anthropicTool, _ = sjson.Set(anthropicTool, "name", function.Get("name").String())
anthropicTool, _ = sjson.Set(anthropicTool, "description", function.Get("description").String())
// Convert parameters schema for the tool
if parameters := function.Get("parameters"); parameters.Exists() {
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
} else if parameters := function.Get("parametersJsonSchema"); parameters.Exists() {
anthropicTool, _ = sjson.SetRaw(anthropicTool, "input_schema", parameters.Raw)
}
out, _ = sjson.SetRaw(out, "tools.-1", anthropicTool)
hasAnthropicTools = true
}
return true
})
if !hasAnthropicTools {
out, _ = sjson.Delete(out, "tools")
}
}
// Tool choice mapping from OpenAI format to Claude Code format
if toolChoice := root.Get("tool_choice"); toolChoice.Exists() {
switch toolChoice.Type {
case gjson.String:
choice := toolChoice.String()
switch choice {
case "none":
// Don't set tool_choice, Claude Code will not use tools
case "auto":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "required":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
}
case gjson.JSON:
// Specific tool choice mapping
if toolChoice.Get("type").String() == "function" {
functionName := toolChoice.Get("function.name").String()
toolChoiceJSON := `{"type":"tool","name":""}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", functionName)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
}
default:
}
}
return []byte(out)
}
func convertOpenAIContentPartToClaudePart(part gjson.Result) string {
switch part.Get("type").String() {
case "text":
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.Get("text").String())
return textPart
case "image_url":
return convertOpenAIImageURLToClaudePart(part.Get("image_url.url").String())
case "file":
fileData := part.Get("file.file_data").String()
if strings.HasPrefix(fileData, "data:") {
semicolonIdx := strings.Index(fileData, ";")
commaIdx := strings.Index(fileData, ",")
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
data := fileData[commaIdx+1:]
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
docPart, _ = sjson.Set(docPart, "source.data", data)
return docPart
}
}
}
return ""
}
func convertOpenAIImageURLToClaudePart(imageURL string) string {
if imageURL == "" {
return ""
}
if strings.HasPrefix(imageURL, "data:") {
parts := strings.SplitN(imageURL, ",", 2)
if len(parts) != 2 {
return ""
}
mediaTypePart := strings.SplitN(parts[0], ";", 2)[0]
mediaType := strings.TrimPrefix(mediaTypePart, "data:")
if mediaType == "" {
mediaType = "application/octet-stream"
}
imagePart := `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
imagePart, _ = sjson.Set(imagePart, "source.media_type", mediaType)
imagePart, _ = sjson.Set(imagePart, "source.data", parts[1])
return imagePart
}
imagePart := `{"type":"image","source":{"type":"url","url":""}}`
imagePart, _ = sjson.Set(imagePart, "source.url", imageURL)
return imagePart
}
func convertOpenAIToolResultContent(content gjson.Result) (string, bool) {
if !content.Exists() {
return "", false
}
if content.Type == gjson.String {
return content.String(), false
}
if content.IsArray() {
claudeContent := "[]"
partCount := 0
content.ForEach(func(_, part gjson.Result) bool {
if part.Type == gjson.String {
textPart := `{"type":"text","text":""}`
textPart, _ = sjson.Set(textPart, "text", part.String())
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", textPart)
partCount++
return true
}
claudePart := convertOpenAIContentPartToClaudePart(part)
if claudePart != "" {
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
partCount++
}
return true
})
if partCount > 0 || len(content.Array()) == 0 {
return claudeContent, true
}
return content.Raw, false
}
if content.IsObject() {
claudePart := convertOpenAIContentPartToClaudePart(content)
if claudePart != "" {
claudeContent := "[]"
claudeContent, _ = sjson.SetRaw(claudeContent, "-1", claudePart)
return claudeContent, true
}
return content.Raw, false
}
return content.Raw, false
}
================================================
FILE: internal/translator/claude/openai/chat-completions/claude_openai_request_test.go
================================================
package chat_completions
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertOpenAIRequestToClaude_ToolResultTextAndBase64Image(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "do_work",
"arguments": "{\"a\":1}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": [
{"type": "text", "text": "tool ok"},
{
"type": "image_url",
"image_url": {
"url": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolResult := messages[1].Get("content.0")
if got := toolResult.Get("type").String(); got != "tool_result" {
t.Fatalf("Expected content[0].type %q, got %q", "tool_result", got)
}
if got := toolResult.Get("tool_use_id").String(); got != "call_1" {
t.Fatalf("Expected tool_use_id %q, got %q", "call_1", got)
}
toolContent := toolResult.Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "text" {
t.Fatalf("Expected first tool_result part type %q, got %q", "text", got)
}
if got := toolContent.Get("0.text").String(); got != "tool ok" {
t.Fatalf("Expected first tool_result part text %q, got %q", "tool ok", got)
}
if got := toolContent.Get("1.type").String(); got != "image" {
t.Fatalf("Expected second tool_result part type %q, got %q", "image", got)
}
if got := toolContent.Get("1.source.type").String(); got != "base64" {
t.Fatalf("Expected image source type %q, got %q", "base64", got)
}
if got := toolContent.Get("1.source.media_type").String(); got != "image/png" {
t.Fatalf("Expected image media type %q, got %q", "image/png", got)
}
if got := toolContent.Get("1.source.data").String(); got != "iVBORw0KGgoAAAANSUhEUg==" {
t.Fatalf("Unexpected base64 image data: %q", got)
}
}
func TestConvertOpenAIRequestToClaude_ToolResultURLImageOnly(t *testing.T) {
inputJSON := `{
"model": "gpt-4.1",
"messages": [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "do_work",
"arguments": "{\"a\":1}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": [
{
"type": "image_url",
"image_url": {
"url": "https://example.com/tool.png"
}
}
]
}
]
}`
result := ConvertOpenAIRequestToClaude("claude-sonnet-4-5", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content.0.content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool_result content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "image" {
t.Fatalf("Expected tool_result part type %q, got %q", "image", got)
}
if got := toolContent.Get("0.source.type").String(); got != "url" {
t.Fatalf("Expected image source type %q, got %q", "url", got)
}
if got := toolContent.Get("0.source.url").String(); got != "https://example.com/tool.png" {
t.Fatalf("Unexpected image URL: %q", got)
}
}
================================================
FILE: internal/translator/claude/openai/chat-completions/claude_openai_response.go
================================================
// Package openai provides response translation functionality for Claude Code to OpenAI API compatibility.
// This package handles the conversion of Claude Code API responses into OpenAI Chat Completions-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by OpenAI API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, reasoning content, and usage metadata appropriately.
package chat_completions
import (
"bytes"
"context"
"fmt"
"strings"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
dataTag = []byte("data:")
)
// ConvertAnthropicResponseToOpenAIParams holds parameters for response conversion
type ConvertAnthropicResponseToOpenAIParams struct {
CreatedAt int64
ResponseID string
FinishReason string
// Tool calls accumulator for streaming
ToolCallsAccumulator map[int]*ToolCallAccumulator
}
// ToolCallAccumulator holds the state for accumulating tool call data
type ToolCallAccumulator struct {
ID string
Name string
Arguments strings.Builder
}
// ConvertClaudeResponseToOpenAI converts Claude Code streaming response format to OpenAI Chat Completions format.
// This function processes various Claude Code event types and transforms them into OpenAI-compatible JSON responses.
// It handles text content, tool calls, reasoning content, and usage metadata, outputting responses that match
// the OpenAI API format. The function supports incremental updates for streaming responses.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Claude Code API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertAnthropicResponseToOpenAIParams{
CreatedAt: 0,
ResponseID: "",
FinishReason: "",
}
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
root := gjson.ParseBytes(rawJSON)
eventType := root.Get("type").String()
// Base OpenAI streaming response template
template := `{"id":"","object":"chat.completion.chunk","created":0,"model":"","choices":[{"index":0,"delta":{},"finish_reason":null}]}`
// Set model
if modelName != "" {
template, _ = sjson.Set(template, "model", modelName)
}
// Set response ID and creation time
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID != "" {
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
}
if (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt > 0 {
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
}
switch eventType {
case "message_start":
// Initialize response with message metadata when a new message begins
if message := root.Get("message"); message.Exists() {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID = message.Get("id").String()
(*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt = time.Now().Unix()
template, _ = sjson.Set(template, "id", (*param).(*ConvertAnthropicResponseToOpenAIParams).ResponseID)
template, _ = sjson.Set(template, "model", modelName)
template, _ = sjson.Set(template, "created", (*param).(*ConvertAnthropicResponseToOpenAIParams).CreatedAt)
// Set initial role to assistant for the response
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
// Initialize tool calls accumulator for tracking tool call progress
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
}
return []string{template}
case "content_block_start":
// Start of a content block (text, tool use, or reasoning)
if contentBlock := root.Get("content_block"); contentBlock.Exists() {
blockType := contentBlock.Get("type").String()
if blockType == "tool_use" {
// Start of tool call - initialize accumulator to track arguments
toolCallID := contentBlock.Get("id").String()
toolName := contentBlock.Get("name").String()
index := int(root.Get("index").Int())
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator == nil {
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
(*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator[index] = &ToolCallAccumulator{
ID: toolCallID,
Name: toolName,
}
// Don't output anything yet - wait for complete tool call
return []string{}
}
}
return []string{}
case "content_block_delta":
// Handle content delta (text, tool use arguments, or reasoning content)
hasContent := false
if delta := root.Get("delta"); delta.Exists() {
deltaType := delta.Get("type").String()
switch deltaType {
case "text_delta":
// Text content delta - send incremental text updates
if text := delta.Get("text"); text.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.content", text.String())
hasContent = true
}
case "thinking_delta":
// Accumulate reasoning/thinking content
if thinking := delta.Get("thinking"); thinking.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", thinking.String())
hasContent = true
}
case "input_json_delta":
// Tool use input delta - accumulate arguments for tool calls
if partialJSON := delta.Get("partial_json"); partialJSON.Exists() {
index := int(root.Get("index").Int())
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator != nil {
if accumulator, exists := (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator[index]; exists {
accumulator.Arguments.WriteString(partialJSON.String())
}
}
}
// Don't output anything yet - wait for complete tool call
return []string{}
}
}
if hasContent {
return []string{template}
} else {
return []string{}
}
case "content_block_stop":
// End of content block - output complete tool call if it's a tool_use block
index := int(root.Get("index").Int())
if (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator != nil {
if accumulator, exists := (*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator[index]; exists {
// Build complete tool call with accumulated arguments
arguments := accumulator.Arguments.String()
if arguments == "" {
arguments = "{}"
}
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.index", index)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.id", accumulator.ID)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.type", "function")
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.name", accumulator.Name)
template, _ = sjson.Set(template, "choices.0.delta.tool_calls.0.function.arguments", arguments)
// Clean up the accumulator for this index
delete((*param).(*ConvertAnthropicResponseToOpenAIParams).ToolCallsAccumulator, index)
return []string{template}
}
}
return []string{}
case "message_delta":
// Handle message-level changes including stop reason and usage
if delta := root.Get("delta"); delta.Exists() {
if stopReason := delta.Get("stop_reason"); stopReason.Exists() {
(*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason = mapAnthropicStopReasonToOpenAI(stopReason.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", (*param).(*ConvertAnthropicResponseToOpenAIParams).FinishReason)
}
}
// Handle usage information for token counts
if usage := root.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
return []string{template}
case "message_stop":
// Final message event - no additional output needed
return []string{}
case "ping":
// Ping events for keeping connection alive - no output needed
return []string{}
case "error":
// Error event - format and return error response
if errorData := root.Get("error"); errorData.Exists() {
errorJSON := `{"error":{"message":"","type":""}}`
errorJSON, _ = sjson.Set(errorJSON, "error.message", errorData.Get("message").String())
errorJSON, _ = sjson.Set(errorJSON, "error.type", errorData.Get("type").String())
return []string{errorJSON}
}
return []string{}
default:
// Unknown event type - ignore
return []string{}
}
}
// mapAnthropicStopReasonToOpenAI maps Anthropic stop reasons to OpenAI stop reasons
func mapAnthropicStopReasonToOpenAI(anthropicReason string) string {
switch anthropicReason {
case "end_turn":
return "stop"
case "tool_use":
return "tool_calls"
case "max_tokens":
return "length"
case "stop_sequence":
return "stop"
default:
return "stop"
}
}
// ConvertClaudeResponseToOpenAINonStream converts a non-streaming Claude Code response to a non-streaming OpenAI response.
// This function processes the complete Claude Code response and transforms it into a single OpenAI-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the OpenAI API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Claude Code API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
chunks := make([][]byte, 0)
lines := bytes.Split(rawJSON, []byte("\n"))
for _, line := range lines {
if !bytes.HasPrefix(line, dataTag) {
continue
}
chunks = append(chunks, bytes.TrimSpace(line[5:]))
}
// Base OpenAI non-streaming response template
out := `{"id":"","object":"chat.completion","created":0,"model":"","choices":[{"index":0,"message":{"role":"assistant","content":""},"finish_reason":"stop"}],"usage":{"prompt_tokens":0,"completion_tokens":0,"total_tokens":0}}`
var messageID string
var model string
var createdAt int64
var stopReason string
var contentParts []string
var reasoningParts []string
toolCallsAccumulator := make(map[int]*ToolCallAccumulator)
for _, chunk := range chunks {
root := gjson.ParseBytes(chunk)
eventType := root.Get("type").String()
switch eventType {
case "message_start":
// Extract initial message metadata including ID, model, and input token count
if message := root.Get("message"); message.Exists() {
messageID = message.Get("id").String()
model = message.Get("model").String()
createdAt = time.Now().Unix()
}
case "content_block_start":
// Handle different content block types at the beginning
if contentBlock := root.Get("content_block"); contentBlock.Exists() {
blockType := contentBlock.Get("type").String()
if blockType == "thinking" {
// Start of thinking/reasoning content - skip for now as it's handled in delta
continue
} else if blockType == "tool_use" {
// Initialize tool call accumulator for this index
index := int(root.Get("index").Int())
toolCallsAccumulator[index] = &ToolCallAccumulator{
ID: contentBlock.Get("id").String(),
Name: contentBlock.Get("name").String(),
}
}
}
case "content_block_delta":
// Process incremental content updates
if delta := root.Get("delta"); delta.Exists() {
deltaType := delta.Get("type").String()
switch deltaType {
case "text_delta":
// Accumulate text content
if text := delta.Get("text"); text.Exists() {
contentParts = append(contentParts, text.String())
}
case "thinking_delta":
// Accumulate reasoning/thinking content
if thinking := delta.Get("thinking"); thinking.Exists() {
reasoningParts = append(reasoningParts, thinking.String())
}
case "input_json_delta":
// Accumulate tool call arguments
if partialJSON := delta.Get("partial_json"); partialJSON.Exists() {
index := int(root.Get("index").Int())
if accumulator, exists := toolCallsAccumulator[index]; exists {
accumulator.Arguments.WriteString(partialJSON.String())
}
}
}
}
case "content_block_stop":
// Finalize tool call arguments for this index when content block ends
index := int(root.Get("index").Int())
if accumulator, exists := toolCallsAccumulator[index]; exists {
if accumulator.Arguments.Len() == 0 {
accumulator.Arguments.WriteString("{}")
}
}
case "message_delta":
// Extract stop reason and output token count when message ends
if delta := root.Get("delta"); delta.Exists() {
if sr := delta.Get("stop_reason"); sr.Exists() {
stopReason = sr.String()
}
}
if usage := root.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
}
}
// Set basic response fields including message ID, creation time, and model
out, _ = sjson.Set(out, "id", messageID)
out, _ = sjson.Set(out, "created", createdAt)
out, _ = sjson.Set(out, "model", model)
// Set message content by combining all text parts
messageContent := strings.Join(contentParts, "")
out, _ = sjson.Set(out, "choices.0.message.content", messageContent)
// Add reasoning content if available (following OpenAI reasoning format)
if len(reasoningParts) > 0 {
reasoningContent := strings.Join(reasoningParts, "")
// Add reasoning as a separate field in the message
out, _ = sjson.Set(out, "choices.0.message.reasoning", reasoningContent)
}
// Set tool calls if any were accumulated during processing
if len(toolCallsAccumulator) > 0 {
toolCallsCount := 0
maxIndex := -1
for index := range toolCallsAccumulator {
if index > maxIndex {
maxIndex = index
}
}
for i := 0; i <= maxIndex; i++ {
accumulator, exists := toolCallsAccumulator[i]
if !exists {
continue
}
arguments := accumulator.Arguments.String()
idPath := fmt.Sprintf("choices.0.message.tool_calls.%d.id", toolCallsCount)
typePath := fmt.Sprintf("choices.0.message.tool_calls.%d.type", toolCallsCount)
namePath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.name", toolCallsCount)
argumentsPath := fmt.Sprintf("choices.0.message.tool_calls.%d.function.arguments", toolCallsCount)
out, _ = sjson.Set(out, idPath, accumulator.ID)
out, _ = sjson.Set(out, typePath, "function")
out, _ = sjson.Set(out, namePath, accumulator.Name)
out, _ = sjson.Set(out, argumentsPath, arguments)
toolCallsCount++
}
if toolCallsCount > 0 {
out, _ = sjson.Set(out, "choices.0.finish_reason", "tool_calls")
} else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
}
} else {
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
}
return out
}
================================================
FILE: internal/translator/claude/openai/chat-completions/init.go
================================================
package chat_completions
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenAI,
Claude,
ConvertOpenAIRequestToClaude,
interfaces.TranslateResponse{
Stream: ConvertClaudeResponseToOpenAI,
NonStream: ConvertClaudeResponseToOpenAINonStream,
},
)
}
================================================
FILE: internal/translator/claude/openai/responses/claude_openai-responses_request.go
================================================
package responses
import (
"crypto/rand"
"crypto/sha256"
"encoding/hex"
"fmt"
"math/big"
"strings"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
user = ""
account = ""
session = ""
)
// ConvertOpenAIResponsesRequestToClaude transforms an OpenAI Responses API request
// into a Claude Messages API request using only gjson/sjson for JSON handling.
// It supports:
// - instructions -> system message
// - input[].type==message with input_text/output_text -> user/assistant messages
// - function_call -> assistant tool_use
// - function_call_output -> user tool_result
// - tools[].parameters -> tools[].input_schema
// - max_output_tokens -> max_tokens
// - stream passthrough via parameter
func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
if account == "" {
u, _ := uuid.NewRandom()
account = u.String()
}
if session == "" {
u, _ := uuid.NewRandom()
session = u.String()
}
if user == "" {
sum := sha256.Sum256([]byte(account + session))
user = hex.EncodeToString(sum[:])
}
userID := fmt.Sprintf("user_%s_account_%s_session_%s", user, account, session)
// Base Claude message payload
out := fmt.Sprintf(`{"model":"","max_tokens":32000,"messages":[],"metadata":{"user_id":"%s"}}`, userID)
root := gjson.ParseBytes(rawJSON)
// Convert OpenAI Responses reasoning.effort to Claude thinking config.
if v := root.Get("reasoning.effort"); v.Exists() {
effort := strings.ToLower(strings.TrimSpace(v.String()))
if effort != "" {
mi := registry.LookupModelInfo(modelName, "claude")
supportsAdaptive := mi != nil && mi.Thinking != nil && len(mi.Thinking.Levels) > 0
supportsMax := supportsAdaptive && thinking.HasLevel(mi.Thinking.Levels, string(thinking.LevelMax))
// Claude 4.6 supports adaptive thinking with output_config.effort.
// MapToClaudeEffort normalizes levels (e.g. minimal→low, xhigh→high) to avoid
// validation errors since validate treats same-provider unsupported levels as errors.
if supportsAdaptive {
switch effort {
case "none":
out, _ = sjson.Set(out, "thinking.type", "disabled")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
case "auto":
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Delete(out, "output_config.effort")
default:
if mapped, ok := thinking.MapToClaudeEffort(effort, supportsMax); ok {
effort = mapped
}
out, _ = sjson.Set(out, "thinking.type", "adaptive")
out, _ = sjson.Delete(out, "thinking.budget_tokens")
out, _ = sjson.Set(out, "output_config.effort", effort)
}
} else {
// Legacy/manual thinking (budget_tokens).
budget, ok := thinking.ConvertLevelToBudget(effort)
if ok {
switch budget {
case 0:
out, _ = sjson.Set(out, "thinking.type", "disabled")
case -1:
out, _ = sjson.Set(out, "thinking.type", "enabled")
default:
if budget > 0 {
out, _ = sjson.Set(out, "thinking.type", "enabled")
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
}
}
}
}
}
}
// Helper for generating tool call IDs when missing
genToolCallID := func() string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var b strings.Builder
for i := 0; i < 24; i++ {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
b.WriteByte(letters[n.Int64()])
}
return "toolu_" + b.String()
}
// Model
out, _ = sjson.Set(out, "model", modelName)
// Max tokens
if mot := root.Get("max_output_tokens"); mot.Exists() {
out, _ = sjson.Set(out, "max_tokens", mot.Int())
}
// Stream
out, _ = sjson.Set(out, "stream", stream)
// instructions -> as a leading message (use role user for Claude API compatibility)
instructionsText := ""
extractedFromSystem := false
if instr := root.Get("instructions"); instr.Exists() && instr.Type == gjson.String {
instructionsText = instr.String()
if instructionsText != "" {
sysMsg := `{"role":"user","content":""}`
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
}
}
if instructionsText == "" {
if input := root.Get("input"); input.Exists() && input.IsArray() {
input.ForEach(func(_, item gjson.Result) bool {
if strings.EqualFold(item.Get("role").String(), "system") {
var builder strings.Builder
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
textResult := part.Get("text")
text := textResult.String()
if builder.Len() > 0 && text != "" {
builder.WriteByte('\n')
}
builder.WriteString(text)
return true
})
} else if parts.Type == gjson.String {
builder.WriteString(parts.String())
}
instructionsText = builder.String()
if instructionsText != "" {
sysMsg := `{"role":"user","content":""}`
sysMsg, _ = sjson.Set(sysMsg, "content", instructionsText)
out, _ = sjson.SetRaw(out, "messages.-1", sysMsg)
extractedFromSystem = true
}
}
return instructionsText == ""
})
}
}
// input array processing
if input := root.Get("input"); input.Exists() && input.IsArray() {
input.ForEach(func(_, item gjson.Result) bool {
if extractedFromSystem && strings.EqualFold(item.Get("role").String(), "system") {
return true
}
typ := item.Get("type").String()
if typ == "" && item.Get("role").String() != "" {
typ = "message"
}
switch typ {
case "message":
// Determine role and construct Claude-compatible content parts.
var role string
var textAggregate strings.Builder
var partsJSON []string
hasImage := false
hasFile := false
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
ptype := part.Get("type").String()
switch ptype {
case "input_text", "output_text":
if t := part.Get("text"); t.Exists() {
txt := t.String()
textAggregate.WriteString(txt)
contentPart := `{"type":"text","text":""}`
contentPart, _ = sjson.Set(contentPart, "text", txt)
partsJSON = append(partsJSON, contentPart)
}
if ptype == "input_text" {
role = "user"
} else {
role = "assistant"
}
case "input_image":
url := part.Get("image_url").String()
if url == "" {
url = part.Get("url").String()
}
if url != "" {
var contentPart string
if strings.HasPrefix(url, "data:") {
trimmed := strings.TrimPrefix(url, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
mediaType := "application/octet-stream"
data := ""
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mediaType = mediaAndData[0]
}
data = mediaAndData[1]
}
if data != "" {
contentPart = `{"type":"image","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
}
} else {
contentPart = `{"type":"image","source":{"type":"url","url":""}}`
contentPart, _ = sjson.Set(contentPart, "source.url", url)
}
if contentPart != "" {
partsJSON = append(partsJSON, contentPart)
if role == "" {
role = "user"
}
hasImage = true
}
}
case "input_file":
fileData := part.Get("file_data").String()
if fileData != "" {
mediaType := "application/octet-stream"
data := fileData
if strings.HasPrefix(fileData, "data:") {
trimmed := strings.TrimPrefix(fileData, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mediaType = mediaAndData[0]
}
data = mediaAndData[1]
}
}
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
contentPart, _ = sjson.Set(contentPart, "source.data", data)
partsJSON = append(partsJSON, contentPart)
if role == "" {
role = "user"
}
hasFile = true
}
}
return true
})
} else if parts.Type == gjson.String {
textAggregate.WriteString(parts.String())
}
// Fallback to given role if content types not decisive
if role == "" {
r := item.Get("role").String()
switch r {
case "user", "assistant", "system":
role = r
default:
role = "user"
}
}
if len(partsJSON) > 0 {
msg := `{"role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
if len(partsJSON) == 1 && !hasImage && !hasFile {
// Preserve legacy behavior for single text content
msg, _ = sjson.Delete(msg, "content")
textPart := gjson.Parse(partsJSON[0])
msg, _ = sjson.Set(msg, "content", textPart.Get("text").String())
} else {
for _, partJSON := range partsJSON {
msg, _ = sjson.SetRaw(msg, "content.-1", partJSON)
}
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
} else if textAggregate.Len() > 0 || role == "system" {
msg := `{"role":"","content":""}`
msg, _ = sjson.Set(msg, "role", role)
msg, _ = sjson.Set(msg, "content", textAggregate.String())
out, _ = sjson.SetRaw(out, "messages.-1", msg)
}
case "function_call":
// Map to assistant tool_use
callID := item.Get("call_id").String()
if callID == "" {
callID = genToolCallID()
}
name := item.Get("name").String()
argsStr := item.Get("arguments").String()
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", callID)
toolUse, _ = sjson.Set(toolUse, "name", name)
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
}
}
asst := `{"role":"assistant","content":[]}`
asst, _ = sjson.SetRaw(asst, "content.-1", toolUse)
out, _ = sjson.SetRaw(out, "messages.-1", asst)
case "function_call_output":
// Map to user tool_result
callID := item.Get("call_id").String()
outputStr := item.Get("output").String()
toolResult := `{"type":"tool_result","tool_use_id":"","content":""}`
toolResult, _ = sjson.Set(toolResult, "tool_use_id", callID)
toolResult, _ = sjson.Set(toolResult, "content", outputStr)
usr := `{"role":"user","content":[]}`
usr, _ = sjson.SetRaw(usr, "content.-1", toolResult)
out, _ = sjson.SetRaw(out, "messages.-1", usr)
}
return true
})
}
// tools mapping: parameters -> input_schema
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
toolsJSON := "[]"
tools.ForEach(func(_, tool gjson.Result) bool {
tJSON := `{"name":"","description":"","input_schema":{}}`
if n := tool.Get("name"); n.Exists() {
tJSON, _ = sjson.Set(tJSON, "name", n.String())
}
if d := tool.Get("description"); d.Exists() {
tJSON, _ = sjson.Set(tJSON, "description", d.String())
}
if params := tool.Get("parameters"); params.Exists() {
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
} else if params = tool.Get("parametersJsonSchema"); params.Exists() {
tJSON, _ = sjson.SetRaw(tJSON, "input_schema", params.Raw)
}
toolsJSON, _ = sjson.SetRaw(toolsJSON, "-1", tJSON)
return true
})
if gjson.Parse(toolsJSON).IsArray() && len(gjson.Parse(toolsJSON).Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", toolsJSON)
}
}
// Map tool_choice similar to Chat Completions translator (optional in docs, safe to handle)
if toolChoice := root.Get("tool_choice"); toolChoice.Exists() {
switch toolChoice.Type {
case gjson.String:
switch toolChoice.String() {
case "auto":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"auto"}`)
case "none":
// Leave unset; implies no tools
case "required":
out, _ = sjson.SetRaw(out, "tool_choice", `{"type":"any"}`)
}
case gjson.JSON:
if toolChoice.Get("type").String() == "function" {
fn := toolChoice.Get("function.name").String()
toolChoiceJSON := `{"name":"","type":"tool"}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "name", fn)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
}
default:
}
}
return []byte(out)
}
================================================
FILE: internal/translator/claude/openai/responses/claude_openai-responses_response.go
================================================
package responses
import (
"bufio"
"bytes"
"context"
"fmt"
"strings"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type claudeToResponsesState struct {
Seq int
ResponseID string
CreatedAt int64
CurrentMsgID string
CurrentFCID string
InTextBlock bool
InFuncBlock bool
FuncArgsBuf map[int]*strings.Builder // index -> args
// function call bookkeeping for output aggregation
FuncNames map[int]string // index -> function name
FuncCallIDs map[int]string // index -> call id
// message text aggregation
TextBuf strings.Builder
// reasoning state
ReasoningActive bool
ReasoningItemID string
ReasoningBuf strings.Builder
ReasoningPartAdded bool
ReasoningIndex int
// usage aggregation
InputTokens int64
OutputTokens int64
UsageSeen bool
}
var dataTag = []byte("data:")
func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) {
return originalRequestRawJSON
}
if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) {
return requestRawJSON
}
return nil
}
func emitEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
}
// ConvertClaudeResponseToOpenAIResponses converts Claude SSE to OpenAI Responses SSE events.
func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &claudeToResponsesState{FuncArgsBuf: make(map[int]*strings.Builder), FuncNames: make(map[int]string), FuncCallIDs: make(map[int]string)}
}
st := (*param).(*claudeToResponsesState)
// Expect `data: {..}` from Claude clients
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
root := gjson.ParseBytes(rawJSON)
ev := root.Get("type").String()
var out []string
nextSeq := func() int { st.Seq++; return st.Seq }
switch ev {
case "message_start":
if msg := root.Get("message"); msg.Exists() {
st.ResponseID = msg.Get("id").String()
st.CreatedAt = time.Now().Unix()
// Reset per-message aggregation state
st.TextBuf.Reset()
st.ReasoningBuf.Reset()
st.ReasoningActive = false
st.InTextBlock = false
st.InFuncBlock = false
st.CurrentMsgID = ""
st.CurrentFCID = ""
st.ReasoningItemID = ""
st.ReasoningIndex = 0
st.ReasoningPartAdded = false
st.FuncArgsBuf = make(map[int]*strings.Builder)
st.FuncNames = make(map[int]string)
st.FuncCallIDs = make(map[int]string)
st.InputTokens = 0
st.OutputTokens = 0
st.UsageSeen = false
if usage := msg.Get("usage"); usage.Exists() {
if v := usage.Get("input_tokens"); v.Exists() {
st.InputTokens = v.Int()
st.UsageSeen = true
}
if v := usage.Get("output_tokens"); v.Exists() {
st.OutputTokens = v.Int()
st.UsageSeen = true
}
}
// response.created
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
created, _ = sjson.Set(created, "sequence_number", nextSeq())
created, _ = sjson.Set(created, "response.id", st.ResponseID)
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.created", created))
// response.in_progress
inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`
inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq())
inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID)
inprog, _ = sjson.Set(inprog, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.in_progress", inprog))
}
case "content_block_start":
cb := root.Get("content_block")
if !cb.Exists() {
return out
}
idx := int(root.Get("index").Int())
typ := cb.Get("type").String()
if typ == "text" {
// open message item + content part
st.InTextBlock = true
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "item.id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_item.added", item))
part := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
part, _ = sjson.Set(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.content_part.added", part))
} else if typ == "tool_use" {
st.InFuncBlock = true
st.CurrentFCID = cb.Get("id").String()
name := cb.Get("name").String()
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
item, _ = sjson.Set(item, "item.call_id", st.CurrentFCID)
item, _ = sjson.Set(item, "item.name", name)
out = append(out, emitEvent("response.output_item.added", item))
if st.FuncArgsBuf[idx] == nil {
st.FuncArgsBuf[idx] = &strings.Builder{}
}
// record function metadata for aggregation
st.FuncCallIDs[idx] = st.CurrentFCID
st.FuncNames[idx] = name
} else if typ == "thinking" {
// start reasoning item
st.ReasoningActive = true
st.ReasoningIndex = idx
st.ReasoningBuf.Reset()
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
out = append(out, emitEvent("response.output_item.added", item))
// add a summary part placeholder
part := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
part, _ = sjson.Set(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", st.ReasoningItemID)
part, _ = sjson.Set(part, "output_index", idx)
out = append(out, emitEvent("response.reasoning_summary_part.added", part))
st.ReasoningPartAdded = true
}
case "content_block_delta":
d := root.Get("delta")
if !d.Exists() {
return out
}
dt := d.Get("type").String()
if dt == "text_delta" {
if t := d.Get("text"); t.Exists() {
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
msg, _ = sjson.Set(msg, "delta", t.String())
out = append(out, emitEvent("response.output_text.delta", msg))
// aggregate text for response.output
st.TextBuf.WriteString(t.String())
}
} else if dt == "input_json_delta" {
idx := int(root.Get("index").Int())
if pj := d.Get("partial_json"); pj.Exists() {
if st.FuncArgsBuf[idx] == nil {
st.FuncArgsBuf[idx] = &strings.Builder{}
}
st.FuncArgsBuf[idx].WriteString(pj.String())
msg := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
msg, _ = sjson.Set(msg, "output_index", idx)
msg, _ = sjson.Set(msg, "delta", pj.String())
out = append(out, emitEvent("response.function_call_arguments.delta", msg))
}
} else if dt == "thinking_delta" {
if st.ReasoningActive {
if t := d.Get("thinking"); t.Exists() {
st.ReasoningBuf.WriteString(t.String())
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.Set(msg, "delta", t.String())
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
}
}
}
case "content_block_stop":
idx := int(root.Get("index").Int())
if st.InTextBlock {
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_item.done", final))
st.InTextBlock = false
} else if st.InFuncBlock {
args := "{}"
if buf := st.FuncArgsBuf[idx]; buf != nil {
if buf.Len() > 0 {
args = buf.String()
}
}
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.CurrentFCID))
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
fcDone, _ = sjson.Set(fcDone, "arguments", args)
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.CurrentFCID))
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.CurrentFCID)
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
out = append(out, emitEvent("response.output_item.done", itemDone))
st.InFuncBlock = false
} else if st.ReasoningActive {
full := st.ReasoningBuf.String()
textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq())
textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningItemID)
textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex)
textDone, _ = sjson.Set(textDone, "text", full)
out = append(out, emitEvent("response.reasoning_summary_text.done", textDone))
partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningItemID)
partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex)
partDone, _ = sjson.Set(partDone, "part.text", full)
out = append(out, emitEvent("response.reasoning_summary_part.done", partDone))
st.ReasoningActive = false
st.ReasoningPartAdded = false
}
case "message_delta":
if usage := root.Get("usage"); usage.Exists() {
if v := usage.Get("output_tokens"); v.Exists() {
st.OutputTokens = v.Int()
st.UsageSeen = true
}
if v := usage.Get("input_tokens"); v.Exists() {
st.InputTokens = v.Int()
st.UsageSeen = true
}
}
case "message_stop":
completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`
completed, _ = sjson.Set(completed, "sequence_number", nextSeq())
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
// Inject original request fields into response as per docs/response.completed.json
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.Set(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.Set(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.Set(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.Set(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.Set(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.Set(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.Set(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.Set(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.Set(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.Set(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.Set(completed, "response.metadata", v.Value())
}
}
// Build response.output from aggregated state
outputsWrapper := `{"arr":[]}`
// reasoning item (if any)
if st.ReasoningBuf.Len() > 0 || st.ReasoningPartAdded {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
// assistant message item (if any text)
if st.TextBuf.Len() > 0 || st.InTextBlock || st.CurrentMsgID != "" {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
// function_call items (in ascending index order for determinism)
if len(st.FuncArgsBuf) > 0 {
// collect indices
idxs := make([]int, 0, len(st.FuncArgsBuf))
for idx := range st.FuncArgsBuf {
idxs = append(idxs, idx)
}
// simple sort (small N), avoid adding new imports
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, idx := range idxs {
args := ""
if b := st.FuncArgsBuf[idx]; b != nil {
args = b.String()
}
callID := st.FuncCallIDs[idx]
name := st.FuncNames[idx]
if callID == "" && st.CurrentFCID != "" {
callID = st.CurrentFCID
}
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
}
reasoningTokens := int64(0)
if st.ReasoningBuf.Len() > 0 {
reasoningTokens = int64(st.ReasoningBuf.Len() / 4)
}
usagePresent := st.UsageSeen || reasoningTokens > 0
if usagePresent {
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.InputTokens)
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", 0)
completed, _ = sjson.Set(completed, "response.usage.output_tokens", st.OutputTokens)
if reasoningTokens > 0 {
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", reasoningTokens)
}
total := st.InputTokens + st.OutputTokens
if total > 0 || st.UsageSeen {
completed, _ = sjson.Set(completed, "response.usage.total_tokens", total)
}
}
out = append(out, emitEvent("response.completed", completed))
}
return out
}
// ConvertClaudeResponseToOpenAIResponsesNonStream aggregates Claude SSE into a single OpenAI Responses JSON.
func ConvertClaudeResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
// Aggregate Claude SSE lines into a single OpenAI Responses JSON (non-stream)
// We follow the same aggregation logic as the streaming variant but produce
// one final object matching docs/out.json structure.
// Collect SSE data: lines start with "data: "; ignore others
var chunks [][]byte
{
// Use a simple scanner to iterate through raw bytes
// Note: extremely large responses may require increasing the buffer
scanner := bufio.NewScanner(bytes.NewReader(rawJSON))
buf := make([]byte, 52_428_800) // 50MB
scanner.Buffer(buf, 52_428_800)
for scanner.Scan() {
line := scanner.Bytes()
if !bytes.HasPrefix(line, dataTag) {
continue
}
chunks = append(chunks, line[len(dataTag):])
}
}
// Base OpenAI Responses (non-stream) object
out := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null,"output":[],"usage":{"input_tokens":0,"input_tokens_details":{"cached_tokens":0},"output_tokens":0,"output_tokens_details":{},"total_tokens":0}}`
// Aggregation state
var (
responseID string
createdAt int64
currentMsgID string
currentFCID string
textBuf strings.Builder
reasoningBuf strings.Builder
reasoningActive bool
reasoningItemID string
inputTokens int64
outputTokens int64
)
// Per-index tool call aggregation
type toolState struct {
id string
name string
args strings.Builder
}
toolCalls := make(map[int]*toolState)
// Walk through SSE chunks to fill state
for _, ch := range chunks {
root := gjson.ParseBytes(ch)
ev := root.Get("type").String()
switch ev {
case "message_start":
if msg := root.Get("message"); msg.Exists() {
responseID = msg.Get("id").String()
createdAt = time.Now().Unix()
if usage := msg.Get("usage"); usage.Exists() {
inputTokens = usage.Get("input_tokens").Int()
}
}
case "content_block_start":
cb := root.Get("content_block")
if !cb.Exists() {
continue
}
idx := int(root.Get("index").Int())
typ := cb.Get("type").String()
switch typ {
case "text":
currentMsgID = "msg_" + responseID + "_0"
case "tool_use":
currentFCID = cb.Get("id").String()
name := cb.Get("name").String()
if toolCalls[idx] == nil {
toolCalls[idx] = &toolState{id: currentFCID, name: name}
} else {
toolCalls[idx].id = currentFCID
toolCalls[idx].name = name
}
case "thinking":
reasoningActive = true
reasoningItemID = fmt.Sprintf("rs_%s_%d", responseID, idx)
}
case "content_block_delta":
d := root.Get("delta")
if !d.Exists() {
continue
}
dt := d.Get("type").String()
switch dt {
case "text_delta":
if t := d.Get("text"); t.Exists() {
textBuf.WriteString(t.String())
}
case "input_json_delta":
if pj := d.Get("partial_json"); pj.Exists() {
idx := int(root.Get("index").Int())
if toolCalls[idx] == nil {
toolCalls[idx] = &toolState{}
}
toolCalls[idx].args.WriteString(pj.String())
}
case "thinking_delta":
if reasoningActive {
if t := d.Get("thinking"); t.Exists() {
reasoningBuf.WriteString(t.String())
}
}
}
case "content_block_stop":
// Nothing special to finalize for non-stream aggregation
_ = root
case "message_delta":
if usage := root.Get("usage"); usage.Exists() {
outputTokens = usage.Get("output_tokens").Int()
}
}
}
// Populate base fields
out, _ = sjson.Set(out, "id", responseID)
out, _ = sjson.Set(out, "created_at", createdAt)
// Inject request echo fields as top-level (similar to streaming variant)
reqBytes := pickRequestJSON(originalRequestRawJSON, requestRawJSON)
if len(reqBytes) > 0 {
req := gjson.ParseBytes(reqBytes)
if v := req.Get("instructions"); v.Exists() {
out, _ = sjson.Set(out, "instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
out, _ = sjson.Set(out, "max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
out, _ = sjson.Set(out, "max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
out, _ = sjson.Set(out, "model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
out, _ = sjson.Set(out, "parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
out, _ = sjson.Set(out, "previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
out, _ = sjson.Set(out, "prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
out, _ = sjson.Set(out, "reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
out, _ = sjson.Set(out, "safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
out, _ = sjson.Set(out, "service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
out, _ = sjson.Set(out, "store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
out, _ = sjson.Set(out, "temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
out, _ = sjson.Set(out, "text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
out, _ = sjson.Set(out, "tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
out, _ = sjson.Set(out, "tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
out, _ = sjson.Set(out, "top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
out, _ = sjson.Set(out, "top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
out, _ = sjson.Set(out, "truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
out, _ = sjson.Set(out, "user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
out, _ = sjson.Set(out, "metadata", v.Value())
}
}
// Build output array
outputsWrapper := `{"arr":[]}`
if reasoningBuf.Len() > 0 {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", reasoningItemID)
item, _ = sjson.Set(item, "summary.0.text", reasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
if currentMsgID != "" || textBuf.Len() > 0 {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", currentMsgID)
item, _ = sjson.Set(item, "content.0.text", textBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
if len(toolCalls) > 0 {
// Preserve index order
idxs := make([]int, 0, len(toolCalls))
for i := range toolCalls {
idxs = append(idxs, i)
}
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, i := range idxs {
st := toolCalls[i]
args := st.args.String()
if args == "" {
args = "{}"
}
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", st.id))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", st.id)
item, _ = sjson.Set(item, "name", st.name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
out, _ = sjson.SetRaw(out, "output", gjson.Get(outputsWrapper, "arr").Raw)
}
// Usage
total := inputTokens + outputTokens
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", total)
if reasoningBuf.Len() > 0 {
// Rough estimate similar to chat completions
reasoningTokens := int64(len(reasoningBuf.String()) / 4)
if reasoningTokens > 0 {
out, _ = sjson.Set(out, "usage.output_tokens_details.reasoning_tokens", reasoningTokens)
}
}
return out
}
================================================
FILE: internal/translator/claude/openai/responses/init.go
================================================
package responses
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenaiResponse,
Claude,
ConvertOpenAIResponsesRequestToClaude,
interfaces.TranslateResponse{
Stream: ConvertClaudeResponseToOpenAIResponses,
NonStream: ConvertClaudeResponseToOpenAIResponsesNonStream,
},
)
}
================================================
FILE: internal/translator/codex/claude/codex_claude_request.go
================================================
// Package claude provides request translation functionality for Claude Code API compatibility.
// It handles parsing and transforming Claude Code API requests into the internal client format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package also performs JSON data cleaning and transformation to ensure compatibility
// between Claude Code API format and the internal client's expected format.
package claude
import (
"fmt"
"strconv"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertClaudeRequestToCodex parses and transforms a Claude Code API request into the internal client format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the internal client.
// The function performs the following transformations:
// 1. Sets up a template with the model name and empty instructions field
// 2. Processes system messages and converts them to developer input content
// 3. Transforms message contents (text, image, tool_use, tool_result) to appropriate formats
// 4. Converts tools declarations to the expected format
// 5. Adds additional configuration parameters for the Codex API
// 6. Maps Claude thinking configuration to Codex reasoning settings
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the Claude Code API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in internal client format
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
template := `{"model":"","instructions":"","input":[]}`
rootResult := gjson.ParseBytes(rawJSON)
template, _ = sjson.Set(template, "model", modelName)
// Process system messages and convert them to input content format.
systemsResult := rootResult.Get("system")
if systemsResult.Exists() {
message := `{"type":"message","role":"developer","content":[]}`
contentIndex := 0
appendSystemText := func(text string) {
if text == "" || strings.HasPrefix(text, "x-anthropic-billing-header: ") {
return
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_text")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
}
if systemsResult.Type == gjson.String {
appendSystemText(systemsResult.String())
} else if systemsResult.IsArray() {
systemResults := systemsResult.Array()
for i := 0; i < len(systemResults); i++ {
systemResult := systemResults[i]
if systemResult.Get("type").String() == "text" {
appendSystemText(systemResult.Get("text").String())
}
}
}
if contentIndex > 0 {
template, _ = sjson.SetRaw(template, "input.-1", message)
}
}
// Process messages and transform their contents to appropriate formats.
messagesResult := rootResult.Get("messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
messageRole := messageResult.Get("role").String()
newMessage := func() string {
msg := `{"type": "message","role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", messageRole)
return msg
}
message := newMessage()
contentIndex := 0
hasContent := false
flushMessage := func() {
if hasContent {
template, _ = sjson.SetRaw(template, "input.-1", message)
message = newMessage()
contentIndex = 0
hasContent = false
}
}
appendTextContent := func(text string) {
partType := "input_text"
if messageRole == "assistant" {
partType = "output_text"
}
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), partType)
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", contentIndex), text)
contentIndex++
hasContent = true
}
appendImageContent := func(dataURL string) {
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", contentIndex), "input_image")
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.image_url", contentIndex), dataURL)
contentIndex++
hasContent = true
}
messageContentsResult := messageResult.Get("content")
if messageContentsResult.IsArray() {
messageContentResults := messageContentsResult.Array()
for j := 0; j < len(messageContentResults); j++ {
messageContentResult := messageContentResults[j]
contentType := messageContentResult.Get("type").String()
switch contentType {
case "text":
appendTextContent(messageContentResult.Get("text").String())
case "image":
sourceResult := messageContentResult.Get("source")
if sourceResult.Exists() {
data := sourceResult.Get("data").String()
if data == "" {
data = sourceResult.Get("base64").String()
}
if data != "" {
mediaType := sourceResult.Get("media_type").String()
if mediaType == "" {
mediaType = sourceResult.Get("mime_type").String()
}
if mediaType == "" {
mediaType = "application/octet-stream"
}
dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
appendImageContent(dataURL)
}
}
case "tool_use":
flushMessage()
functionCallMessage := `{"type":"function_call"}`
functionCallMessage, _ = sjson.Set(functionCallMessage, "call_id", messageContentResult.Get("id").String())
{
name := messageContentResult.Get("name").String()
toolMap := buildReverseMapFromClaudeOriginalToShort(rawJSON)
if short, ok := toolMap[name]; ok {
name = short
} else {
name = shortenNameIfNeeded(name)
}
functionCallMessage, _ = sjson.Set(functionCallMessage, "name", name)
}
functionCallMessage, _ = sjson.Set(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
template, _ = sjson.SetRaw(template, "input.-1", functionCallMessage)
case "tool_result":
flushMessage()
functionCallOutputMessage := `{"type":"function_call_output"}`
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
contentResult := messageContentResult.Get("content")
if contentResult.IsArray() {
toolResultContentIndex := 0
toolResultContent := `[]`
contentResults := contentResult.Array()
for k := 0; k < len(contentResults); k++ {
toolResultContentType := contentResults[k].Get("type").String()
if toolResultContentType == "image" {
sourceResult := contentResults[k].Get("source")
if sourceResult.Exists() {
data := sourceResult.Get("data").String()
if data == "" {
data = sourceResult.Get("base64").String()
}
if data != "" {
mediaType := sourceResult.Get("media_type").String()
if mediaType == "" {
mediaType = sourceResult.Get("mime_type").String()
}
if mediaType == "" {
mediaType = "application/octet-stream"
}
dataURL := fmt.Sprintf("data:%s;base64,%s", mediaType, data)
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_image")
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.image_url", toolResultContentIndex), dataURL)
toolResultContentIndex++
}
}
} else if toolResultContentType == "text" {
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.type", toolResultContentIndex), "input_text")
toolResultContent, _ = sjson.Set(toolResultContent, fmt.Sprintf("%d.text", toolResultContentIndex), contentResults[k].Get("text").String())
toolResultContentIndex++
}
}
if toolResultContent != `[]` {
functionCallOutputMessage, _ = sjson.SetRaw(functionCallOutputMessage, "output", toolResultContent)
} else {
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
}
} else {
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
}
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage)
}
}
flushMessage()
} else if messageContentsResult.Type == gjson.String {
appendTextContent(messageContentsResult.String())
flushMessage()
}
}
}
// Convert tools declarations to the expected format for the Codex API.
toolsResult := rootResult.Get("tools")
if toolsResult.IsArray() {
template, _ = sjson.SetRaw(template, "tools", `[]`)
template, _ = sjson.Set(template, "tool_choice", `auto`)
toolResults := toolsResult.Array()
// Build short name map from declared tools
var names []string
for i := 0; i < len(toolResults); i++ {
n := toolResults[i].Get("name").String()
if n != "" {
names = append(names, n)
}
}
shortMap := buildShortNameMap(names)
for i := 0; i < len(toolResults); i++ {
toolResult := toolResults[i]
// Special handling: map Claude web search tool to Codex web_search
if toolResult.Get("type").String() == "web_search_20250305" {
// Replace the tool content entirely with {"type":"web_search"}
template, _ = sjson.SetRaw(template, "tools.-1", `{"type":"web_search"}`)
continue
}
tool := toolResult.Raw
tool, _ = sjson.Set(tool, "type", "function")
// Apply shortened name if needed
if v := toolResult.Get("name"); v.Exists() {
name := v.String()
if short, ok := shortMap[name]; ok {
name = short
} else {
name = shortenNameIfNeeded(name)
}
tool, _ = sjson.Set(tool, "name", name)
}
tool, _ = sjson.SetRaw(tool, "parameters", normalizeToolParameters(toolResult.Get("input_schema").Raw))
tool, _ = sjson.Delete(tool, "input_schema")
tool, _ = sjson.Delete(tool, "parameters.$schema")
tool, _ = sjson.Delete(tool, "cache_control")
tool, _ = sjson.Delete(tool, "defer_loading")
tool, _ = sjson.Set(tool, "strict", false)
template, _ = sjson.SetRaw(template, "tools.-1", tool)
}
}
// Add additional configuration parameters for the Codex API.
template, _ = sjson.Set(template, "parallel_tool_calls", true)
// Convert thinking.budget_tokens to reasoning.effort.
reasoningEffort := "medium"
if thinkingConfig := rootResult.Get("thinking"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
switch thinkingConfig.Get("type").String() {
case "enabled":
if budgetTokens := thinkingConfig.Get("budget_tokens"); budgetTokens.Exists() {
budget := int(budgetTokens.Int())
if effort, ok := thinking.ConvertBudgetToLevel(budget); ok && effort != "" {
reasoningEffort = effort
}
}
case "adaptive", "auto":
// Adaptive thinking can carry an explicit effort in output_config.effort (Claude 4.6).
// Pass through directly; ApplyThinking handles clamping to target model's levels.
effort := ""
if v := rootResult.Get("output_config.effort"); v.Exists() && v.Type == gjson.String {
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
reasoningEffort = effort
} else {
reasoningEffort = string(thinking.LevelXHigh)
}
case "disabled":
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
reasoningEffort = effort
}
}
}
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
template, _ = sjson.Set(template, "reasoning.summary", "auto")
template, _ = sjson.Set(template, "stream", true)
template, _ = sjson.Set(template, "store", false)
template, _ = sjson.Set(template, "include", []string{"reasoning.encrypted_content"})
return []byte(template)
}
// shortenNameIfNeeded applies a simple shortening rule for a single name.
func shortenNameIfNeeded(name string) string {
const limit = 64
if len(name) <= limit {
return name
}
if strings.HasPrefix(name, "mcp__") {
idx := strings.LastIndex(name, "__")
if idx > 0 {
cand := "mcp__" + name[idx+2:]
if len(cand) > limit {
return cand[:limit]
}
return cand
}
}
return name[:limit]
}
// buildShortNameMap ensures uniqueness of shortened names within a request.
func buildShortNameMap(names []string) map[string]string {
const limit = 64
used := map[string]struct{}{}
m := map[string]string{}
baseCandidate := func(n string) string {
if len(n) <= limit {
return n
}
if strings.HasPrefix(n, "mcp__") {
idx := strings.LastIndex(n, "__")
if idx > 0 {
cand := "mcp__" + n[idx+2:]
if len(cand) > limit {
cand = cand[:limit]
}
return cand
}
}
return n[:limit]
}
makeUnique := func(cand string) string {
if _, ok := used[cand]; !ok {
return cand
}
base := cand
for i := 1; ; i++ {
suffix := "_" + strconv.Itoa(i)
allowed := limit - len(suffix)
if allowed < 0 {
allowed = 0
}
tmp := base
if len(tmp) > allowed {
tmp = tmp[:allowed]
}
tmp = tmp + suffix
if _, ok := used[tmp]; !ok {
return tmp
}
}
}
for _, n := range names {
cand := baseCandidate(n)
uniq := makeUnique(cand)
used[uniq] = struct{}{}
m[n] = uniq
}
return m
}
// buildReverseMapFromClaudeOriginalToShort builds original->short map, used to map tool_use names to short.
func buildReverseMapFromClaudeOriginalToShort(original []byte) map[string]string {
tools := gjson.GetBytes(original, "tools")
m := map[string]string{}
if !tools.IsArray() {
return m
}
var names []string
arr := tools.Array()
for i := 0; i < len(arr); i++ {
n := arr[i].Get("name").String()
if n != "" {
names = append(names, n)
}
}
if len(names) > 0 {
m = buildShortNameMap(names)
}
return m
}
// normalizeToolParameters ensures object schemas contain at least an empty properties map.
func normalizeToolParameters(raw string) string {
raw = strings.TrimSpace(raw)
if raw == "" || raw == "null" || !gjson.Valid(raw) {
return `{"type":"object","properties":{}}`
}
schema := raw
result := gjson.Parse(raw)
schemaType := result.Get("type").String()
if schemaType == "" {
schema, _ = sjson.Set(schema, "type", "object")
schemaType = "object"
}
if schemaType == "object" && !result.Get("properties").Exists() {
schema, _ = sjson.SetRaw(schema, "properties", `{}`)
}
return schema
}
================================================
FILE: internal/translator/codex/claude/codex_claude_request_test.go
================================================
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToCodex_SystemMessageScenarios(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantHasDeveloper bool
wantTexts []string
}{
{
name: "No system field",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "Empty string system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: false,
},
{
name: "String system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "Be helpful",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Be helpful"},
},
{
name: "Array system field with filtered billing header",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "text", "text": "x-anthropic-billing-header: tenant-123"},
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"}
],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasDeveloper: true,
wantTexts: []string{"Block 1", "Block 2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToCodex("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
inputs := resultJSON.Get("input").Array()
hasDeveloper := len(inputs) > 0 && inputs[0].Get("role").String() == "developer"
if hasDeveloper != tt.wantHasDeveloper {
t.Fatalf("got hasDeveloper = %v, want %v. Output: %s", hasDeveloper, tt.wantHasDeveloper, resultJSON.Get("input").Raw)
}
if !tt.wantHasDeveloper {
return
}
content := inputs[0].Get("content").Array()
if len(content) != len(tt.wantTexts) {
t.Fatalf("got %d system content items, want %d. Content: %s", len(content), len(tt.wantTexts), inputs[0].Get("content").Raw)
}
for i, wantText := range tt.wantTexts {
if gotType := content[i].Get("type").String(); gotType != "input_text" {
t.Fatalf("content[%d] type = %q, want %q", i, gotType, "input_text")
}
if gotText := content[i].Get("text").String(); gotText != wantText {
t.Fatalf("content[%d] text = %q, want %q", i, gotText, wantText)
}
}
})
}
}
================================================
FILE: internal/translator/codex/claude/codex_claude_response.go
================================================
// Package claude provides response translation functionality for Codex to Claude Code API compatibility.
// This package handles the conversion of Codex API responses into Claude Code-compatible
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
// different response types including text content, thinking processes, and function calls.
// The translation ensures proper sequencing of SSE events and maintains state across
// multiple response chunks to provide a seamless streaming experience.
package claude
import (
"bytes"
"context"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
dataTag = []byte("data:")
)
// ConvertCodexResponseToClaudeParams holds parameters for response conversion.
type ConvertCodexResponseToClaudeParams struct {
HasToolCall bool
BlockIndex int
HasReceivedArgumentsDelta bool
}
// ConvertCodexResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates Codex API responses
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
// and handles state transitions between content blocks, thinking processes, and function calls.
//
// Response type states: 0=none, 1=content, 2=thinking, 3=function
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertCodexResponseToClaudeParams{
HasToolCall: false,
BlockIndex: 0,
}
}
// log.Debugf("rawJSON: %s", string(rawJSON))
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
output := ""
rootResult := gjson.ParseBytes(rawJSON)
typeResult := rootResult.Get("type")
typeStr := typeResult.String()
template := ""
if typeStr == "response.created" {
template = `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`
template, _ = sjson.Set(template, "message.model", rootResult.Get("response.model").String())
template, _ = sjson.Set(template, "message.id", rootResult.Get("response.id").String())
output = "event: message_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.reasoning_summary_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.content_part.added" {
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.output_text.delta" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String())
output = "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.content_part.done" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.completed" {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
p := (*param).(*ConvertCodexResponseToClaudeParams).HasToolCall
stopReason := rootResult.Get("response.stop_reason").String()
if p {
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
} else if stopReason == "max_tokens" || stopReason == "stop" {
template, _ = sjson.Set(template, "delta.stop_reason", stopReason)
} else {
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
}
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(rootResult.Get("response.usage"))
template, _ = sjson.Set(template, "usage.input_tokens", inputTokens)
template, _ = sjson.Set(template, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
template, _ = sjson.Set(template, "usage.cache_read_input_tokens", cachedTokens)
}
output = "event: message_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
output += "event: message_stop\n"
output += `data: {"type":"message_stop"}`
output += "\n\n"
} else if typeStr == "response.output_item.added" {
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
(*param).(*ConvertCodexResponseToClaudeParams).HasToolCall = true
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = false
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "content_block.id", util.SanitizeClaudeToolID(itemResult.Get("call_id").String()))
{
// Restore original tool name if shortened
name := itemResult.Get("name").String()
rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
template, _ = sjson.Set(template, "content_block.name", name)
}
output = "event: content_block_start\n"
output += fmt.Sprintf("data: %s\n\n", template)
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
}
} else if typeStr == "response.output_item.done" {
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
template = `{"type":"content_block_stop","index":0}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
(*param).(*ConvertCodexResponseToClaudeParams).BlockIndex++
output = "event: content_block_stop\n"
output += fmt.Sprintf("data: %s\n\n", template)
}
} else if typeStr == "response.function_call_arguments.delta" {
(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta = true
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
} else if typeStr == "response.function_call_arguments.done" {
// Some models (e.g. gpt-5.3-codex-spark) send function call arguments
// in a single "done" event without preceding "delta" events.
// Emit the full arguments as a single input_json_delta so the
// downstream Claude client receives the complete tool input.
// When delta events were already received, skip to avoid duplicating arguments.
if !(*param).(*ConvertCodexResponseToClaudeParams).HasReceivedArgumentsDelta {
if args := rootResult.Get("arguments").String(); args != "" {
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
template, _ = sjson.Set(template, "index", (*param).(*ConvertCodexResponseToClaudeParams).BlockIndex)
template, _ = sjson.Set(template, "delta.partial_json", args)
output += "event: content_block_delta\n"
output += fmt.Sprintf("data: %s\n\n", template)
}
}
}
return []string{output}
}
// ConvertCodexResponseToClaudeNonStream converts a non-streaming Codex response to a non-streaming Claude Code response.
// This function processes the complete Codex response and transforms it into a single Claude Code-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the Claude Code API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Claude Code-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, _ []byte, rawJSON []byte, _ *any) string {
revNames := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON)
rootResult := gjson.ParseBytes(rawJSON)
if rootResult.Get("type").String() != "response.completed" {
return ""
}
responseData := rootResult.Get("response")
if !responseData.Exists() {
return ""
}
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", responseData.Get("id").String())
out, _ = sjson.Set(out, "model", responseData.Get("model").String())
inputTokens, outputTokens, cachedTokens := extractResponsesUsage(responseData.Get("usage"))
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
}
hasToolCall := false
if output := responseData.Get("output"); output.Exists() && output.IsArray() {
output.ForEach(func(_, item gjson.Result) bool {
switch item.Get("type").String() {
case "reasoning":
thinkingBuilder := strings.Builder{}
if summary := item.Get("summary"); summary.Exists() {
if summary.IsArray() {
summary.ForEach(func(_, part gjson.Result) bool {
if txt := part.Get("text"); txt.Exists() {
thinkingBuilder.WriteString(txt.String())
} else {
thinkingBuilder.WriteString(part.String())
}
return true
})
} else {
thinkingBuilder.WriteString(summary.String())
}
}
if thinkingBuilder.Len() == 0 {
if content := item.Get("content"); content.Exists() {
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
if txt := part.Get("text"); txt.Exists() {
thinkingBuilder.WriteString(txt.String())
} else {
thinkingBuilder.WriteString(part.String())
}
return true
})
} else {
thinkingBuilder.WriteString(content.String())
}
}
}
if thinkingBuilder.Len() > 0 {
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
}
case "message":
if content := item.Get("content"); content.Exists() {
if content.IsArray() {
content.ForEach(func(_, part gjson.Result) bool {
if part.Get("type").String() == "output_text" {
text := part.Get("text").String()
if text != "" {
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
return true
})
} else {
text := content.String()
if text != "" {
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", text)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
}
case "function_call":
hasToolCall = true
name := item.Get("name").String()
if original, ok := revNames[name]; ok {
name = original
}
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(item.Get("call_id").String()))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
inputRaw := "{}"
if argsStr := item.Get("arguments").String(); argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
inputRaw = argsJSON.Raw
}
}
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
}
return true
})
}
if stopReason := responseData.Get("stop_reason"); stopReason.Exists() && stopReason.String() != "" {
out, _ = sjson.Set(out, "stop_reason", stopReason.String())
} else if hasToolCall {
out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
if stopSequence := responseData.Get("stop_sequence"); stopSequence.Exists() && stopSequence.String() != "" {
out, _ = sjson.SetRaw(out, "stop_sequence", stopSequence.Raw)
}
return out
}
func extractResponsesUsage(usage gjson.Result) (int64, int64, int64) {
if !usage.Exists() || usage.Type == gjson.Null {
return 0, 0, 0
}
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
cachedTokens := usage.Get("input_tokens_details.cached_tokens").Int()
if cachedTokens > 0 {
if inputTokens >= cachedTokens {
inputTokens -= cachedTokens
} else {
inputTokens = 0
}
}
return inputTokens, outputTokens, cachedTokens
}
// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools.
func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[string]string {
tools := gjson.GetBytes(original, "tools")
rev := map[string]string{}
if !tools.IsArray() {
return rev
}
var names []string
arr := tools.Array()
for i := 0; i < len(arr); i++ {
n := arr[i].Get("name").String()
if n != "" {
names = append(names, n)
}
}
if len(names) > 0 {
m := buildShortNameMap(names)
for orig, short := range m {
rev[short] = orig
}
}
return rev
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
}
================================================
FILE: internal/translator/codex/claude/init.go
================================================
package claude
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Claude,
Codex,
ConvertClaudeRequestToCodex,
interfaces.TranslateResponse{
Stream: ConvertCodexResponseToClaude,
NonStream: ConvertCodexResponseToClaudeNonStream,
TokenCount: ClaudeTokenCount,
},
)
}
================================================
FILE: internal/translator/codex/gemini/codex_gemini_request.go
================================================
// Package gemini provides request translation functionality for Codex to Gemini API compatibility.
// It handles parsing and transforming Codex API requests into Gemini API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Codex API format and Gemini API's expected format.
package gemini
import (
"crypto/rand"
"fmt"
"math/big"
"strconv"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiRequestToCodex parses and transforms a Gemini API request into Codex API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Codex API.
// The function performs comprehensive transformation including:
// 1. Model name mapping and generation configuration extraction
// 2. System instruction conversion to Codex format
// 3. Message content conversion with proper role mapping
// 4. Tool call and tool result handling with FIFO queue for ID matching
// 5. Tool declaration and tool choice configuration mapping
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the Gemini API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Codex API format
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
// Base template
out := `{"model":"","instructions":"","input":[]}`
root := gjson.ParseBytes(rawJSON)
// Pre-compute tool name shortening map from declared functionDeclarations
shortMap := map[string]string{}
if tools := root.Get("tools"); tools.IsArray() {
var names []string
tarr := tools.Array()
for i := 0; i < len(tarr); i++ {
fns := tarr[i].Get("functionDeclarations")
if !fns.IsArray() {
continue
}
for _, fn := range fns.Array() {
if v := fn.Get("name"); v.Exists() {
names = append(names, v.String())
}
}
}
if len(names) > 0 {
shortMap = buildShortNameMap(names)
}
}
// helper for generating paired call IDs in the form: call_
// Gemini uses sequential pairing across possibly multiple in-flight
// functionCalls, so we keep a FIFO queue of generated call IDs and
// consume them in order when functionResponses arrive.
var pendingCallIDs []string
// genCallID creates a random call id like: call_<8chars>
genCallID := func() string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var b strings.Builder
// 8 chars random suffix
for i := 0; i < 24; i++ {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
b.WriteByte(letters[n.Int64()])
}
return "call_" + b.String()
}
// Model
out, _ = sjson.Set(out, "model", modelName)
// System instruction -> as a user message with input_text parts
sysParts := root.Get("system_instruction.parts")
if sysParts.IsArray() {
msg := `{"type":"message","role":"developer","content":[]}`
arr := sysParts.Array()
for i := 0; i < len(arr); i++ {
p := arr[i]
if t := p.Get("text"); t.Exists() {
part := `{}`
part, _ = sjson.Set(part, "type", "input_text")
part, _ = sjson.Set(part, "text", t.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
}
if len(gjson.Get(msg, "content").Array()) > 0 {
out, _ = sjson.SetRaw(out, "input.-1", msg)
}
}
// Contents -> messages and function calls/results
contents := root.Get("contents")
if contents.IsArray() {
items := contents.Array()
for i := 0; i < len(items); i++ {
item := items[i]
role := item.Get("role").String()
if role == "model" {
role = "assistant"
}
parts := item.Get("parts")
if !parts.IsArray() {
continue
}
parr := parts.Array()
for j := 0; j < len(parr); j++ {
p := parr[j]
// text part
if t := p.Get("text"); t.Exists() {
msg := `{"type":"message","role":"","content":[]}`
msg, _ = sjson.Set(msg, "role", role)
partType := "input_text"
if role == "assistant" {
partType = "output_text"
}
part := `{}`
part, _ = sjson.Set(part, "type", partType)
part, _ = sjson.Set(part, "text", t.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
out, _ = sjson.SetRaw(out, "input.-1", msg)
continue
}
// function call from model
if fc := p.Get("functionCall"); fc.Exists() {
fn := `{"type":"function_call"}`
if name := fc.Get("name"); name.Exists() {
n := name.String()
if short, ok := shortMap[n]; ok {
n = short
} else {
n = shortenNameIfNeeded(n)
}
fn, _ = sjson.Set(fn, "name", n)
}
if args := fc.Get("args"); args.Exists() {
fn, _ = sjson.Set(fn, "arguments", args.Raw)
}
// generate a paired random call_id and enqueue it so the
// corresponding functionResponse can pop the earliest id
// to preserve ordering when multiple calls are present.
id := genCallID()
fn, _ = sjson.Set(fn, "call_id", id)
pendingCallIDs = append(pendingCallIDs, id)
out, _ = sjson.SetRaw(out, "input.-1", fn)
continue
}
// function response from user
if fr := p.Get("functionResponse"); fr.Exists() {
fno := `{"type":"function_call_output"}`
// Prefer a string result if present; otherwise embed the raw response as a string
if res := fr.Get("response.result"); res.Exists() {
fno, _ = sjson.Set(fno, "output", res.String())
} else if resp := fr.Get("response"); resp.Exists() {
fno, _ = sjson.Set(fno, "output", resp.Raw)
}
// fno, _ = sjson.Set(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
// attach the oldest queued call_id to pair the response
// with its call. If the queue is empty, generate a new id.
var id string
if len(pendingCallIDs) > 0 {
id = pendingCallIDs[0]
// pop the first element
pendingCallIDs = pendingCallIDs[1:]
} else {
id = genCallID()
}
fno, _ = sjson.Set(fno, "call_id", id)
out, _ = sjson.SetRaw(out, "input.-1", fno)
continue
}
}
}
}
// Tools mapping: Gemini functionDeclarations -> Codex tools
tools := root.Get("tools")
if tools.IsArray() {
out, _ = sjson.SetRaw(out, "tools", `[]`)
out, _ = sjson.Set(out, "tool_choice", "auto")
tarr := tools.Array()
for i := 0; i < len(tarr); i++ {
td := tarr[i]
fns := td.Get("functionDeclarations")
if !fns.IsArray() {
continue
}
farr := fns.Array()
for j := 0; j < len(farr); j++ {
fn := farr[j]
tool := `{}`
tool, _ = sjson.Set(tool, "type", "function")
if v := fn.Get("name"); v.Exists() {
name := v.String()
if short, ok := shortMap[name]; ok {
name = short
} else {
name = shortenNameIfNeeded(name)
}
tool, _ = sjson.Set(tool, "name", name)
}
if v := fn.Get("description"); v.Exists() {
tool, _ = sjson.Set(tool, "description", v.String())
}
if prm := fn.Get("parameters"); prm.Exists() {
// Remove optional $schema field if present
cleaned := prm.Raw
cleaned, _ = sjson.Delete(cleaned, "$schema")
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
} else if prm = fn.Get("parametersJsonSchema"); prm.Exists() {
// Remove optional $schema field if present
cleaned := prm.Raw
cleaned, _ = sjson.Delete(cleaned, "$schema")
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
}
tool, _ = sjson.Set(tool, "strict", false)
out, _ = sjson.SetRaw(out, "tools.-1", tool)
}
}
}
// Fixed flags aligning with Codex expectations
out, _ = sjson.Set(out, "parallel_tool_calls", true)
// Convert Gemini thinkingConfig to Codex reasoning.effort.
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
effortSet := false
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
thinkingLevel := thinkingConfig.Get("thinkingLevel")
if !thinkingLevel.Exists() {
thinkingLevel = thinkingConfig.Get("thinking_level")
}
if thinkingLevel.Exists() {
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
if effort != "" {
out, _ = sjson.Set(out, "reasoning.effort", effort)
effortSet = true
}
} else {
thinkingBudget := thinkingConfig.Get("thinkingBudget")
if !thinkingBudget.Exists() {
thinkingBudget = thinkingConfig.Get("thinking_budget")
}
if thinkingBudget.Exists() {
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
out, _ = sjson.Set(out, "reasoning.effort", effort)
effortSet = true
}
}
}
}
}
if !effortSet {
// No thinking config, set default effort
out, _ = sjson.Set(out, "reasoning.effort", "medium")
}
out, _ = sjson.Set(out, "reasoning.summary", "auto")
out, _ = sjson.Set(out, "stream", true)
out, _ = sjson.Set(out, "store", false)
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
var pathsToLower []string
toolsResult := gjson.Get(out, "tools")
util.Walk(toolsResult, "", "type", &pathsToLower)
for _, p := range pathsToLower {
fullPath := fmt.Sprintf("tools.%s", p)
out, _ = sjson.Set(out, fullPath, strings.ToLower(gjson.Get(out, fullPath).String()))
}
return []byte(out)
}
// shortenNameIfNeeded applies the simple shortening rule for a single name.
func shortenNameIfNeeded(name string) string {
const limit = 64
if len(name) <= limit {
return name
}
if strings.HasPrefix(name, "mcp__") {
idx := strings.LastIndex(name, "__")
if idx > 0 {
cand := "mcp__" + name[idx+2:]
if len(cand) > limit {
return cand[:limit]
}
return cand
}
}
return name[:limit]
}
// buildShortNameMap ensures uniqueness of shortened names within a request.
func buildShortNameMap(names []string) map[string]string {
const limit = 64
used := map[string]struct{}{}
m := map[string]string{}
baseCandidate := func(n string) string {
if len(n) <= limit {
return n
}
if strings.HasPrefix(n, "mcp__") {
idx := strings.LastIndex(n, "__")
if idx > 0 {
cand := "mcp__" + n[idx+2:]
if len(cand) > limit {
cand = cand[:limit]
}
return cand
}
}
return n[:limit]
}
makeUnique := func(cand string) string {
if _, ok := used[cand]; !ok {
return cand
}
base := cand
for i := 1; ; i++ {
suffix := "_" + strconv.Itoa(i)
allowed := limit - len(suffix)
if allowed < 0 {
allowed = 0
}
tmp := base
if len(tmp) > allowed {
tmp = tmp[:allowed]
}
tmp = tmp + suffix
if _, ok := used[tmp]; !ok {
return tmp
}
}
}
for _, n := range names {
cand := baseCandidate(n)
uniq := makeUnique(cand)
used[uniq] = struct{}{}
m[n] = uniq
}
return m
}
================================================
FILE: internal/translator/codex/gemini/codex_gemini_response.go
================================================
// Package gemini provides response translation functionality for Codex to Gemini API compatibility.
// This package handles the conversion of Codex API responses into Gemini-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by Gemini API clients.
package gemini
import (
"bytes"
"context"
"fmt"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
dataTag = []byte("data:")
)
// ConvertCodexResponseToGeminiParams holds parameters for response conversion.
type ConvertCodexResponseToGeminiParams struct {
Model string
CreatedAt int64
ResponseID string
LastStorageOutput string
}
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini format.
// This function processes various Codex event types and transforms them into Gemini-compatible JSON responses.
// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini API format.
// The function maintains state across multiple calls to ensure proper response sequencing.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response
func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertCodexResponseToGeminiParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
LastStorageOutput: "",
}
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
rootResult := gjson.ParseBytes(rawJSON)
typeResult := rootResult.Get("type")
typeStr := typeResult.String()
// Base Gemini response template
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" && typeStr == "response.output_item.done" {
template = (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput
} else {
template, _ = sjson.Set(template, "modelVersion", (*param).(*ConvertCodexResponseToGeminiParams).Model)
createdAtResult := rootResult.Get("response.created_at")
if createdAtResult.Exists() {
(*param).(*ConvertCodexResponseToGeminiParams).CreatedAt = createdAtResult.Int()
template, _ = sjson.Set(template, "createTime", time.Unix((*param).(*ConvertCodexResponseToGeminiParams).CreatedAt, 0).Format(time.RFC3339Nano))
}
template, _ = sjson.Set(template, "responseId", (*param).(*ConvertCodexResponseToGeminiParams).ResponseID)
}
// Handle function call completion
if typeStr == "response.output_item.done" {
itemResult := rootResult.Get("item")
itemType := itemResult.Get("type").String()
if itemType == "function_call" {
// Create function call part
functionCall := `{"functionCall":{"name":"","args":{}}}`
{
// Restore original tool name if shortened
n := itemResult.Get("name").String()
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
if orig, ok := rev[n]; ok {
n = orig
}
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
}
// Parse and set arguments
argsStr := itemResult.Get("arguments").String()
if argsStr != "" {
argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
}
}
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput = template
// Use this return to storage message
return []string{}
}
}
if typeStr == "response.created" { // Handle response creation - set model and response ID
template, _ = sjson.Set(template, "modelVersion", rootResult.Get("response.model").String())
template, _ = sjson.Set(template, "responseId", rootResult.Get("response.id").String())
(*param).(*ConvertCodexResponseToGeminiParams).ResponseID = rootResult.Get("response.id").String()
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
part := `{"thought":true,"text":""}`
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta
part := `{"text":""}`
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
} else if typeStr == "response.completed" { // Handle response completion with usage metadata
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
totalTokens := rootResult.Get("response.usage.input_tokens").Int() + rootResult.Get("response.usage.output_tokens").Int()
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
} else {
return []string{}
}
if (*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput != "" {
return []string{(*param).(*ConvertCodexResponseToGeminiParams).LastStorageOutput, template}
} else {
return []string{template}
}
}
// ConvertCodexResponseToGeminiNonStream converts a non-streaming Codex response to a non-streaming Gemini response.
// This function processes the complete Codex response and transforms it into a single Gemini-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the Gemini API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
}
// Base Gemini response template for non-streaming
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
// Set model version
template, _ = sjson.Set(template, "modelVersion", modelName)
// Set response metadata from the completed response
responseData := rootResult.Get("response")
if responseData.Exists() {
// Set response ID
if responseId := responseData.Get("id"); responseId.Exists() {
template, _ = sjson.Set(template, "responseId", responseId.String())
}
// Set creation time
if createdAt := responseData.Get("created_at"); createdAt.Exists() {
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
}
// Set usage metadata
if usage := responseData.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
totalTokens := inputTokens + outputTokens
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
}
// Process output content to build parts array
hasToolCall := false
var pendingFunctionCalls []string
flushPendingFunctionCalls := func() {
if len(pendingFunctionCalls) == 0 {
return
}
// Add all pending function calls as individual parts
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
for _, fc := range pendingFunctionCalls {
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", fc)
}
pendingFunctionCalls = nil
}
if output := responseData.Get("output"); output.Exists() && output.IsArray() {
output.ForEach(func(key, value gjson.Result) bool {
itemType := value.Get("type").String()
switch itemType {
case "reasoning":
// Flush any pending function calls before adding non-function content
flushPendingFunctionCalls()
// Add thinking content
if content := value.Get("content"); content.Exists() {
part := `{"text":"","thought":true}`
part, _ = sjson.Set(part, "text", content.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
}
case "message":
// Flush any pending function calls before adding non-function content
flushPendingFunctionCalls()
// Add regular text content
if content := value.Get("content"); content.Exists() && content.IsArray() {
content.ForEach(func(_, contentItem gjson.Result) bool {
if contentItem.Get("type").String() == "output_text" {
if text := contentItem.Get("text"); text.Exists() {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", text.String())
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
}
}
return true
})
}
case "function_call":
// Collect function call for potential merging with consecutive ones
hasToolCall = true
functionCall := `{"functionCall":{"args":{},"name":""}}`
{
n := value.Get("name").String()
rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON)
if orig, ok := rev[n]; ok {
n = orig
}
functionCall, _ = sjson.Set(functionCall, "functionCall.name", n)
}
// Parse and set arguments
if argsStr := value.Get("arguments").String(); argsStr != "" {
argsResult := gjson.Parse(argsStr)
if argsResult.IsObject() {
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
}
}
pendingFunctionCalls = append(pendingFunctionCalls, functionCall)
}
return true
})
// Handle any remaining pending function calls at the end
flushPendingFunctionCalls()
}
// Set finish reason based on whether there were tool calls
if hasToolCall {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
} else {
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
}
}
return template
}
// buildReverseMapFromGeminiOriginal builds a map[short]original from original Gemini request tools.
func buildReverseMapFromGeminiOriginal(original []byte) map[string]string {
tools := gjson.GetBytes(original, "tools")
rev := map[string]string{}
if !tools.IsArray() {
return rev
}
var names []string
tarr := tools.Array()
for i := 0; i < len(tarr); i++ {
fns := tarr[i].Get("functionDeclarations")
if !fns.IsArray() {
continue
}
for _, fn := range fns.Array() {
if v := fn.Get("name"); v.Exists() {
names = append(names, v.String())
}
}
}
if len(names) > 0 {
m := buildShortNameMap(names)
for orig, short := range m {
rev[short] = orig
}
}
return rev
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}
================================================
FILE: internal/translator/codex/gemini/init.go
================================================
package gemini
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Gemini,
Codex,
ConvertGeminiRequestToCodex,
interfaces.TranslateResponse{
Stream: ConvertCodexResponseToGemini,
NonStream: ConvertCodexResponseToGeminiNonStream,
TokenCount: GeminiTokenCount,
},
)
}
================================================
FILE: internal/translator/codex/gemini-cli/codex_gemini-cli_request.go
================================================
// Package geminiCLI provides request translation functionality for Gemini CLI to Codex API compatibility.
// It handles parsing and transforming Gemini CLI API requests into Codex API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini CLI API format and Codex API's expected format.
package geminiCLI
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiCLIRequestToCodex parses and transforms a Gemini CLI API request into Codex API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Codex API.
// The function performs the following transformations:
// 1. Extracts the inner request object and promotes it to the top level
// 2. Restores the model information at the top level
// 3. Converts systemInstruction field to system_instruction for Codex compatibility
// 4. Delegates to the Gemini-to-Codex conversion function for further processing
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - stream: A boolean indicating if the request is for a streaming response
//
// Returns:
// - []byte: The transformed request data in Codex API format
func ConvertGeminiCLIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
if gjson.GetBytes(rawJSON, "systemInstruction").Exists() {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw))
rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction")
}
return ConvertGeminiRequestToCodex(modelName, rawJSON, stream)
}
================================================
FILE: internal/translator/codex/gemini-cli/codex_gemini-cli_response.go
================================================
// Package geminiCLI provides response translation functionality for Codex to Gemini CLI API compatibility.
// This package handles the conversion of Codex API responses into Gemini CLI-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by Gemini CLI API clients.
package geminiCLI
import (
"context"
"fmt"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
"github.com/tidwall/sjson"
)
// ConvertCodexResponseToGeminiCLI converts Codex streaming response format to Gemini CLI format.
// This function processes various Codex event types and transforms them into Gemini-compatible JSON responses.
// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini CLI API format.
// The function wraps each converted response in a "response" object to match the Gemini CLI API structure.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response wrapped in a response object
func ConvertCodexResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
outputs := ConvertCodexResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
newOutputs := make([]string, 0)
for i := 0; i < len(outputs); i++ {
json := `{"response": {}}`
output, _ := sjson.SetRaw(json, "response", outputs[i])
newOutputs = append(newOutputs, output)
}
return newOutputs
}
// ConvertCodexResponseToGeminiCLINonStream converts a non-streaming Codex response to a non-streaming Gemini CLI response.
// This function processes the complete Codex response and transforms it into a single Gemini-compatible
// JSON response. It wraps the converted response in a "response" object to match the Gemini CLI API structure.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: A Gemini-compatible JSON response wrapped in a response object
func ConvertCodexResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
// log.Debug(string(rawJSON))
strJSON := ConvertCodexResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
json := `{"response": {}}`
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}
================================================
FILE: internal/translator/codex/gemini-cli/init.go
================================================
package geminiCLI
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
GeminiCLI,
Codex,
ConvertGeminiCLIRequestToCodex,
interfaces.TranslateResponse{
Stream: ConvertCodexResponseToGeminiCLI,
NonStream: ConvertCodexResponseToGeminiCLINonStream,
TokenCount: GeminiCLITokenCount,
},
)
}
================================================
FILE: internal/translator/codex/openai/chat-completions/codex_openai_request.go
================================================
// Package openai provides utilities to translate OpenAI Chat Completions
// request JSON into OpenAI Responses API request JSON using gjson/sjson.
// It supports tools, multimodal text/image inputs, and Structured Outputs.
// The package handles the conversion of OpenAI API requests into the format
// expected by the OpenAI Responses API, including proper mapping of messages,
// tools, and generation parameters.
package chat_completions
import (
"strconv"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertOpenAIRequestToCodex converts an OpenAI Chat Completions request JSON
// into an OpenAI Responses API request JSON. The transformation follows the
// examples defined in docs/2.md exactly, including tools, multi-turn dialog,
// multimodal text/image handling, and Structured Outputs mapping.
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the OpenAI Chat Completions API
// - stream: A boolean indicating if the request is for a streaming response
//
// Returns:
// - []byte: The transformed request data in OpenAI Responses API format
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
// Start with empty JSON object
out := `{"instructions":""}`
// Stream must be set to true
out, _ = sjson.Set(out, "stream", stream)
// Codex not support temperature, top_p, top_k, max_output_tokens, so comment them
// if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() {
// out, _ = sjson.Set(out, "temperature", v.Value())
// }
// if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() {
// out, _ = sjson.Set(out, "top_p", v.Value())
// }
// if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() {
// out, _ = sjson.Set(out, "top_k", v.Value())
// }
// Map token limits
// if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() {
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
// }
// if v := gjson.GetBytes(rawJSON, "max_completion_tokens"); v.Exists() {
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
// }
// Map reasoning effort
if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
out, _ = sjson.Set(out, "reasoning.effort", v.Value())
} else {
out, _ = sjson.Set(out, "reasoning.effort", "medium")
}
out, _ = sjson.Set(out, "parallel_tool_calls", true)
out, _ = sjson.Set(out, "reasoning.summary", "auto")
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
// Model
out, _ = sjson.Set(out, "model", modelName)
// Build tool name shortening map from original tools (if any)
originalToolNameMap := map[string]string{}
{
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
// Collect original tool names
var names []string
arr := tools.Array()
for i := 0; i < len(arr); i++ {
t := arr[i]
if t.Get("type").String() == "function" {
fn := t.Get("function")
if fn.Exists() {
if v := fn.Get("name"); v.Exists() {
names = append(names, v.String())
}
}
}
}
if len(names) > 0 {
originalToolNameMap = buildShortNameMap(names)
}
}
}
// Extract system instructions from first system message (string or text object)
messages := gjson.GetBytes(rawJSON, "messages")
// if messages.IsArray() {
// arr := messages.Array()
// for i := 0; i < len(arr); i++ {
// m := arr[i]
// if m.Get("role").String() == "system" {
// c := m.Get("content")
// if c.Type == gjson.String {
// out, _ = sjson.Set(out, "instructions", c.String())
// } else if c.IsObject() && c.Get("type").String() == "text" {
// out, _ = sjson.Set(out, "instructions", c.Get("text").String())
// }
// break
// }
// }
// }
// Build input from messages, handling all message types including tool calls
out, _ = sjson.SetRaw(out, "input", `[]`)
if messages.IsArray() {
arr := messages.Array()
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
switch role {
case "tool":
// Handle tool response messages as top-level function_call_output objects
toolCallID := m.Get("tool_call_id").String()
content := m.Get("content").String()
// Create function_call_output object
funcOutput := `{}`
funcOutput, _ = sjson.Set(funcOutput, "type", "function_call_output")
funcOutput, _ = sjson.Set(funcOutput, "call_id", toolCallID)
funcOutput, _ = sjson.Set(funcOutput, "output", content)
out, _ = sjson.SetRaw(out, "input.-1", funcOutput)
default:
// Handle regular messages
msg := `{}`
msg, _ = sjson.Set(msg, "type", "message")
if role == "system" {
msg, _ = sjson.Set(msg, "role", "developer")
} else {
msg, _ = sjson.Set(msg, "role", role)
}
msg, _ = sjson.SetRaw(msg, "content", `[]`)
// Handle regular content
c := m.Get("content")
if c.Exists() && c.Type == gjson.String && c.String() != "" {
// Single string content
partType := "input_text"
if role == "assistant" {
partType = "output_text"
}
part := `{}`
part, _ = sjson.Set(part, "type", partType)
part, _ = sjson.Set(part, "text", c.String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
} else if c.Exists() && c.IsArray() {
items := c.Array()
for j := 0; j < len(items); j++ {
it := items[j]
t := it.Get("type").String()
switch t {
case "text":
partType := "input_text"
if role == "assistant" {
partType = "output_text"
}
part := `{}`
part, _ = sjson.Set(part, "type", partType)
part, _ = sjson.Set(part, "text", it.Get("text").String())
msg, _ = sjson.SetRaw(msg, "content.-1", part)
case "image_url":
// Map image inputs to input_image for Responses API
if role == "user" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_image")
if u := it.Get("image_url.url"); u.Exists() {
part, _ = sjson.Set(part, "image_url", u.String())
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
case "file":
if role == "user" {
fileData := it.Get("file.file_data").String()
filename := it.Get("file.filename").String()
if fileData != "" {
part := `{}`
part, _ = sjson.Set(part, "type", "input_file")
part, _ = sjson.Set(part, "file_data", fileData)
if filename != "" {
part, _ = sjson.Set(part, "filename", filename)
}
msg, _ = sjson.SetRaw(msg, "content.-1", part)
}
}
}
}
}
// Don't emit empty assistant messages when only tool_calls
// are present — Responses API needs function_call items
// directly, otherwise call_id matching fails (#2132).
if role != "assistant" || len(gjson.Get(msg, "content").Array()) > 0 {
out, _ = sjson.SetRaw(out, "input.-1", msg)
}
// Handle tool calls for assistant messages as separate top-level objects
if role == "assistant" {
toolCalls := m.Get("tool_calls")
if toolCalls.Exists() && toolCalls.IsArray() {
toolCallsArr := toolCalls.Array()
for j := 0; j < len(toolCallsArr); j++ {
tc := toolCallsArr[j]
if tc.Get("type").String() == "function" {
// Create function_call as top-level object
funcCall := `{}`
funcCall, _ = sjson.Set(funcCall, "type", "function_call")
funcCall, _ = sjson.Set(funcCall, "call_id", tc.Get("id").String())
{
name := tc.Get("function.name").String()
if short, ok := originalToolNameMap[name]; ok {
name = short
} else {
name = shortenNameIfNeeded(name)
}
funcCall, _ = sjson.Set(funcCall, "name", name)
}
funcCall, _ = sjson.Set(funcCall, "arguments", tc.Get("function.arguments").String())
out, _ = sjson.SetRaw(out, "input.-1", funcCall)
}
}
}
}
}
}
}
// Map response_format and text settings to Responses API text.format
rf := gjson.GetBytes(rawJSON, "response_format")
text := gjson.GetBytes(rawJSON, "text")
if rf.Exists() {
// Always create text object when response_format provided
if !gjson.Get(out, "text").Exists() {
out, _ = sjson.SetRaw(out, "text", `{}`)
}
rft := rf.Get("type").String()
switch rft {
case "text":
out, _ = sjson.Set(out, "text.format.type", "text")
case "json_schema":
js := rf.Get("json_schema")
if js.Exists() {
out, _ = sjson.Set(out, "text.format.type", "json_schema")
if v := js.Get("name"); v.Exists() {
out, _ = sjson.Set(out, "text.format.name", v.Value())
}
if v := js.Get("strict"); v.Exists() {
out, _ = sjson.Set(out, "text.format.strict", v.Value())
}
if v := js.Get("schema"); v.Exists() {
out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw)
}
}
}
// Map verbosity if provided
if text.Exists() {
if v := text.Get("verbosity"); v.Exists() {
out, _ = sjson.Set(out, "text.verbosity", v.Value())
}
}
} else if text.Exists() {
// If only text.verbosity present (no response_format), map verbosity
if v := text.Get("verbosity"); v.Exists() {
if !gjson.Get(out, "text").Exists() {
out, _ = sjson.SetRaw(out, "text", `{}`)
}
out, _ = sjson.Set(out, "text.verbosity", v.Value())
}
}
// Map tools (flatten function fields)
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", `[]`)
arr := tools.Array()
for i := 0; i < len(arr); i++ {
t := arr[i]
toolType := t.Get("type").String()
// Pass through built-in tools (e.g. {"type":"web_search"}) directly for the Responses API.
// Only "function" needs structural conversion because Chat Completions nests details under "function".
if toolType != "" && toolType != "function" && t.IsObject() {
out, _ = sjson.SetRaw(out, "tools.-1", t.Raw)
continue
}
if toolType == "function" {
item := `{}`
item, _ = sjson.Set(item, "type", "function")
fn := t.Get("function")
if fn.Exists() {
if v := fn.Get("name"); v.Exists() {
name := v.String()
if short, ok := originalToolNameMap[name]; ok {
name = short
} else {
name = shortenNameIfNeeded(name)
}
item, _ = sjson.Set(item, "name", name)
}
if v := fn.Get("description"); v.Exists() {
item, _ = sjson.Set(item, "description", v.Value())
}
if v := fn.Get("parameters"); v.Exists() {
item, _ = sjson.SetRaw(item, "parameters", v.Raw)
}
if v := fn.Get("strict"); v.Exists() {
item, _ = sjson.Set(item, "strict", v.Value())
}
}
out, _ = sjson.SetRaw(out, "tools.-1", item)
}
}
}
// Map tool_choice when present.
// Chat Completions: "tool_choice" can be a string ("auto"/"none") or an object (e.g. {"type":"function","function":{"name":"..."}}).
// Responses API: keep built-in tool choices as-is; flatten function choice to {"type":"function","name":"..."}.
if tc := gjson.GetBytes(rawJSON, "tool_choice"); tc.Exists() {
switch {
case tc.Type == gjson.String:
out, _ = sjson.Set(out, "tool_choice", tc.String())
case tc.IsObject():
tcType := tc.Get("type").String()
if tcType == "function" {
name := tc.Get("function.name").String()
if name != "" {
if short, ok := originalToolNameMap[name]; ok {
name = short
} else {
name = shortenNameIfNeeded(name)
}
}
choice := `{}`
choice, _ = sjson.Set(choice, "type", "function")
if name != "" {
choice, _ = sjson.Set(choice, "name", name)
}
out, _ = sjson.SetRaw(out, "tool_choice", choice)
} else if tcType != "" {
// Built-in tool choices (e.g. {"type":"web_search"}) are already Responses-compatible.
out, _ = sjson.SetRaw(out, "tool_choice", tc.Raw)
}
}
}
out, _ = sjson.Set(out, "store", false)
return []byte(out)
}
// shortenNameIfNeeded applies the simple shortening rule for a single name.
// If the name length exceeds 64, it will try to preserve the "mcp__" prefix and last segment.
// Otherwise it truncates to 64 characters.
func shortenNameIfNeeded(name string) string {
const limit = 64
if len(name) <= limit {
return name
}
if strings.HasPrefix(name, "mcp__") {
// Keep prefix and last segment after '__'
idx := strings.LastIndex(name, "__")
if idx > 0 {
candidate := "mcp__" + name[idx+2:]
if len(candidate) > limit {
return candidate[:limit]
}
return candidate
}
}
return name[:limit]
}
// buildShortNameMap generates unique short names (<=64) for the given list of names.
// It preserves the "mcp__" prefix with the last segment when possible and ensures uniqueness
// by appending suffixes like "~1", "~2" if needed.
func buildShortNameMap(names []string) map[string]string {
const limit = 64
used := map[string]struct{}{}
m := map[string]string{}
baseCandidate := func(n string) string {
if len(n) <= limit {
return n
}
if strings.HasPrefix(n, "mcp__") {
idx := strings.LastIndex(n, "__")
if idx > 0 {
cand := "mcp__" + n[idx+2:]
if len(cand) > limit {
cand = cand[:limit]
}
return cand
}
}
return n[:limit]
}
makeUnique := func(cand string) string {
if _, ok := used[cand]; !ok {
return cand
}
base := cand
for i := 1; ; i++ {
suffix := "_" + strconv.Itoa(i)
allowed := limit - len(suffix)
if allowed < 0 {
allowed = 0
}
tmp := base
if len(tmp) > allowed {
tmp = tmp[:allowed]
}
tmp = tmp + suffix
if _, ok := used[tmp]; !ok {
return tmp
}
}
}
for _, n := range names {
cand := baseCandidate(n)
uniq := makeUnique(cand)
used[uniq] = struct{}{}
m[n] = uniq
}
return m
}
================================================
FILE: internal/translator/codex/openai/chat-completions/codex_openai_request_test.go
================================================
package chat_completions
import (
"testing"
"github.com/tidwall/gjson"
)
// Basic tool-call: system + user + assistant(tool_calls, no content) + tool result.
// Expects developer msg + user msg + function_call + function_call_output.
// No empty assistant message should appear between user and function_call.
func TestToolCallSimple(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the weather in Paris?"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_1",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Paris\"}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_1",
"content": "sunny, 22C"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
if len(items) != 4 {
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
// system -> developer
if items[0].Get("type").String() != "message" {
t.Errorf("item 0: expected type 'message', got '%s'", items[0].Get("type").String())
}
if items[0].Get("role").String() != "developer" {
t.Errorf("item 0: expected role 'developer', got '%s'", items[0].Get("role").String())
}
// user
if items[1].Get("type").String() != "message" {
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
}
if items[1].Get("role").String() != "user" {
t.Errorf("item 1: expected role 'user', got '%s'", items[1].Get("role").String())
}
// function_call, not an empty assistant msg
if items[2].Get("type").String() != "function_call" {
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
}
if items[2].Get("call_id").String() != "call_1" {
t.Errorf("item 2: expected call_id 'call_1', got '%s'", items[2].Get("call_id").String())
}
if items[2].Get("name").String() != "get_weather" {
t.Errorf("item 2: expected name 'get_weather', got '%s'", items[2].Get("name").String())
}
if items[2].Get("arguments").String() != `{"city":"Paris"}` {
t.Errorf("item 2: unexpected arguments: %s", items[2].Get("arguments").String())
}
// function_call_output
if items[3].Get("type").String() != "function_call_output" {
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
}
if items[3].Get("call_id").String() != "call_1" {
t.Errorf("item 3: expected call_id 'call_1', got '%s'", items[3].Get("call_id").String())
}
if items[3].Get("output").String() != "sunny, 22C" {
t.Errorf("item 3: expected output 'sunny, 22C', got '%s'", items[3].Get("output").String())
}
}
// Assistant has both text content and tool_calls — the message should
// be emitted (non-empty content), followed by function_call items.
func TestToolCallWithContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "What is the weather?"},
{
"role": "assistant",
"content": "Let me check the weather for you.",
"tool_calls": [
{
"id": "call_abc",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{}"
}
}
]
},
{
"role": "tool",
"tool_call_id": "call_abc",
"content": "rainy, 15C"
}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user + assistant(with content) + function_call + function_call_output
if len(items) != 4 {
t.Fatalf("expected 4 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
}
// assistant with content — should be kept
if items[1].Get("type").String() != "message" {
t.Errorf("item 1: expected type 'message', got '%s'", items[1].Get("type").String())
}
if items[1].Get("role").String() != "assistant" {
t.Errorf("item 1: expected role 'assistant', got '%s'", items[1].Get("role").String())
}
contentParts := items[1].Get("content").Array()
if len(contentParts) == 0 {
t.Errorf("item 1: assistant message should have content parts")
}
if items[2].Get("type").String() != "function_call" {
t.Errorf("item 2: expected type 'function_call', got '%s'", items[2].Get("type").String())
}
if items[2].Get("call_id").String() != "call_abc" {
t.Errorf("item 2: expected call_id 'call_abc', got '%s'", items[2].Get("call_id").String())
}
if items[3].Get("type").String() != "function_call_output" {
t.Errorf("item 3: expected type 'function_call_output', got '%s'", items[3].Get("type").String())
}
if items[3].Get("call_id").String() != "call_abc" {
t.Errorf("item 3: expected call_id 'call_abc', got '%s'", items[3].Get("call_id").String())
}
}
// Parallel tool calls: assistant invokes 3 tools at once, all call_ids
// and outputs must be translated and paired correctly.
func TestMultipleToolCalls(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Compare weather in Paris, London and Tokyo"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_paris",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Paris\"}"
}
},
{
"id": "call_london",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"London\"}"
}
},
{
"id": "call_tokyo",
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"city\":\"Tokyo\"}"
}
}
]
},
{"role": "tool", "tool_call_id": "call_paris", "content": "sunny, 22C"},
{"role": "tool", "tool_call_id": "call_london", "content": "cloudy, 14C"},
{"role": "tool", "tool_call_id": "call_tokyo", "content": "humid, 28C"}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user + 3 function_call + 3 function_call_output = 7
if len(items) != 7 {
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected role 'user', got '%s'", items[0].Get("role").String())
}
expectedCallIDs := []string{"call_paris", "call_london", "call_tokyo"}
for i, expectedID := range expectedCallIDs {
idx := i + 1
if items[idx].Get("type").String() != "function_call" {
t.Errorf("item %d: expected type 'function_call', got '%s'", idx, items[idx].Get("type").String())
}
if items[idx].Get("call_id").String() != expectedID {
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedID, items[idx].Get("call_id").String())
}
}
expectedOutputs := []string{"sunny, 22C", "cloudy, 14C", "humid, 28C"}
for i, expectedOutput := range expectedOutputs {
idx := i + 4
if items[idx].Get("type").String() != "function_call_output" {
t.Errorf("item %d: expected type 'function_call_output', got '%s'", idx, items[idx].Get("type").String())
}
if items[idx].Get("call_id").String() != expectedCallIDs[i] {
t.Errorf("item %d: expected call_id '%s', got '%s'", idx, expectedCallIDs[i], items[idx].Get("call_id").String())
}
if items[idx].Get("output").String() != expectedOutput {
t.Errorf("item %d: expected output '%s', got '%s'", idx, expectedOutput, items[idx].Get("output").String())
}
}
}
// Regression test for #2132: tool-call-only assistant messages (content:null)
// must not produce an empty message item in the translated output.
func TestNoSpuriousEmptyAssistantMessage(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Call a tool"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_x",
"type": "function",
"function": {"name": "do_thing", "arguments": "{}"}
}
]
},
{"role": "tool", "tool_call_id": "call_x", "content": "done"}
],
"tools": [
{
"type": "function",
"function": {
"name": "do_thing",
"description": "Do a thing",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
for i, item := range items {
typ := item.Get("type").String()
role := item.Get("role").String()
if typ == "message" && role == "assistant" {
contentArr := item.Get("content").Array()
if len(contentArr) == 0 {
t.Errorf("item %d: empty assistant message breaks call_id matching. item: %s", i, item.Raw)
}
}
}
// should be exactly: user + function_call + function_call_output
if len(items) != 3 {
t.Fatalf("expected 3 input items (user + function_call + function_call_output), got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
if items[0].Get("type").String() != "message" || items[0].Get("role").String() != "user" {
t.Errorf("item 0: expected user message")
}
if items[1].Get("type").String() != "function_call" {
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
}
if items[2].Get("type").String() != "function_call_output" {
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
}
}
// Two rounds of tool calling in one conversation, with a text reply in between.
func TestMultiTurnToolCalling(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Weather in Paris?"},
{
"role": "assistant",
"content": null,
"tool_calls": [{"id": "call_r1", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"Paris\"}"}}]
},
{"role": "tool", "tool_call_id": "call_r1", "content": "sunny"},
{"role": "assistant", "content": "It is sunny in Paris."},
{"role": "user", "content": "And London?"},
{
"role": "assistant",
"content": null,
"tool_calls": [{"id": "call_r2", "type": "function", "function": {"name": "get_weather", "arguments": "{\"city\":\"London\"}"}}]
},
{"role": "tool", "tool_call_id": "call_r2", "content": "rainy"}
],
"tools": [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather",
"parameters": {"type": "object", "properties": {"city": {"type": "string"}}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// user, func_call(r1), func_output(r1), assistant text, user, func_call(r2), func_output(r2)
if len(items) != 7 {
t.Fatalf("expected 7 input items, got %d: %s", len(items), gjson.Get(result, "input").Raw)
}
for i, item := range items {
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
if len(item.Get("content").Array()) == 0 {
t.Errorf("item %d: unexpected empty assistant message", i)
}
}
}
// round 1
if items[1].Get("type").String() != "function_call" {
t.Errorf("item 1: expected function_call, got %s", items[1].Get("type").String())
}
if items[1].Get("call_id").String() != "call_r1" {
t.Errorf("item 1: expected call_id 'call_r1', got '%s'", items[1].Get("call_id").String())
}
if items[2].Get("type").String() != "function_call_output" {
t.Errorf("item 2: expected function_call_output, got %s", items[2].Get("type").String())
}
// text reply between rounds
if items[3].Get("type").String() != "message" || items[3].Get("role").String() != "assistant" {
t.Errorf("item 3: expected assistant message, got type=%s role=%s", items[3].Get("type").String(), items[3].Get("role").String())
}
// round 2
if items[5].Get("type").String() != "function_call" {
t.Errorf("item 5: expected function_call, got %s", items[5].Get("type").String())
}
if items[5].Get("call_id").String() != "call_r2" {
t.Errorf("item 5: expected call_id 'call_r2', got '%s'", items[5].Get("call_id").String())
}
if items[6].Get("type").String() != "function_call_output" {
t.Errorf("item 6: expected function_call_output, got %s", items[6].Get("type").String())
}
}
// Tool names over 64 chars get shortened, call_id stays the same.
func TestToolNameShortening(t *testing.T) {
longName := "a_very_long_tool_name_that_exceeds_sixty_four_characters_limit_here_test"
if len(longName) <= 64 {
t.Fatalf("test setup error: name must be > 64 chars, got %d", len(longName))
}
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Do it"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "call_long",
"type": "function",
"function": {
"name": "` + longName + `",
"arguments": "{}"
}
}
]
},
{"role": "tool", "tool_call_id": "call_long", "content": "ok"}
],
"tools": [
{
"type": "function",
"function": {
"name": "` + longName + `",
"description": "A tool with a very long name",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// find function_call
var funcCallItem gjson.Result
for _, item := range items {
if item.Get("type").String() == "function_call" {
funcCallItem = item
break
}
}
if !funcCallItem.Exists() {
t.Fatal("no function_call item found in output")
}
// call_id unchanged
if funcCallItem.Get("call_id").String() != "call_long" {
t.Errorf("call_id changed: expected 'call_long', got '%s'", funcCallItem.Get("call_id").String())
}
// name must be truncated
translatedName := funcCallItem.Get("name").String()
if translatedName == longName {
t.Errorf("tool name was NOT shortened: still '%s'", translatedName)
}
if len(translatedName) > 64 {
t.Errorf("shortened name still > 64 chars: len=%d name='%s'", len(translatedName), translatedName)
}
}
// content:"" (empty string, not null) should be treated the same as null.
func TestEmptyStringContent(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Do something"},
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_empty",
"type": "function",
"function": {"name": "action", "arguments": "{}"}
}
]
},
{"role": "tool", "tool_call_id": "call_empty", "content": "result"}
],
"tools": [
{
"type": "function",
"function": {
"name": "action",
"description": "An action",
"parameters": {"type": "object", "properties": {}}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
for i, item := range items {
if item.Get("type").String() == "message" && item.Get("role").String() == "assistant" {
if len(item.Get("content").Array()) == 0 {
t.Errorf("item %d: empty assistant message from content:\"\"", i)
}
}
}
// user + function_call + function_call_output
if len(items) != 3 {
t.Errorf("expected 3 input items, got %d", len(items))
}
}
// Every function_call_output must have a matching function_call by call_id.
func TestCallIDsMatchBetweenCallAndOutput(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Multi-tool"},
{
"role": "assistant",
"content": null,
"tool_calls": [
{"id": "id_a", "type": "function", "function": {"name": "tool_a", "arguments": "{}"}},
{"id": "id_b", "type": "function", "function": {"name": "tool_b", "arguments": "{}"}}
]
},
{"role": "tool", "tool_call_id": "id_a", "content": "res_a"},
{"role": "tool", "tool_call_id": "id_b", "content": "res_b"}
],
"tools": [
{"type": "function", "function": {"name": "tool_a", "description": "A", "parameters": {"type": "object", "properties": {}}}},
{"type": "function", "function": {"name": "tool_b", "description": "B", "parameters": {"type": "object", "properties": {}}}}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
items := gjson.Get(result, "input").Array()
// collect call_ids from function_call items
callIDs := make(map[string]bool)
for _, item := range items {
if item.Get("type").String() == "function_call" {
callIDs[item.Get("call_id").String()] = true
}
}
for i, item := range items {
if item.Get("type").String() == "function_call_output" {
outID := item.Get("call_id").String()
if !callIDs[outID] {
t.Errorf("item %d: function_call_output has call_id '%s' with no matching function_call", i, outID)
}
}
}
// 2 calls, 2 outputs
funcCallCount := 0
funcOutputCount := 0
for _, item := range items {
switch item.Get("type").String() {
case "function_call":
funcCallCount++
case "function_call_output":
funcOutputCount++
}
}
if funcCallCount != 2 {
t.Errorf("expected 2 function_calls, got %d", funcCallCount)
}
if funcOutputCount != 2 {
t.Errorf("expected 2 function_call_outputs, got %d", funcOutputCount)
}
}
// Tools array should carry over to the Responses format output.
func TestToolsDefinitionTranslated(t *testing.T) {
input := []byte(`{
"model": "gpt-4o",
"messages": [
{"role": "user", "content": "Hi"}
],
"tools": [
{
"type": "function",
"function": {
"name": "search",
"description": "Search the web",
"parameters": {"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}
}
}
]
}`)
out := ConvertOpenAIRequestToCodex("gpt-4o", input, true)
result := string(out)
tools := gjson.Get(result, "tools").Array()
if len(tools) == 0 {
t.Fatal("no tools found in output")
}
found := false
for _, tool := range tools {
if tool.Get("name").String() == "search" {
found = true
break
}
}
if !found {
t.Errorf("tool 'search' not found in output tools: %s", gjson.Get(result, "tools").Raw)
}
}
================================================
FILE: internal/translator/codex/openai/chat-completions/codex_openai_response.go
================================================
// Package openai provides response translation functionality for Codex to OpenAI API compatibility.
// This package handles the conversion of Codex API responses into OpenAI Chat Completions-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by OpenAI API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, reasoning content, and usage metadata appropriately.
package chat_completions
import (
"bytes"
"context"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
dataTag = []byte("data:")
)
// ConvertCliToOpenAIParams holds parameters for response conversion.
type ConvertCliToOpenAIParams struct {
ResponseID string
CreatedAt int64
Model string
FunctionCallIndex int
HasReceivedArgumentsDelta bool
HasToolCallAnnounced bool
}
// ConvertCodexResponseToOpenAI translates a single chunk of a streaming response from the
// Codex API format to the OpenAI Chat Completions streaming format.
// It processes various Codex event types and transforms them into OpenAI-compatible JSON responses.
// The function handles text content, tool calls, reasoning content, and usage metadata, outputting
// responses that match the OpenAI API format. It supports incremental updates for streaming responses.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertCliToOpenAIParams{
Model: modelName,
CreatedAt: 0,
ResponseID: "",
FunctionCallIndex: -1,
HasReceivedArgumentsDelta: false,
HasToolCallAnnounced: false,
}
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
rootResult := gjson.ParseBytes(rawJSON)
typeResult := rootResult.Get("type")
dataType := typeResult.String()
if dataType == "response.created" {
(*param).(*ConvertCliToOpenAIParams).ResponseID = rootResult.Get("response.id").String()
(*param).(*ConvertCliToOpenAIParams).CreatedAt = rootResult.Get("response.created_at").Int()
(*param).(*ConvertCliToOpenAIParams).Model = rootResult.Get("response.model").String()
return []string{}
}
// Extract and set the model version.
cachedModel := (*param).(*ConvertCliToOpenAIParams).Model
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
template, _ = sjson.Set(template, "model", modelResult.String())
} else if cachedModel != "" {
template, _ = sjson.Set(template, "model", cachedModel)
} else if modelName != "" {
template, _ = sjson.Set(template, "model", modelName)
}
template, _ = sjson.Set(template, "created", (*param).(*ConvertCliToOpenAIParams).CreatedAt)
// Extract and set the response ID.
template, _ = sjson.Set(template, "id", (*param).(*ConvertCliToOpenAIParams).ResponseID)
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usage"); usageResult.Exists() {
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
}
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
}
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
}
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
}
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
}
}
if dataType == "response.reasoning_summary_text.delta" {
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", deltaResult.String())
}
} else if dataType == "response.reasoning_summary_text.done" {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", "\n\n")
} else if dataType == "response.output_text.delta" {
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.Set(template, "choices.0.delta.content", deltaResult.String())
}
} else if dataType == "response.completed" {
finishReason := "stop"
if (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex != -1 {
finishReason = "tool_calls"
}
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
} else if dataType == "response.output_item.added" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
// Increment index for this new function call item.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = false
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = true
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", "")
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.delta" {
(*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta = true
deltaValue := rootResult.Get("delta").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", deltaValue)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.function_call_arguments.done" {
if (*param).(*ConvertCliToOpenAIParams).HasReceivedArgumentsDelta {
// Arguments were already streamed via delta events; nothing to emit.
return []string{}
}
// Fallback: no delta events were received, emit the full arguments as a single chunk.
fullArgs := rootResult.Get("arguments").String()
functionCallItemTemplate := `{"index":0,"function":{"arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fullArgs)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else if dataType == "response.output_item.done" {
itemResult := rootResult.Get("item")
if !itemResult.Exists() || itemResult.Get("type").String() != "function_call" {
return []string{}
}
if (*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced {
// Tool call was already announced via output_item.added; skip emission.
(*param).(*ConvertCliToOpenAIParams).HasToolCallAnnounced = false
return []string{}
}
// Fallback path: model skipped output_item.added, so emit complete tool call now.
(*param).(*ConvertCliToOpenAIParams).FunctionCallIndex++
functionCallItemTemplate := `{"index":0,"id":"","type":"function","function":{"name":"","arguments":""}}`
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "index", (*param).(*ConvertCliToOpenAIParams).FunctionCallIndex)
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
// Restore original tool name if it was shortened.
name := itemResult.Get("name").String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[name]; ok {
name = orig
}
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
} else {
return []string{}
}
return []string{template}
}
// ConvertCodexResponseToOpenAINonStream converts a non-streaming Codex response to a non-streaming OpenAI response.
// This function processes the complete Codex response and transforms it into a single OpenAI-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the OpenAI API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Codex API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
}
unixTimestamp := time.Now().Unix()
responseResult := rootResult.Get("response")
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelResult := responseResult.Get("model"); modelResult.Exists() {
template, _ = sjson.Set(template, "model", modelResult.String())
}
// Extract and set the creation timestamp.
if createdAtResult := responseResult.Get("created_at"); createdAtResult.Exists() {
template, _ = sjson.Set(template, "created", createdAtResult.Int())
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
// Extract and set the response ID.
if idResult := responseResult.Get("id"); idResult.Exists() {
template, _ = sjson.Set(template, "id", idResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := responseResult.Get("usage"); usageResult.Exists() {
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
}
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
}
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
}
if cachedTokensResult := usageResult.Get("input_tokens_details.cached_tokens"); cachedTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokensResult.Int())
}
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
}
}
// Process the output array for content and function calls
outputResult := responseResult.Get("output")
if outputResult.IsArray() {
outputArray := outputResult.Array()
var contentText string
var reasoningText string
var toolCalls []string
for _, outputItem := range outputArray {
outputType := outputItem.Get("type").String()
switch outputType {
case "reasoning":
// Extract reasoning content from summary
if summaryResult := outputItem.Get("summary"); summaryResult.IsArray() {
summaryArray := summaryResult.Array()
for _, summaryItem := range summaryArray {
if summaryItem.Get("type").String() == "summary_text" {
reasoningText = summaryItem.Get("text").String()
break
}
}
}
case "message":
// Extract message content
if contentResult := outputItem.Get("content"); contentResult.IsArray() {
contentArray := contentResult.Array()
for _, contentItem := range contentArray {
if contentItem.Get("type").String() == "output_text" {
contentText = contentItem.Get("text").String()
break
}
}
}
case "function_call":
// Handle function call content
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
if callIdResult := outputItem.Get("call_id"); callIdResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIdResult.String())
}
if nameResult := outputItem.Get("name"); nameResult.Exists() {
n := nameResult.String()
rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON)
if orig, ok := rev[n]; ok {
n = orig
}
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", n)
}
if argsResult := outputItem.Get("arguments"); argsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String())
}
toolCalls = append(toolCalls, functionCallTemplate)
}
}
// Set content and reasoning content if found
if contentText != "" {
template, _ = sjson.Set(template, "choices.0.message.content", contentText)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
}
if reasoningText != "" {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
}
// Add tool calls if any
if len(toolCalls) > 0 {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
for _, toolCall := range toolCalls {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
}
}
// Extract and set the finish reason based on status
if statusResult := responseResult.Get("status"); statusResult.Exists() {
status := statusResult.String()
if status == "completed" {
template, _ = sjson.Set(template, "choices.0.finish_reason", "stop")
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop")
}
}
return template
}
// buildReverseMapFromOriginalOpenAI builds a map of shortened tool name -> original tool name
// from the original OpenAI-style request JSON using the same shortening logic.
func buildReverseMapFromOriginalOpenAI(original []byte) map[string]string {
tools := gjson.GetBytes(original, "tools")
rev := map[string]string{}
if tools.IsArray() && len(tools.Array()) > 0 {
var names []string
arr := tools.Array()
for i := 0; i < len(arr); i++ {
t := arr[i]
if t.Get("type").String() != "function" {
continue
}
fn := t.Get("function")
if !fn.Exists() {
continue
}
if v := fn.Get("name"); v.Exists() {
names = append(names, v.String())
}
}
if len(names) > 0 {
m := buildShortNameMap(names)
for orig, short := range m {
rev[short] = orig
}
}
}
return rev
}
================================================
FILE: internal/translator/codex/openai/chat-completions/codex_openai_response_test.go
================================================
package chat_completions
import (
"context"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertCodexResponseToOpenAI_StreamSetsModelFromResponseCreated(t *testing.T) {
ctx := context.Background()
var param any
modelName := "gpt-5.3-codex"
out := ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.created","response":{"id":"resp_123","created_at":1700000000,"model":"gpt-5.3-codex"}}`), ¶m)
if len(out) != 0 {
t.Fatalf("expected no output for response.created, got %d chunks", len(out))
}
out = ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.output_text.delta","delta":"hello"}`), ¶m)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
gotModel := gjson.Get(out[0], "model").String()
if gotModel != modelName {
t.Fatalf("expected model %q, got %q", modelName, gotModel)
}
}
func TestConvertCodexResponseToOpenAI_FirstChunkUsesRequestModelName(t *testing.T) {
ctx := context.Background()
var param any
modelName := "gpt-5.3-codex"
out := ConvertCodexResponseToOpenAI(ctx, modelName, nil, nil, []byte(`data: {"type":"response.output_text.delta","delta":"hello"}`), ¶m)
if len(out) != 1 {
t.Fatalf("expected 1 chunk, got %d", len(out))
}
gotModel := gjson.Get(out[0], "model").String()
if gotModel != modelName {
t.Fatalf("expected model %q, got %q", modelName, gotModel)
}
}
================================================
FILE: internal/translator/codex/openai/chat-completions/init.go
================================================
package chat_completions
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenAI,
Codex,
ConvertOpenAIRequestToCodex,
interfaces.TranslateResponse{
Stream: ConvertCodexResponseToOpenAI,
NonStream: ConvertCodexResponseToOpenAINonStream,
},
)
}
================================================
FILE: internal/translator/codex/openai/responses/codex_openai-responses_request.go
================================================
package responses
import (
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
inputResult := gjson.GetBytes(rawJSON, "input")
if inputResult.Type == gjson.String {
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String())
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input))
}
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
rawJSON, _ = sjson.SetBytes(rawJSON, "store", false)
rawJSON, _ = sjson.SetBytes(rawJSON, "parallel_tool_calls", true)
rawJSON, _ = sjson.SetBytes(rawJSON, "include", []string{"reasoning.encrypted_content"})
// Codex Responses rejects token limit fields, so strip them out before forwarding.
rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_output_tokens")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "max_completion_tokens")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "temperature")
rawJSON, _ = sjson.DeleteBytes(rawJSON, "top_p")
if v := gjson.GetBytes(rawJSON, "service_tier"); v.Exists() {
if v.String() != "priority" {
rawJSON, _ = sjson.DeleteBytes(rawJSON, "service_tier")
}
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "truncation")
rawJSON = applyResponsesCompactionCompatibility(rawJSON)
// Delete the user field as it is not supported by the Codex upstream.
rawJSON, _ = sjson.DeleteBytes(rawJSON, "user")
// Convert role "system" to "developer" in input array to comply with Codex API requirements.
rawJSON = convertSystemRoleToDeveloper(rawJSON)
return rawJSON
}
// applyResponsesCompactionCompatibility handles OpenAI Responses context_management.compaction
// for Codex upstream compatibility.
//
// Codex /responses currently rejects context_management with:
// {"detail":"Unsupported parameter: context_management"}.
//
// Compatibility strategy:
// 1) Remove context_management before forwarding to Codex upstream.
func applyResponsesCompactionCompatibility(rawJSON []byte) []byte {
if !gjson.GetBytes(rawJSON, "context_management").Exists() {
return rawJSON
}
rawJSON, _ = sjson.DeleteBytes(rawJSON, "context_management")
return rawJSON
}
// convertSystemRoleToDeveloper traverses the input array and converts any message items
// with role "system" to role "developer". This is necessary because Codex API does not
// accept "system" role in the input array.
func convertSystemRoleToDeveloper(rawJSON []byte) []byte {
inputResult := gjson.GetBytes(rawJSON, "input")
if !inputResult.IsArray() {
return rawJSON
}
inputArray := inputResult.Array()
result := rawJSON
// Directly modify role values for items with "system" role
for i := 0; i < len(inputArray); i++ {
rolePath := fmt.Sprintf("input.%d.role", i)
if gjson.GetBytes(result, rolePath).String() == "system" {
result, _ = sjson.SetBytes(result, rolePath, "developer")
}
}
return result
}
================================================
FILE: internal/translator/codex/openai/responses/codex_openai-responses_request_test.go
================================================
package responses
import (
"testing"
"github.com/tidwall/gjson"
)
// TestConvertSystemRoleToDeveloper_BasicConversion tests the basic system -> developer role conversion
func TestConvertSystemRoleToDeveloper_BasicConversion(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"input": [
{
"type": "message",
"role": "system",
"content": [{"type": "input_text", "text": "You are a pirate."}]
},
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "Say hello."}]
}
]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Check that system role was converted to developer
firstItemRole := gjson.Get(outputStr, "input.0.role")
if firstItemRole.String() != "developer" {
t.Errorf("Expected role 'developer', got '%s'", firstItemRole.String())
}
// Check that user role remains unchanged
secondItemRole := gjson.Get(outputStr, "input.1.role")
if secondItemRole.String() != "user" {
t.Errorf("Expected role 'user', got '%s'", secondItemRole.String())
}
// Check content is preserved
firstItemContent := gjson.Get(outputStr, "input.0.content.0.text")
if firstItemContent.String() != "You are a pirate." {
t.Errorf("Expected content 'You are a pirate.', got '%s'", firstItemContent.String())
}
}
// TestConvertSystemRoleToDeveloper_MultipleSystemMessages tests conversion with multiple system messages
func TestConvertSystemRoleToDeveloper_MultipleSystemMessages(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"input": [
{
"type": "message",
"role": "system",
"content": [{"type": "input_text", "text": "You are helpful."}]
},
{
"type": "message",
"role": "system",
"content": [{"type": "input_text", "text": "Be concise."}]
},
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "Hello"}]
}
]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Check that both system roles were converted
firstRole := gjson.Get(outputStr, "input.0.role")
if firstRole.String() != "developer" {
t.Errorf("Expected first role 'developer', got '%s'", firstRole.String())
}
secondRole := gjson.Get(outputStr, "input.1.role")
if secondRole.String() != "developer" {
t.Errorf("Expected second role 'developer', got '%s'", secondRole.String())
}
// Check that user role is unchanged
thirdRole := gjson.Get(outputStr, "input.2.role")
if thirdRole.String() != "user" {
t.Errorf("Expected third role 'user', got '%s'", thirdRole.String())
}
}
// TestConvertSystemRoleToDeveloper_NoSystemMessages tests that requests without system messages are unchanged
func TestConvertSystemRoleToDeveloper_NoSystemMessages(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"input": [
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "Hello"}]
},
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Hi there!"}]
}
]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Check that user and assistant roles are unchanged
firstRole := gjson.Get(outputStr, "input.0.role")
if firstRole.String() != "user" {
t.Errorf("Expected role 'user', got '%s'", firstRole.String())
}
secondRole := gjson.Get(outputStr, "input.1.role")
if secondRole.String() != "assistant" {
t.Errorf("Expected role 'assistant', got '%s'", secondRole.String())
}
}
// TestConvertSystemRoleToDeveloper_EmptyInput tests that empty input arrays are handled correctly
func TestConvertSystemRoleToDeveloper_EmptyInput(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"input": []
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Check that input is still an empty array
inputArray := gjson.Get(outputStr, "input")
if !inputArray.IsArray() {
t.Error("Input should still be an array")
}
if len(inputArray.Array()) != 0 {
t.Errorf("Expected empty array, got %d items", len(inputArray.Array()))
}
}
// TestConvertSystemRoleToDeveloper_NoInputField tests that requests without input field are unchanged
func TestConvertSystemRoleToDeveloper_NoInputField(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"stream": false
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Check that other fields are still set correctly
stream := gjson.Get(outputStr, "stream")
if !stream.Bool() {
t.Error("Stream should be set to true by conversion")
}
store := gjson.Get(outputStr, "store")
if store.Bool() {
t.Error("Store should be set to false by conversion")
}
}
// TestConvertOpenAIResponsesRequestToCodex_OriginalIssue tests the exact issue reported by the user
func TestConvertOpenAIResponsesRequestToCodex_OriginalIssue(t *testing.T) {
// This is the exact input that was failing with "System messages are not allowed"
inputJSON := []byte(`{
"model": "gpt-5.2",
"input": [
{
"type": "message",
"role": "system",
"content": "You are a pirate. Always respond in pirate speak."
},
{
"type": "message",
"role": "user",
"content": "Say hello."
}
],
"stream": false
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Verify system role was converted to developer
firstRole := gjson.Get(outputStr, "input.0.role")
if firstRole.String() != "developer" {
t.Errorf("Expected role 'developer', got '%s'", firstRole.String())
}
// Verify stream was set to true (as required by Codex)
stream := gjson.Get(outputStr, "stream")
if !stream.Bool() {
t.Error("Stream should be set to true")
}
// Verify other required fields for Codex
store := gjson.Get(outputStr, "store")
if store.Bool() {
t.Error("Store should be false")
}
parallelCalls := gjson.Get(outputStr, "parallel_tool_calls")
if !parallelCalls.Bool() {
t.Error("parallel_tool_calls should be true")
}
include := gjson.Get(outputStr, "include")
if !include.IsArray() || len(include.Array()) != 1 {
t.Error("include should be an array with one element")
} else if include.Array()[0].String() != "reasoning.encrypted_content" {
t.Errorf("Expected include[0] to be 'reasoning.encrypted_content', got '%s'", include.Array()[0].String())
}
}
// TestConvertSystemRoleToDeveloper_AssistantRole tests that assistant role is preserved
func TestConvertSystemRoleToDeveloper_AssistantRole(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"input": [
{
"type": "message",
"role": "system",
"content": [{"type": "input_text", "text": "You are helpful."}]
},
{
"type": "message",
"role": "user",
"content": [{"type": "input_text", "text": "Hello"}]
},
{
"type": "message",
"role": "assistant",
"content": [{"type": "output_text", "text": "Hi!"}]
}
]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Check system -> developer
firstRole := gjson.Get(outputStr, "input.0.role")
if firstRole.String() != "developer" {
t.Errorf("Expected first role 'developer', got '%s'", firstRole.String())
}
// Check user unchanged
secondRole := gjson.Get(outputStr, "input.1.role")
if secondRole.String() != "user" {
t.Errorf("Expected second role 'user', got '%s'", secondRole.String())
}
// Check assistant unchanged
thirdRole := gjson.Get(outputStr, "input.2.role")
if thirdRole.String() != "assistant" {
t.Errorf("Expected third role 'assistant', got '%s'", thirdRole.String())
}
}
func TestUserFieldDeletion(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"user": "test-user",
"input": [{"role": "user", "content": "Hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
// Verify user field is deleted
userField := gjson.Get(outputStr, "user")
if userField.Exists() {
t.Errorf("user field should be deleted, but it was found with value: %s", userField.Raw)
}
}
func TestContextManagementCompactionCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"context_management": [
{
"type": "compaction",
"compact_threshold": 12000
}
],
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "context_management").Exists() {
t.Fatalf("context_management should be removed for Codex compatibility")
}
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}
func TestTruncationRemovedForCodexCompatibility(t *testing.T) {
inputJSON := []byte(`{
"model": "gpt-5.2",
"truncation": "disabled",
"input": [{"role":"user","content":"hello"}]
}`)
output := ConvertOpenAIResponsesRequestToCodex("gpt-5.2", inputJSON, false)
outputStr := string(output)
if gjson.Get(outputStr, "truncation").Exists() {
t.Fatalf("truncation should be removed for Codex compatibility")
}
}
================================================
FILE: internal/translator/codex/openai/responses/codex_openai-responses_response.go
================================================
package responses
import (
"bytes"
"context"
"fmt"
"github.com/tidwall/gjson"
)
// ConvertCodexResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
// to OpenAI Responses SSE events (response.*).
func ConvertCodexResponseToOpenAIResponses(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) []string {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
out := fmt.Sprintf("data: %s", string(rawJSON))
return []string{out}
}
return []string{string(rawJSON)}
}
// ConvertCodexResponseToOpenAIResponsesNonStream builds a single Responses JSON
// from a non-streaming OpenAI Chat Completions response.
func ConvertCodexResponseToOpenAIResponsesNonStream(_ context.Context, _ string, _, _, rawJSON []byte, _ *any) string {
rootResult := gjson.ParseBytes(rawJSON)
// Verify this is a response.completed event
if rootResult.Get("type").String() != "response.completed" {
return ""
}
responseResult := rootResult.Get("response")
return responseResult.Raw
}
================================================
FILE: internal/translator/codex/openai/responses/init.go
================================================
package responses
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenaiResponse,
Codex,
ConvertOpenAIResponsesRequestToCodex,
interfaces.TranslateResponse{
Stream: ConvertCodexResponseToOpenAIResponses,
NonStream: ConvertCodexResponseToOpenAIResponsesNonStream,
},
)
}
================================================
FILE: internal/translator/gemini/claude/gemini_claude_request.go
================================================
// Package claude provides request translation functionality for Claude API.
// It handles parsing and transforming Claude API requests into the internal client format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package also performs JSON data cleaning and transformation to ensure compatibility
// between Claude API format and the internal client's expected format.
package claude
import (
"bytes"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const geminiClaudeThoughtSignature = "skip_thought_signature_validator"
// ConvertClaudeRequestToGemini parses a Claude API request and returns a complete
// Gemini CLI request body (as JSON bytes) ready to be sent via SendRawMessageStream.
// All JSON transformations are performed using gjson/sjson.
//
// Parameters:
// - modelName: The name of the model.
// - rawJSON: The raw JSON request from the Claude API.
// - stream: A boolean indicating if the request is for a streaming response.
//
// Returns:
// - []byte: The transformed request in Gemini CLI format.
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Build output Gemini CLI request JSON
out := `{"contents":[]}`
out, _ = sjson.Set(out, "model", modelName)
// system instruction
if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemInstruction := `{"role":"user","parts":[]}`
hasSystemParts := false
systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
if systemPromptResult.Get("type").String() == "text" {
textResult := systemPromptResult.Get("text")
if textResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
hasSystemParts = true
}
}
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "system_instruction", systemInstruction)
}
} else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "system_instruction.parts.-1.text", systemResult.String())
}
// contents
if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() {
messagesResult.ForEach(func(_, messageResult gjson.Result) bool {
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
return true
}
role := roleResult.String()
if role == "assistant" {
role = "model"
}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
switch contentResult.Get("type").String() {
case "text":
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
case "tool_use":
functionName := contentResult.Get("name").String()
if toolUseID := contentResult.Get("id").String(); toolUseID != "" {
if derived := toolNameFromClaudeToolUseID(toolUseID); derived != "" {
functionName = derived
}
}
functionArgs := contentResult.Get("input").String()
argsResult := gjson.Parse(functionArgs)
if argsResult.IsObject() && gjson.Valid(functionArgs) {
part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
part, _ = sjson.Set(part, "thoughtSignature", geminiClaudeThoughtSignature)
part, _ = sjson.Set(part, "functionCall.name", functionName)
part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
}
case "tool_result":
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID == "" {
return true
}
funcName := toolNameFromClaudeToolUseID(toolCallID)
if funcName == "" {
funcName = toolCallID
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
case "image":
source := contentResult.Get("source")
if source.Get("type").String() != "base64" {
return true
}
mimeType := source.Get("media_type").String()
data := source.Get("data").String()
if mimeType == "" || data == "" {
return true
}
part := `{"inline_data":{"mime_type":"","data":""}}`
part, _ = sjson.Set(part, "inline_data.mime_type", mimeType)
part, _ = sjson.Set(part, "inline_data.data", data)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
}
return true
})
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "contents.-1", contentJSON)
}
return true
})
}
// tools
if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
hasTools := false
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict")
tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
tool, _ = sjson.Delete(tool, "defer_loading")
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
if !hasTools {
out, _ = sjson.SetRaw(out, "tools", `[{"functionDeclarations":[]}]`)
hasTools = true
}
out, _ = sjson.SetRaw(out, "tools.0.functionDeclarations.-1", tool)
}
}
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "tools")
}
}
// tool_choice
toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice")
if toolChoiceResult.Exists() {
toolChoiceType := ""
toolChoiceName := ""
if toolChoiceResult.IsObject() {
toolChoiceType = toolChoiceResult.Get("type").String()
toolChoiceName = toolChoiceResult.Get("name").String()
} else if toolChoiceResult.Type == gjson.String {
toolChoiceType = toolChoiceResult.String()
}
switch toolChoiceType {
case "auto":
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "AUTO")
case "none":
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "NONE")
case "any":
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY")
case "tool":
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" {
out, _ = sjson.Set(out, "toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
}
}
}
// Map Anthropic thinking -> Gemini thinking config when enabled
// Translator only does format conversion, ApplyThinking handles model capability validation.
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
switch t.Get("type").String() {
case "enabled":
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive", "auto":
// For adaptive thinking:
// - If output_config.effort is explicitly present, pass through as thinkingLevel.
// - Otherwise, treat it as "enabled with target-model maximum" and emit thinkingBudget=max.
// ApplyThinking handles clamping to target model's supported levels.
effort := ""
if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String {
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
maxBudget := 0
if mi := registry.LookupModelInfo(modelName, "gemini"); mi != nil && mi.Thinking != nil {
maxBudget = mi.Thinking.Max
}
if maxBudget > 0 {
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", maxBudget)
} else {
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingLevel", "high")
}
}
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.includeThoughts", true)
}
}
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "generationConfig.temperature", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "generationConfig.topP", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "generationConfig.topK", v.Num)
}
result := []byte(out)
result = common.AttachDefaultSafetySettings(result, "safetySettings")
return result
}
func toolNameFromClaudeToolUseID(toolUseID string) string {
parts := strings.Split(toolUseID, "-")
if len(parts) <= 1 {
return ""
}
return strings.Join(parts[0:len(parts)-1], "-")
}
================================================
FILE: internal/translator/gemini/claude/gemini_claude_request_test.go
================================================
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToGemini_ToolChoice_SpecificTool(t *testing.T) {
inputJSON := []byte(`{
"model": "gemini-3-flash-preview",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hi"}
]
}
],
"tools": [
{
"name": "json",
"description": "A JSON tool",
"input_schema": {
"type": "object",
"properties": {}
}
}
],
"tool_choice": {"type": "tool", "name": "json"}
}`)
output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false)
if got := gjson.GetBytes(output, "toolConfig.functionCallingConfig.mode").String(); got != "ANY" {
t.Fatalf("Expected toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got)
}
allowed := gjson.GetBytes(output, "toolConfig.functionCallingConfig.allowedFunctionNames").Array()
if len(allowed) != 1 || allowed[0].String() != "json" {
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
}
}
func TestConvertClaudeRequestToGemini_ImageContent(t *testing.T) {
inputJSON := []byte(`{
"model": "gemini-3-flash-preview",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "describe this image"},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "aGVsbG8="
}
}
]
}
]
}`)
output := ConvertClaudeRequestToGemini("gemini-3-flash-preview", inputJSON, false)
parts := gjson.GetBytes(output, "contents.0.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(parts))
}
if got := parts[0].Get("text").String(); got != "describe this image" {
t.Fatalf("Expected first part text 'describe this image', got '%s'", got)
}
if got := parts[1].Get("inline_data.mime_type").String(); got != "image/png" {
t.Fatalf("Expected image mime type 'image/png', got '%s'", got)
}
if got := parts[1].Get("inline_data.data").String(); got != "aGVsbG8=" {
t.Fatalf("Expected image data 'aGVsbG8=', got '%s'", got)
}
}
================================================
FILE: internal/translator/gemini/claude/gemini_claude_response.go
================================================
// Package claude provides response translation functionality for Claude API.
// This package handles the conversion of backend client responses into Claude-compatible
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
// different response types including text content, thinking processes, and function calls.
// The translation ensures proper sequencing of SSE events and maintains state across
// multiple response chunks to provide a seamless streaming experience.
package claude
import (
"bytes"
"context"
"fmt"
"strings"
"sync/atomic"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Params holds parameters for response conversion.
type Params struct {
IsGlAPIKey bool
HasFirstResponse bool
ResponseType int
ResponseIndex int
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
ToolNameMap map[string]string
SawToolCall bool
}
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
var toolUseIDCounter uint64
// ConvertGeminiResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
// and handles state transitions between content blocks, thinking processes, and function calls.
//
// Response type states: 0=none, 1=content, 2=thinking, 3=function
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the Gemini API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - []string: A slice of strings, each containing a Claude-compatible JSON response.
func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &Params{
IsGlAPIKey: false,
HasFirstResponse: false,
ResponseType: 0,
ResponseIndex: 0,
ToolNameMap: util.ToolNameMapFromClaudeRequest(originalRequestRawJSON),
SawToolCall: false,
}
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
// Only send message_stop if we have actually output content
if (*param).(*Params).HasContent {
return []string{
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
}
}
return []string{}
}
output := ""
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk
if !(*param).(*Params).HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values
// This follows the Claude API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Override default values with actual response metadata if available
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
(*param).(*Params).HasFirstResponse = true
}
// Process the response parts array from the backend client
// Each part can contain text content, thinking content, or function calls
partsResult := gjson.GetBytes(rawJSON, "candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
// Extract the different types of content from each part
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
// Handle text content (both regular content and thinking)
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
// Continue existing thinking block
if (*param).(*Params).ResponseType == 2 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).HasContent = true
} else {
// Transition from another state to thinking
// First, close any existing content block
if (*param).(*Params).ResponseType != 0 {
if (*param).(*Params).ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).ResponseType = 2 // Set state to thinking
(*param).(*Params).HasContent = true
}
} else {
// Process regular text content (user-visible output)
// Continue existing text block
if (*param).(*Params).ResponseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).HasContent = true
} else {
// Transition from another state to text content
// First, close any existing content block
if (*param).(*Params).ResponseType != 0 {
if (*param).(*Params).ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).ResponseType = 1 // Set state to content
(*param).(*Params).HasContent = true
}
}
} else if functionCallResult.Exists() {
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude API compatibility
(*param).(*Params).SawToolCall = true
upstreamToolName := functionCallResult.Get("name").String()
clientToolName := util.MapToolName((*param).(*Params).ToolNameMap, upstreamToolName)
// FIX: Handle streaming split/delta where name might be empty in subsequent chunks.
// If we are already in tool use mode and name is empty, treat as continuation (delta).
if (*param).(*Params).ResponseType == 3 && upstreamToolName == "" {
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
}
// Continue to next part without closing/opening logic
continue
}
// Handle state transitions when switching to function calls
// Close any existing function call block first
if (*param).(*Params).ResponseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++
(*param).(*Params).ResponseType = 0
}
// Special handling for thinking state transition
if (*param).(*Params).ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
// Close any other existing content block
if (*param).(*Params).ResponseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", clientToolName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
}
(*param).(*Params).ResponseType = 3
(*param).(*Params).HasContent = true
}
}
}
usageResult := gjson.GetBytes(rawJSON, "usageMetadata")
if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
// Only send final events if we have actually output content
if (*param).(*Params).HasContent {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: message_delta\n"
output = output + `data: `
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
if (*param).(*Params).SawToolCall {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
} else if finish := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" {
template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
}
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n"
}
}
}
return []string{output}
}
// ConvertGeminiResponseToClaudeNonStream converts a non-streaming Gemini response to a non-streaming Claude response.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the Gemini API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Claude-compatible JSON response.
func ConvertGeminiResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
_ = requestRawJSON
root := gjson.ParseBytes(rawJSON)
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("responseId").String())
out, _ = sjson.Set(out, "model", root.Get("modelVersion").String())
inputTokens := root.Get("usageMetadata.promptTokenCount").Int()
outputTokens := root.Get("usageMetadata.candidatesTokenCount").Int() + root.Get("usageMetadata.thoughtsTokenCount").Int()
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
parts := root.Get("candidates.0.content.parts")
textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{}
toolIDCounter := 0
hasToolCall := false
flushText := func() {
if textBuilder.Len() == 0 {
return
}
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
textBuilder.Reset()
}
flushThinking := func() {
if thinkingBuilder.Len() == 0 {
return
}
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
thinkingBuilder.Reset()
}
if parts.IsArray() {
for _, part := range parts.Array() {
if text := part.Get("text"); text.Exists() && text.String() != "" {
if part.Get("thought").Bool() {
flushText()
thinkingBuilder.WriteString(text.String())
continue
}
flushThinking()
textBuilder.WriteString(text.String())
continue
}
if functionCall := part.Get("functionCall"); functionCall.Exists() {
flushThinking()
flushText()
hasToolCall = true
upstreamToolName := functionCall.Get("name").String()
clientToolName := util.MapToolName(toolNameMap, upstreamToolName)
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d", upstreamToolName, toolIDCounter)))
toolBlock, _ = sjson.Set(toolBlock, "name", clientToolName)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
}
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
continue
}
}
}
flushThinking()
flushText()
stopReason := "end_turn"
if hasToolCall {
stopReason = "tool_use"
} else {
if finish := root.Get("candidates.0.finishReason"); finish.Exists() {
switch finish.String() {
case "MAX_TOKENS":
stopReason = "max_tokens"
case "STOP", "FINISH_REASON_UNSPECIFIED", "UNKNOWN":
stopReason = "end_turn"
default:
stopReason = "end_turn"
}
}
}
out, _ = sjson.Set(out, "stop_reason", stopReason)
if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("usageMetadata").Exists() {
out, _ = sjson.Delete(out, "usage")
}
return out
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
}
================================================
FILE: internal/translator/gemini/claude/init.go
================================================
package claude
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Claude,
Gemini,
ConvertClaudeRequestToGemini,
interfaces.TranslateResponse{
Stream: ConvertGeminiResponseToClaude,
NonStream: ConvertGeminiResponseToClaudeNonStream,
TokenCount: ClaudeTokenCount,
},
)
}
================================================
FILE: internal/translator/gemini/common/safety.go
================================================
package common
import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// DefaultSafetySettings returns the default Gemini safety configuration we attach to requests.
func DefaultSafetySettings() []map[string]string {
return []map[string]string{
{
"category": "HARM_CATEGORY_HARASSMENT",
"threshold": "OFF",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"threshold": "OFF",
},
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"threshold": "OFF",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"threshold": "OFF",
},
{
"category": "HARM_CATEGORY_CIVIC_INTEGRITY",
"threshold": "BLOCK_NONE",
},
}
}
// AttachDefaultSafetySettings ensures the default safety settings are present when absent.
// The caller must provide the target JSON path (e.g. "safetySettings" or "request.safetySettings").
func AttachDefaultSafetySettings(rawJSON []byte, path string) []byte {
if gjson.GetBytes(rawJSON, path).Exists() {
return rawJSON
}
out, err := sjson.SetBytes(rawJSON, path, DefaultSafetySettings())
if err != nil {
return rawJSON
}
return out
}
================================================
FILE: internal/translator/gemini/gemini/gemini_gemini_request.go
================================================
// Package gemini provides in-provider request normalization for Gemini API.
// It ensures incoming v1beta requests meet minimal schema requirements
// expected by Google's Generative Language API.
package gemini
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiRequestToGemini normalizes Gemini v1beta requests.
// - Adds a default role for each content if missing or invalid.
// The first message defaults to "user", then alternates user/model when needed.
//
// It keeps the payload otherwise unchanged.
func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
// Fast path: if no contents field, only attach safety settings
contents := gjson.GetBytes(rawJSON, "contents")
if !contents.Exists() {
return common.AttachDefaultSafetySettings(rawJSON, "safetySettings")
}
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.Exists() && toolsResult.IsArray() {
toolResults := toolsResult.Array()
for i := 0; i < len(toolResults); i++ {
if gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.functionDeclarations", i)).Exists() {
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("tools.%d.functionDeclarations", i), fmt.Sprintf("tools.%d.function_declarations", i))
rawJSON = []byte(strJson)
}
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.function_declarations", i))
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() {
functionDeclarationsResults := functionDeclarationsResult.Array()
for j := 0; j < len(functionDeclarationsResults); j++ {
parametersResult := gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.function_declarations.%d.parameters", i, j))
if parametersResult.Exists() {
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("tools.%d.function_declarations.%d.parameters", i, j), fmt.Sprintf("tools.%d.function_declarations.%d.parametersJsonSchema", i, j))
rawJSON = []byte(strJson)
}
}
}
}
}
// Walk contents and fix roles
out := rawJSON
prevRole := ""
idx := 0
contents.ForEach(func(_ gjson.Result, value gjson.Result) bool {
role := value.Get("role").String()
// Only user/model are valid for Gemini v1beta requests
valid := role == "user" || role == "model"
if role == "" || !valid {
var newRole string
if prevRole == "" {
newRole = "user"
} else if prevRole == "user" {
newRole = "model"
} else {
newRole = "user"
}
path := fmt.Sprintf("contents.%d.role", idx)
out, _ = sjson.SetBytes(out, path, newRole)
role = newRole
}
prevRole = role
idx++
return true
})
gjson.GetBytes(out, "contents").ForEach(func(key, content gjson.Result) bool {
if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
out, _ = sjson.SetBytes(out, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
} else if part.Get("thoughtSignature").Exists() {
out, _ = sjson.SetBytes(out, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
}
return true
})
}
return true
})
if gjson.GetBytes(rawJSON, "generationConfig.responseSchema").Exists() {
strJson, _ := util.RenameKey(string(out), "generationConfig.responseSchema", "generationConfig.responseJsonSchema")
out = []byte(strJson)
}
// Backfill empty functionResponse.name from the preceding functionCall.name.
// Amp may send function responses with empty names; the Gemini API rejects these.
out = backfillEmptyFunctionResponseNames(out)
out = common.AttachDefaultSafetySettings(out, "safetySettings")
return out
}
// backfillEmptyFunctionResponseNames walks the contents array and for each
// model turn containing functionCall parts, records the call names in order.
// For the immediately following user/function turn containing functionResponse
// parts, any empty name is replaced with the corresponding call name.
func backfillEmptyFunctionResponseNames(data []byte) []byte {
contents := gjson.GetBytes(data, "contents")
if !contents.Exists() {
return data
}
out := data
var pendingCallNames []string
contents.ForEach(func(contentIdx, content gjson.Result) bool {
role := content.Get("role").String()
// Collect functionCall names from model turns
if role == "model" {
var names []string
content.Get("parts").ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
names = append(names, part.Get("functionCall.name").String())
}
return true
})
if len(names) > 0 {
pendingCallNames = names
} else {
pendingCallNames = nil
}
return true
}
// Backfill empty functionResponse names from pending call names
if len(pendingCallNames) > 0 {
ri := 0
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
name := part.Get("functionResponse.name").String()
if strings.TrimSpace(name) == "" {
if ri < len(pendingCallNames) {
out, _ = sjson.SetBytes(out,
fmt.Sprintf("contents.%d.parts.%d.functionResponse.name", contentIdx.Int(), partIdx.Int()),
pendingCallNames[ri])
} else {
log.Debugf("more function responses than calls at contents[%d], skipping name backfill", contentIdx.Int())
}
}
ri++
}
return true
})
pendingCallNames = nil
}
return true
})
return out
}
================================================
FILE: internal/translator/gemini/gemini/gemini_gemini_request_test.go
================================================
package gemini
import (
"testing"
"github.com/tidwall/gjson"
)
func TestBackfillEmptyFunctionResponseNames_Single(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
}
}
func TestBackfillEmptyFunctionResponseNames_Parallel(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {"path": "/a"}}},
{"functionCall": {"name": "Grep", "args": {"pattern": "x"}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "content a"}}},
{"functionResponse": {"name": "", "response": {"result": "match x"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
name1 := gjson.GetBytes(out, "contents.1.parts.1.functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second name 'Grep', got '%s'", name1)
}
}
func TestBackfillEmptyFunctionResponseNames_PreservesExisting(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "Bash", "response": {"result": "ok"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected preserved name 'Bash', got '%s'", name)
}
}
func TestConvertGeminiRequestToGemini_BackfillsEmptyName(t *testing.T) {
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {"cmd": "ls"}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"output": "file1.txt"}}}
]
}
]
}`)
out := ConvertGeminiRequestToGemini("", input, false)
name := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
if name != "Bash" {
t.Errorf("Expected backfilled name 'Bash', got '%s'", name)
}
}
func TestBackfillEmptyFunctionResponseNames_MoreResponsesThanCalls(t *testing.T) {
// Extra responses beyond the call count should not panic and should be left unchanged.
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Bash", "args": {}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "ok"}}},
{"functionResponse": {"name": "", "response": {"result": "extra"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
if name0 != "Bash" {
t.Errorf("Expected first name 'Bash', got '%s'", name0)
}
// Second response has no matching call, should remain empty
name1 := gjson.GetBytes(out, "contents.1.parts.1.functionResponse.name").String()
if name1 != "" {
t.Errorf("Expected second name to remain empty, got '%s'", name1)
}
}
func TestBackfillEmptyFunctionResponseNames_MultipleGroups(t *testing.T) {
// Two sequential call/response groups should each get correct names.
input := []byte(`{
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "Read", "args": {}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "content"}}}
]
},
{
"role": "model",
"parts": [
{"functionCall": {"name": "Grep", "args": {}}}
]
},
{
"role": "user",
"parts": [
{"functionResponse": {"name": "", "response": {"result": "match"}}}
]
}
]
}`)
out := backfillEmptyFunctionResponseNames(input)
name0 := gjson.GetBytes(out, "contents.1.parts.0.functionResponse.name").String()
name1 := gjson.GetBytes(out, "contents.3.parts.0.functionResponse.name").String()
if name0 != "Read" {
t.Errorf("Expected first group name 'Read', got '%s'", name0)
}
if name1 != "Grep" {
t.Errorf("Expected second group name 'Grep', got '%s'", name1)
}
}
================================================
FILE: internal/translator/gemini/gemini/gemini_gemini_response.go
================================================
package gemini
import (
"bytes"
"context"
"fmt"
)
// PassthroughGeminiResponseStream forwards Gemini responses unchanged.
func PassthroughGeminiResponseStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
return []string{string(rawJSON)}
}
// PassthroughGeminiResponseNonStream forwards Gemini responses unchanged.
func PassthroughGeminiResponseNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
return string(rawJSON)
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}
================================================
FILE: internal/translator/gemini/gemini/init.go
================================================
package gemini
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
// Register a no-op response translator and a request normalizer for Gemini→Gemini.
// The request converter ensures missing or invalid roles are normalized to valid values.
func init() {
translator.Register(
Gemini,
Gemini,
ConvertGeminiRequestToGemini,
interfaces.TranslateResponse{
Stream: PassthroughGeminiResponseStream,
NonStream: PassthroughGeminiResponseNonStream,
TokenCount: GeminiTokenCount,
},
)
}
================================================
FILE: internal/translator/gemini/gemini-cli/gemini_gemini-cli_request.go
================================================
// Package gemini provides request translation functionality for Claude API.
// It handles parsing and transforming Claude API requests into the internal client format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package also performs JSON data cleaning and transformation to ensure compatibility
// between Claude API format and the internal client's expected format.
package geminiCLI
import (
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// PrepareClaudeRequest parses and transforms a Claude API request into internal client format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the internal client.
func ConvertGeminiCLIRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
modelResult := gjson.GetBytes(rawJSON, "model")
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
if gjson.GetBytes(rawJSON, "systemInstruction").Exists() {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw))
rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction")
}
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.Exists() && toolsResult.IsArray() {
toolResults := toolsResult.Array()
for i := 0; i < len(toolResults); i++ {
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.function_declarations", i))
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() {
functionDeclarationsResults := functionDeclarationsResult.Array()
for j := 0; j < len(functionDeclarationsResults); j++ {
parametersResult := gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.function_declarations.%d.parameters", i, j))
if parametersResult.Exists() {
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("tools.%d.function_declarations.%d.parameters", i, j), fmt.Sprintf("tools.%d.function_declarations.%d.parametersJsonSchema", i, j))
rawJSON = []byte(strJson)
}
}
}
}
}
gjson.GetBytes(rawJSON, "contents").ForEach(func(key, content gjson.Result) bool {
if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
} else if part.Get("thoughtSignature").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
}
return true
})
}
return true
})
return common.AttachDefaultSafetySettings(rawJSON, "safetySettings")
}
================================================
FILE: internal/translator/gemini/gemini-cli/gemini_gemini-cli_response.go
================================================
// Package gemini_cli provides response translation functionality for Gemini API to Gemini CLI API.
// This package handles the conversion of Gemini API responses into Gemini CLI-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by Gemini CLI API clients.
package geminiCLI
import (
"bytes"
"context"
"fmt"
"github.com/tidwall/sjson"
)
var dataTag = []byte("data:")
// ConvertGeminiResponseToGeminiCLI converts Gemini streaming response format to Gemini CLI single-line JSON format.
// This function processes various Gemini event types and transforms them into Gemini CLI-compatible JSON responses.
// It handles thinking content, regular text content, and function calls, outputting single-line JSON
// that matches the Gemini CLI API response format.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the Gemini API.
// - param: A pointer to a parameter object for the conversion (unused).
//
// Returns:
// - []string: A slice of strings, each containing a Gemini CLI-compatible JSON response.
func ConvertGeminiResponseToGeminiCLI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
json := `{"response": {}}`
rawJSON, _ = sjson.SetRawBytes([]byte(json), "response", rawJSON)
return []string{string(rawJSON)}
}
// ConvertGeminiResponseToGeminiCLINonStream converts a non-streaming Gemini response to a non-streaming Gemini CLI response.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the Gemini API.
// - param: A pointer to a parameter object for the conversion (unused).
//
// Returns:
// - string: A Gemini CLI-compatible JSON response.
func ConvertGeminiResponseToGeminiCLINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
json := `{"response": {}}`
rawJSON, _ = sjson.SetRawBytes([]byte(json), "response", rawJSON)
return string(rawJSON)
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}
================================================
FILE: internal/translator/gemini/gemini-cli/init.go
================================================
package geminiCLI
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
GeminiCLI,
Gemini,
ConvertGeminiCLIRequestToGemini,
interfaces.TranslateResponse{
Stream: ConvertGeminiResponseToGeminiCLI,
NonStream: ConvertGeminiResponseToGeminiCLINonStream,
TokenCount: GeminiCLITokenCount,
},
)
}
================================================
FILE: internal/translator/gemini/openai/chat-completions/gemini_openai_request.go
================================================
// Package openai provides request translation functionality for OpenAI to Gemini API compatibility.
// It converts OpenAI Chat Completions requests into Gemini compatible JSON using gjson/sjson only.
package chat_completions
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const geminiFunctionThoughtSignature = "skip_thought_signature_validator"
// ConvertOpenAIRequestToGemini converts an OpenAI Chat Completions request (raw JSON)
// into a complete Gemini request JSON. All JSON construction uses sjson and lookups use gjson.
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the OpenAI API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini API format
func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
// Base envelope (no default thinkingConfig)
out := []byte(`{"contents":[]}`)
// Model
out, _ = sjson.SetBytes(out, "model", modelName)
// Let user-provided generationConfig pass through
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
out, _ = sjson.SetRawBytes(out, "generationConfig", []byte(genConfig.Raw))
}
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini thinkingConfig.
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
re := gjson.GetBytes(rawJSON, "reasoning_effort")
if re.Exists() {
effort := strings.ToLower(strings.TrimSpace(re.String()))
if effort != "" {
thinkingPath := "generationConfig.thinkingConfig"
if effort == "auto" {
out, _ = sjson.SetBytes(out, thinkingPath+".thinkingBudget", -1)
out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", true)
} else {
out, _ = sjson.SetBytes(out, thinkingPath+".thinkingLevel", effort)
out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", effort != "none")
}
}
}
// Temperature/top_p/top_k
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "generationConfig.temperature", tr.Num)
}
if tpr := gjson.GetBytes(rawJSON, "top_p"); tpr.Exists() && tpr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "generationConfig.topP", tpr.Num)
}
if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "generationConfig.topK", tkr.Num)
}
// Candidate count (OpenAI 'n' parameter)
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
if val := n.Int(); val > 1 {
out, _ = sjson.SetBytes(out, "generationConfig.candidateCount", val)
}
}
// Map OpenAI modalities -> Gemini generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
var responseMods []string
for _, m := range mods.Array() {
switch strings.ToLower(m.String()) {
case "text":
responseMods = append(responseMods, "TEXT")
case "image":
responseMods = append(responseMods, "IMAGE")
}
}
if len(responseMods) > 0 {
out, _ = sjson.SetBytes(out, "generationConfig.responseModalities", responseMods)
}
}
// OpenRouter-style image_config support
// If the input uses top-level image_config.aspect_ratio, map it into generationConfig.imageConfig.aspectRatio.
if imgCfg := gjson.GetBytes(rawJSON, "image_config"); imgCfg.Exists() && imgCfg.IsObject() {
if ar := imgCfg.Get("aspect_ratio"); ar.Exists() && ar.Type == gjson.String {
out, _ = sjson.SetBytes(out, "generationConfig.imageConfig.aspectRatio", ar.Str)
}
if size := imgCfg.Get("image_size"); size.Exists() && size.Type == gjson.String {
out, _ = sjson.SetBytes(out, "generationConfig.imageConfig.imageSize", size.Str)
}
}
// messages -> systemInstruction + contents
messages := gjson.GetBytes(rawJSON, "messages")
if messages.IsArray() {
arr := messages.Array()
// First pass: assistant tool_calls id->name map
tcID2Name := map[string]string{}
for i := 0; i < len(arr); i++ {
m := arr[i]
if m.Get("role").String() == "assistant" {
tcs := m.Get("tool_calls")
if tcs.IsArray() {
for _, tc := range tcs.Array() {
if tc.Get("type").String() == "function" {
id := tc.Get("id").String()
name := tc.Get("function.name").String()
if id != "" && name != "" {
tcID2Name[id] = name
}
}
}
}
}
}
// Second pass build systemInstruction/tool responses cache
toolResponses := map[string]string{} // tool_call_id -> response text
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
if role == "tool" {
toolCallID := m.Get("tool_call_id").String()
if toolCallID != "" {
c := m.Get("content")
toolResponses[toolCallID] = c.Raw
}
}
}
systemPartIndex := 0
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
content := m.Get("content")
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.String())
systemPartIndex++
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
systemPartIndex++
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
systemPartIndex++
}
}
}
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
} else if content.IsArray() {
items := content.Array()
p := 0
for _, item := range items {
switch item.Get("type").String() {
case "text":
text := item.Get("text").String()
if text != "" {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text)
}
p++
case "image_url":
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 {
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
p++
}
}
case "file":
filename := item.Get("file.filename").String()
fileData := item.Get("file.file_data").String()
ext := ""
if sp := strings.Split(filename, "."); len(sp) > 1 {
ext = sp[len(sp)-1]
}
if mimeType, ok := misc.MimeTypes[ext]; ok {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
p++
} else {
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
}
}
}
}
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
} else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`)
p := 0
if content.Type == gjson.String {
// Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
text := item.Get("text").String()
if text != "" {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", text)
}
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
p++
}
}
}
}
}
// Tool calls -> single model content with functionCall parts
tcs := m.Get("tool_calls")
if tcs.IsArray() {
fIDs := make([]string, 0)
for _, tc := range tcs.Array() {
if tc.Get("type").String() != "function" {
continue
}
fid := tc.Get("id").String()
fname := tc.Get("function.name").String()
fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
p++
if fid != "" {
fIDs = append(fIDs, fid)
}
}
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
resp := toolResponses[fid]
if resp == "" {
resp = "{}"
}
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
pp++
}
}
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
}
}
}
}
// tools -> tools[].functionDeclarations + tools[].googleSearch/codeExecution/urlContext passthrough
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
functionToolNode := []byte(`{}`)
hasFunction := false
googleSearchNodes := make([][]byte, 0)
codeExecutionNodes := make([][]byte, 0)
urlContextNodes := make([][]byte, 0)
for _, t := range tools.Array() {
if t.Get("type").String() == "function" {
fn := t.Get("function")
if fn.Exists() && fn.IsObject() {
fnRaw := fn.Raw
if fn.Get("parameters").Exists() {
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema")
if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
} else {
fnRaw = renamed
}
} else {
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction {
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
functionToolNode = tmp
hasFunction = true
}
}
if gs := t.Get("google_search"); gs.Exists() {
googleToolNode := []byte(`{}`)
var errSet error
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue
}
googleSearchNodes = append(googleSearchNodes, googleToolNode)
}
if ce := t.Get("code_execution"); ce.Exists() {
codeToolNode := []byte(`{}`)
var errSet error
codeToolNode, errSet = sjson.SetRawBytes(codeToolNode, "codeExecution", []byte(ce.Raw))
if errSet != nil {
log.Warnf("Failed to set codeExecution tool: %v", errSet)
continue
}
codeExecutionNodes = append(codeExecutionNodes, codeToolNode)
}
if uc := t.Get("url_context"); uc.Exists() {
urlToolNode := []byte(`{}`)
var errSet error
urlToolNode, errSet = sjson.SetRawBytes(urlToolNode, "urlContext", []byte(uc.Raw))
if errSet != nil {
log.Warnf("Failed to set urlContext tool: %v", errSet)
continue
}
urlContextNodes = append(urlContextNodes, urlToolNode)
}
}
if hasFunction || len(googleSearchNodes) > 0 || len(codeExecutionNodes) > 0 || len(urlContextNodes) > 0 {
toolsNode := []byte("[]")
if hasFunction {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
}
for _, googleNode := range googleSearchNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
}
for _, codeNode := range codeExecutionNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", codeNode)
}
for _, urlNode := range urlContextNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", urlNode)
}
out, _ = sjson.SetRawBytes(out, "tools", toolsNode)
}
}
out = common.AttachDefaultSafetySettings(out, "safetySettings")
return out
}
// itoa converts int to string without strconv import for few usages.
func itoa(i int) string { return fmt.Sprintf("%d", i) }
================================================
FILE: internal/translator/gemini/openai/chat-completions/gemini_openai_response.go
================================================
// Package openai provides response translation functionality for Gemini to OpenAI API compatibility.
// This package handles the conversion of Gemini API responses into OpenAI Chat Completions-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by OpenAI API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, reasoning content, and usage metadata appropriately.
package chat_completions
import (
"bytes"
"context"
"fmt"
"strings"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// convertGeminiResponseToOpenAIChatParams holds parameters for response conversion.
type convertGeminiResponseToOpenAIChatParams struct {
UnixTimestamp int64
// FunctionIndex tracks tool call indices per candidate index to support multiple candidates.
FunctionIndex map[int]int
}
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
var functionCallIDCounter uint64
// ConvertGeminiResponseToOpenAI translates a single chunk of a streaming response from the
// Gemini API format to the OpenAI Chat Completions streaming format.
// It processes various Gemini event types and transforms them into OpenAI-compatible JSON responses.
// The function handles text content, tool calls, reasoning content, and usage metadata, outputting
// responses that match the OpenAI API format. It supports incremental updates for streaming responses.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Gemini API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
// Initialize parameters if nil.
if *param == nil {
*param = &convertGeminiResponseToOpenAIChatParams{
UnixTimestamp: 0,
FunctionIndex: make(map[int]int),
}
}
// Ensure the Map is initialized (handling cases where param might be reused from older context).
p := (*param).(*convertGeminiResponseToOpenAIChatParams)
if p.FunctionIndex == nil {
p.FunctionIndex = make(map[int]int)
}
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
// Initialize the OpenAI SSE base template.
// We use a base template and clone it for each candidate to support multiple candidates.
baseTemplate := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
baseTemplate, _ = sjson.Set(baseTemplate, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJSON, "createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
p.UnixTimestamp = t.Unix()
}
baseTemplate, _ = sjson.Set(baseTemplate, "created", p.UnixTimestamp)
} else {
baseTemplate, _ = sjson.Set(baseTemplate, "created", p.UnixTimestamp)
}
// Extract and set the response ID.
if responseIDResult := gjson.GetBytes(rawJSON, "responseId"); responseIDResult.Exists() {
baseTemplate, _ = sjson.Set(baseTemplate, "id", responseIDResult.String())
}
// Extract and set usage metadata (token counts).
// Usage is applied to the base template so it appears in the chunks.
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
baseTemplate, err = sjson.Set(baseTemplate, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err)
}
}
}
var responseStrings []string
candidates := gjson.GetBytes(rawJSON, "candidates")
// Iterate over all candidates to support candidate_count > 1.
if candidates.IsArray() {
candidates.ForEach(func(_, candidate gjson.Result) bool {
// Clone the template for the current candidate.
template := baseTemplate
// Set the specific index for this candidate.
candidateIndex := int(candidate.Get("index").Int())
template, _ = sjson.Set(template, "choices.0.index", candidateIndex)
finishReason := ""
if stopReasonResult := gjson.GetBytes(rawJSON, "stop_reason"); stopReasonResult.Exists() {
finishReason = stopReasonResult.String()
}
if finishReason == "" {
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
finishReason = finishReasonResult.String()
}
}
finishReason = strings.ToLower(finishReason)
partsResult := candidate.Get("content.parts")
hasFunctionCall := false
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
inlineDataResult := partResult.Get("inlineData")
if !inlineDataResult.Exists() {
inlineDataResult = partResult.Get("inline_data")
}
thoughtSignatureResult := partResult.Get("thoughtSignature")
if !thoughtSignatureResult.Exists() {
thoughtSignatureResult = partResult.Get("thought_signature")
}
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
// Skip pure thoughtSignature parts but keep any actual payload in the same part.
if hasThoughtSignature && !hasContentPayload {
continue
}
if partTextResult.Exists() {
text := partTextResult.String()
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", text)
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", text)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
hasFunctionCall = true
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
// Retrieve the function index for this specific candidate.
functionCallIndex := p.FunctionIndex[candidateIndex]
p.FunctionIndex[candidateIndex]++
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
functionCallIndex = len(toolCallsResult.Array())
} else {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
}
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
} else if inlineDataResult.Exists() {
data := inlineDataResult.Get("data").String()
if data == "" {
continue
}
mimeType := inlineDataResult.Get("mimeType").String()
if mimeType == "" {
mimeType = inlineDataResult.Get("mime_type").String()
}
if mimeType == "" {
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}
}
}
if hasFunctionCall {
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
} else if finishReason != "" {
// Only pass through specific finish reasons
if finishReason == "max_tokens" || finishReason == "stop" {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
}
}
responseStrings = append(responseStrings, template)
return true // continue loop
})
} else {
// If there are no candidates (e.g., a pure usageMetadata chunk), return the usage chunk if present.
if gjson.GetBytes(rawJSON, "usageMetadata").Exists() && len(responseStrings) == 0 {
responseStrings = append(responseStrings, baseTemplate)
}
}
return responseStrings
}
// ConvertGeminiResponseToOpenAINonStream converts a non-streaming Gemini response to a non-streaming OpenAI response.
// This function processes the complete Gemini response and transforms it into a single OpenAI-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the OpenAI API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Gemini API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
var unixTimestamp int64
// Initialize template with an empty choices array to support multiple candidates.
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[]}`
if modelVersionResult := gjson.GetBytes(rawJSON, "modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
if createTimeResult := gjson.GetBytes(rawJSON, "createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
if responseIDResult := gjson.GetBytes(rawJSON, "responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
}
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini openai response: failed to set cached_tokens in non-streaming: %v", err)
}
}
}
// Process the main content part of the response for all candidates.
candidates := gjson.GetBytes(rawJSON, "candidates")
if candidates.IsArray() {
candidates.ForEach(func(_, candidate gjson.Result) bool {
// Construct a single Choice object.
choiceTemplate := `{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}`
// Set the index for this choice.
choiceTemplate, _ = sjson.Set(choiceTemplate, "index", candidate.Get("index").Int())
// Set finish reason.
if finishReasonResult := candidate.Get("finishReason"); finishReasonResult.Exists() {
choiceTemplate, _ = sjson.Set(choiceTemplate, "finish_reason", strings.ToLower(finishReasonResult.String()))
choiceTemplate, _ = sjson.Set(choiceTemplate, "native_finish_reason", strings.ToLower(finishReasonResult.String()))
}
partsResult := candidate.Get("content.parts")
hasFunctionCall := false
if partsResult.IsArray() {
partsResults := partsResult.Array()
for i := 0; i < len(partsResults); i++ {
partResult := partsResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
inlineDataResult := partResult.Get("inlineData")
if !inlineDataResult.Exists() {
inlineDataResult = partResult.Get("inline_data")
}
if partTextResult.Exists() {
// Append text content, distinguishing between regular content and reasoning.
if partResult.Get("thought").Bool() {
oldVal := gjson.Get(choiceTemplate, "message.reasoning_content").String()
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.reasoning_content", oldVal+partTextResult.String())
} else {
oldVal := gjson.Get(choiceTemplate, "message.content").String()
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.content", oldVal+partTextResult.String())
}
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.role", "assistant")
} else if functionCallResult.Exists() {
// Append function call content to the tool_calls array.
hasFunctionCall = true
toolCallsResult := gjson.Get(choiceTemplate, "message.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.tool_calls", `[]`)
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
}
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.role", "assistant")
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.tool_calls.-1", functionCallItemTemplate)
} else if inlineDataResult.Exists() {
data := inlineDataResult.Get("data").String()
if data != "" {
mimeType := inlineDataResult.Get("mimeType").String()
if mimeType == "" {
mimeType = inlineDataResult.Get("mime_type").String()
}
if mimeType == "" {
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagesResult := gjson.Get(choiceTemplate, "message.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.images", `[]`)
}
imageIndex := len(gjson.Get(choiceTemplate, "message.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
choiceTemplate, _ = sjson.Set(choiceTemplate, "message.role", "assistant")
choiceTemplate, _ = sjson.SetRaw(choiceTemplate, "message.images.-1", imagePayload)
}
}
}
}
if hasFunctionCall {
choiceTemplate, _ = sjson.Set(choiceTemplate, "finish_reason", "tool_calls")
choiceTemplate, _ = sjson.Set(choiceTemplate, "native_finish_reason", "tool_calls")
}
// Append the constructed choice to the main choices array.
template, _ = sjson.SetRaw(template, "choices.-1", choiceTemplate)
return true
})
}
return template
}
================================================
FILE: internal/translator/gemini/openai/chat-completions/init.go
================================================
package chat_completions
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenAI,
Gemini,
ConvertOpenAIRequestToGemini,
interfaces.TranslateResponse{
Stream: ConvertGeminiResponseToOpenAI,
NonStream: ConvertGeminiResponseToOpenAINonStream,
},
)
}
================================================
FILE: internal/translator/gemini/openai/responses/gemini_openai-responses_request.go
================================================
package responses
import (
"encoding/json"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const geminiResponsesThoughtSignature = "skip_thought_signature_validator"
func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
// Note: modelName and stream parameters are part of the fixed method signature
_ = modelName // Unused but required by interface
_ = stream // Unused but required by interface
// Base Gemini API template (do not include thinkingConfig by default)
out := `{"contents":[]}`
root := gjson.ParseBytes(rawJSON)
// Extract system instruction from OpenAI "instructions" field
if instructions := root.Get("instructions"); instructions.Exists() {
systemInstr := `{"parts":[{"text":""}]}`
systemInstr, _ = sjson.Set(systemInstr, "parts.0.text", instructions.String())
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
}
// Convert input messages to Gemini contents format
if input := root.Get("input"); input.Exists() && input.IsArray() {
items := input.Array()
// Normalize consecutive function calls and outputs so each call is immediately followed by its response
normalized := make([]gjson.Result, 0, len(items))
for i := 0; i < len(items); {
item := items[i]
itemType := item.Get("type").String()
itemRole := item.Get("role").String()
if itemType == "" && itemRole != "" {
itemType = "message"
}
if itemType == "function_call" {
var calls []gjson.Result
var outputs []gjson.Result
for i < len(items) {
next := items[i]
nextType := next.Get("type").String()
nextRole := next.Get("role").String()
if nextType == "" && nextRole != "" {
nextType = "message"
}
if nextType != "function_call" {
break
}
calls = append(calls, next)
i++
}
for i < len(items) {
next := items[i]
nextType := next.Get("type").String()
nextRole := next.Get("role").String()
if nextType == "" && nextRole != "" {
nextType = "message"
}
if nextType != "function_call_output" {
break
}
outputs = append(outputs, next)
i++
}
if len(calls) > 0 {
outputMap := make(map[string]gjson.Result, len(outputs))
for _, out := range outputs {
outputMap[out.Get("call_id").String()] = out
}
for _, call := range calls {
normalized = append(normalized, call)
callID := call.Get("call_id").String()
if resp, ok := outputMap[callID]; ok {
normalized = append(normalized, resp)
delete(outputMap, callID)
}
}
for _, out := range outputs {
if _, ok := outputMap[out.Get("call_id").String()]; ok {
normalized = append(normalized, out)
}
}
continue
}
}
if itemType == "function_call_output" {
normalized = append(normalized, item)
i++
continue
}
normalized = append(normalized, item)
i++
}
for _, item := range normalized {
itemType := item.Get("type").String()
itemRole := item.Get("role").String()
if itemType == "" && itemRole != "" {
itemType = "message"
}
switch itemType {
case "message":
if strings.EqualFold(itemRole, "system") {
if contentArray := item.Get("content"); contentArray.Exists() {
systemInstr := ""
if systemInstructionResult := gjson.Get(out, "systemInstruction"); systemInstructionResult.Exists() {
systemInstr = systemInstructionResult.Raw
} else {
systemInstr = `{"parts":[]}`
}
if contentArray.IsArray() {
contentArray.ForEach(func(_, contentItem gjson.Result) bool {
part := `{"text":""}`
text := contentItem.Get("text").String()
part, _ = sjson.Set(part, "text", text)
systemInstr, _ = sjson.SetRaw(systemInstr, "parts.-1", part)
return true
})
} else if contentArray.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentArray.String())
systemInstr, _ = sjson.SetRaw(systemInstr, "parts.-1", part)
}
if systemInstr != `{"parts":[]}` {
out, _ = sjson.SetRaw(out, "systemInstruction", systemInstr)
}
}
continue
}
// Handle regular messages
// Note: In Responses format, model outputs may appear as content items with type "output_text"
// even when the message.role is "user". We split such items into distinct Gemini messages
// with roles derived from the content type to match docs/convert-2.md.
if contentArray := item.Get("content"); contentArray.Exists() && contentArray.IsArray() {
currentRole := ""
var currentParts []string
flush := func() {
if currentRole == "" || len(currentParts) == 0 {
currentParts = nil
return
}
one := `{"role":"","parts":[]}`
one, _ = sjson.Set(one, "role", currentRole)
for _, part := range currentParts {
one, _ = sjson.SetRaw(one, "parts.-1", part)
}
out, _ = sjson.SetRaw(out, "contents.-1", one)
currentParts = nil
}
contentArray.ForEach(func(_, contentItem gjson.Result) bool {
contentType := contentItem.Get("type").String()
if contentType == "" {
contentType = "input_text"
}
effRole := "user"
if itemRole != "" {
switch strings.ToLower(itemRole) {
case "assistant", "model":
effRole = "model"
default:
effRole = strings.ToLower(itemRole)
}
}
if contentType == "output_text" {
effRole = "model"
}
if effRole == "assistant" {
effRole = "model"
}
if currentRole != "" && effRole != currentRole {
flush()
currentRole = ""
}
if currentRole == "" {
currentRole = effRole
}
var partJSON string
switch contentType {
case "input_text", "output_text":
if text := contentItem.Get("text"); text.Exists() {
partJSON = `{"text":""}`
partJSON, _ = sjson.Set(partJSON, "text", text.String())
}
case "input_image":
imageURL := contentItem.Get("image_url").String()
if imageURL == "" {
imageURL = contentItem.Get("url").String()
}
if imageURL != "" {
mimeType := "application/octet-stream"
data := ""
if strings.HasPrefix(imageURL, "data:") {
trimmed := strings.TrimPrefix(imageURL, "data:")
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mimeType = mediaAndData[0]
}
data = mediaAndData[1]
} else {
mediaAndData = strings.SplitN(trimmed, ",", 2)
if len(mediaAndData) == 2 {
if mediaAndData[0] != "" {
mimeType = mediaAndData[0]
}
data = mediaAndData[1]
}
}
}
if data != "" {
partJSON = `{"inline_data":{"mime_type":"","data":""}}`
partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType)
partJSON, _ = sjson.Set(partJSON, "inline_data.data", data)
}
}
case "input_audio":
audioData := contentItem.Get("data").String()
audioFormat := contentItem.Get("format").String()
if audioData != "" {
audioMimeMap := map[string]string{
"mp3": "audio/mpeg",
"wav": "audio/wav",
"ogg": "audio/ogg",
"flac": "audio/flac",
"aac": "audio/aac",
"webm": "audio/webm",
"pcm16": "audio/pcm",
"g711_ulaw": "audio/basic",
"g711_alaw": "audio/basic",
}
mimeType := "audio/wav"
if audioFormat != "" {
if mapped, ok := audioMimeMap[audioFormat]; ok {
mimeType = mapped
} else {
mimeType = "audio/" + audioFormat
}
}
partJSON = `{"inline_data":{"mime_type":"","data":""}}`
partJSON, _ = sjson.Set(partJSON, "inline_data.mime_type", mimeType)
partJSON, _ = sjson.Set(partJSON, "inline_data.data", audioData)
}
}
if partJSON != "" {
currentParts = append(currentParts, partJSON)
}
return true
})
flush()
} else if contentArray.Type == gjson.String {
effRole := "user"
if itemRole != "" {
switch strings.ToLower(itemRole) {
case "assistant", "model":
effRole = "model"
default:
effRole = strings.ToLower(itemRole)
}
}
one := `{"role":"","parts":[{"text":""}]}`
one, _ = sjson.Set(one, "role", effRole)
one, _ = sjson.Set(one, "parts.0.text", contentArray.String())
out, _ = sjson.SetRaw(out, "contents.-1", one)
}
case "function_call":
// Handle function calls - convert to model message with functionCall
name := item.Get("name").String()
arguments := item.Get("arguments").String()
modelContent := `{"role":"model","parts":[]}`
functionCall := `{"functionCall":{"name":"","args":{}}}`
functionCall, _ = sjson.Set(functionCall, "functionCall.name", name)
functionCall, _ = sjson.Set(functionCall, "thoughtSignature", geminiResponsesThoughtSignature)
functionCall, _ = sjson.Set(functionCall, "functionCall.id", item.Get("call_id").String())
// Parse arguments JSON string and set as args object
if arguments != "" {
argsResult := gjson.Parse(arguments)
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsResult.Raw)
}
modelContent, _ = sjson.SetRaw(modelContent, "parts.-1", functionCall)
out, _ = sjson.SetRaw(out, "contents.-1", modelContent)
case "function_call_output":
// Handle function call outputs - convert to function message with functionResponse
callID := item.Get("call_id").String()
// Use .Raw to preserve the JSON encoding (includes quotes for strings)
outputRaw := item.Get("output").Str
functionContent := `{"role":"function","parts":[]}`
functionResponse := `{"functionResponse":{"name":"","response":{}}}`
// We need to extract the function name from the previous function_call
// For now, we'll use a placeholder or extract from context if available
functionName := "unknown" // This should ideally be matched with the corresponding function_call
// Find the corresponding function call name by matching call_id
// We need to look back through the input array to find the matching call
if inputArray := root.Get("input"); inputArray.Exists() && inputArray.IsArray() {
inputArray.ForEach(func(_, prevItem gjson.Result) bool {
if prevItem.Get("type").String() == "function_call" && prevItem.Get("call_id").String() == callID {
functionName = prevItem.Get("name").String()
return false // Stop iteration
}
return true
})
}
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.name", functionName)
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.id", callID)
// Set the raw JSON output directly (preserves string encoding)
if outputRaw != "" && outputRaw != "null" {
output := gjson.Parse(outputRaw)
if output.Type == gjson.JSON && json.Valid([]byte(output.Raw)) {
functionResponse, _ = sjson.SetRaw(functionResponse, "functionResponse.response.result", output.Raw)
} else {
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.response.result", outputRaw)
}
}
functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse)
out, _ = sjson.SetRaw(out, "contents.-1", functionContent)
case "reasoning":
thoughtContent := `{"role":"model","parts":[]}`
thought := `{"text":"","thoughtSignature":"","thought":true}`
thought, _ = sjson.Set(thought, "text", item.Get("summary.0.text").String())
thought, _ = sjson.Set(thought, "thoughtSignature", item.Get("encrypted_content").String())
thoughtContent, _ = sjson.SetRaw(thoughtContent, "parts.-1", thought)
out, _ = sjson.SetRaw(out, "contents.-1", thoughtContent)
}
}
} else if input.Exists() && input.Type == gjson.String {
// Simple string input conversion to user message
userContent := `{"role":"user","parts":[{"text":""}]}`
userContent, _ = sjson.Set(userContent, "parts.0.text", input.String())
out, _ = sjson.SetRaw(out, "contents.-1", userContent)
}
// Convert tools to Gemini functionDeclarations format
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
geminiTools := `[{"functionDeclarations":[]}]`
tools.ForEach(func(_, tool gjson.Result) bool {
if tool.Get("type").String() == "function" {
funcDecl := `{"name":"","description":"","parametersJsonSchema":{}}`
if name := tool.Get("name"); name.Exists() {
funcDecl, _ = sjson.Set(funcDecl, "name", name.String())
}
if desc := tool.Get("description"); desc.Exists() {
funcDecl, _ = sjson.Set(funcDecl, "description", desc.String())
}
if params := tool.Get("parameters"); params.Exists() {
funcDecl, _ = sjson.SetRaw(funcDecl, "parametersJsonSchema", params.Raw)
}
geminiTools, _ = sjson.SetRaw(geminiTools, "0.functionDeclarations.-1", funcDecl)
}
return true
})
// Only add tools if there are function declarations
if funcDecls := gjson.Get(geminiTools, "0.functionDeclarations"); funcDecls.Exists() && len(funcDecls.Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", geminiTools)
}
}
// Handle generation config from OpenAI format
if maxOutputTokens := root.Get("max_output_tokens"); maxOutputTokens.Exists() {
genConfig := `{"maxOutputTokens":0}`
genConfig, _ = sjson.Set(genConfig, "maxOutputTokens", maxOutputTokens.Int())
out, _ = sjson.SetRaw(out, "generationConfig", genConfig)
}
// Handle temperature if present
if temperature := root.Get("temperature"); temperature.Exists() {
if !gjson.Get(out, "generationConfig").Exists() {
out, _ = sjson.SetRaw(out, "generationConfig", `{}`)
}
out, _ = sjson.Set(out, "generationConfig.temperature", temperature.Float())
}
// Handle top_p if present
if topP := root.Get("top_p"); topP.Exists() {
if !gjson.Get(out, "generationConfig").Exists() {
out, _ = sjson.SetRaw(out, "generationConfig", `{}`)
}
out, _ = sjson.Set(out, "generationConfig.topP", topP.Float())
}
// Handle stop sequences
if stopSequences := root.Get("stop_sequences"); stopSequences.Exists() && stopSequences.IsArray() {
if !gjson.Get(out, "generationConfig").Exists() {
out, _ = sjson.SetRaw(out, "generationConfig", `{}`)
}
var sequences []string
stopSequences.ForEach(func(_, seq gjson.Result) bool {
sequences = append(sequences, seq.String())
return true
})
out, _ = sjson.Set(out, "generationConfig.stopSequences", sequences)
}
// Apply thinking configuration: convert OpenAI Responses API reasoning.effort to Gemini thinkingConfig.
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
re := root.Get("reasoning.effort")
if re.Exists() {
effort := strings.ToLower(strings.TrimSpace(re.String()))
if effort != "" {
thinkingPath := "generationConfig.thinkingConfig"
if effort == "auto" {
out, _ = sjson.Set(out, thinkingPath+".thinkingBudget", -1)
out, _ = sjson.Set(out, thinkingPath+".includeThoughts", true)
} else {
out, _ = sjson.Set(out, thinkingPath+".thinkingLevel", effort)
out, _ = sjson.Set(out, thinkingPath+".includeThoughts", effort != "none")
}
}
}
result := []byte(out)
result = common.AttachDefaultSafetySettings(result, "safetySettings")
return result
}
================================================
FILE: internal/translator/gemini/openai/responses/gemini_openai-responses_response.go
================================================
package responses
import (
"bytes"
"context"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type geminiToResponsesState struct {
Seq int
ResponseID string
CreatedAt int64
Started bool
// message aggregation
MsgOpened bool
MsgClosed bool
MsgIndex int
CurrentMsgID string
TextBuf strings.Builder
ItemTextBuf strings.Builder
// reasoning aggregation
ReasoningOpened bool
ReasoningIndex int
ReasoningItemID string
ReasoningEnc string
ReasoningBuf strings.Builder
ReasoningClosed bool
// function call aggregation (keyed by output_index)
NextIndex int
FuncArgsBuf map[int]*strings.Builder
FuncNames map[int]string
FuncCallIDs map[int]string
FuncDone map[int]bool
}
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
var responseIDCounter uint64
// funcCallIDCounter provides a process-wide unique counter for function call identifiers.
var funcCallIDCounter uint64
func pickRequestJSON(originalRequestRawJSON, requestRawJSON []byte) []byte {
if len(originalRequestRawJSON) > 0 && gjson.ValidBytes(originalRequestRawJSON) {
return originalRequestRawJSON
}
if len(requestRawJSON) > 0 && gjson.ValidBytes(requestRawJSON) {
return requestRawJSON
}
return nil
}
func unwrapRequestRoot(root gjson.Result) gjson.Result {
req := root.Get("request")
if !req.Exists() {
return root
}
if req.Get("model").Exists() || req.Get("input").Exists() || req.Get("instructions").Exists() {
return req
}
return root
}
func unwrapGeminiResponseRoot(root gjson.Result) gjson.Result {
resp := root.Get("response")
if !resp.Exists() {
return root
}
// Vertex-style Gemini responses wrap the actual payload in a "response" object.
if resp.Get("candidates").Exists() || resp.Get("responseId").Exists() || resp.Get("usageMetadata").Exists() {
return resp
}
return root
}
func emitEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
}
// ConvertGeminiResponseToOpenAIResponses converts Gemini SSE chunks into OpenAI Responses SSE events.
func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &geminiToResponsesState{
FuncArgsBuf: make(map[int]*strings.Builder),
FuncNames: make(map[int]string),
FuncCallIDs: make(map[int]string),
FuncDone: make(map[int]bool),
}
}
st := (*param).(*geminiToResponsesState)
if st.FuncArgsBuf == nil {
st.FuncArgsBuf = make(map[int]*strings.Builder)
}
if st.FuncNames == nil {
st.FuncNames = make(map[int]string)
}
if st.FuncCallIDs == nil {
st.FuncCallIDs = make(map[int]string)
}
if st.FuncDone == nil {
st.FuncDone = make(map[int]bool)
}
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
rawJSON = bytes.TrimSpace(rawJSON)
if len(rawJSON) == 0 || bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
root := gjson.ParseBytes(rawJSON)
if !root.Exists() {
return []string{}
}
root = unwrapGeminiResponseRoot(root)
var out []string
nextSeq := func() int { st.Seq++; return st.Seq }
// Helper to finalize reasoning summary events in correct order.
// It emits response.reasoning_summary_text.done followed by
// response.reasoning_summary_part.done exactly once.
finalizeReasoning := func() {
if !st.ReasoningOpened || st.ReasoningClosed {
return
}
full := st.ReasoningBuf.String()
textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq())
textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningItemID)
textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex)
textDone, _ = sjson.Set(textDone, "text", full)
out = append(out, emitEvent("response.reasoning_summary_text.done", textDone))
partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningItemID)
partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex)
partDone, _ = sjson.Set(partDone, "part.text", full)
out = append(out, emitEvent("response.reasoning_summary_part.done", partDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "item.id", st.ReasoningItemID)
itemDone, _ = sjson.Set(itemDone, "output_index", st.ReasoningIndex)
itemDone, _ = sjson.Set(itemDone, "item.encrypted_content", st.ReasoningEnc)
itemDone, _ = sjson.Set(itemDone, "item.summary.0.text", full)
out = append(out, emitEvent("response.output_item.done", itemDone))
st.ReasoningClosed = true
}
// Helper to finalize the assistant message in correct order.
// It emits response.output_text.done, response.content_part.done,
// and response.output_item.done exactly once.
finalizeMessage := func() {
if !st.MsgOpened || st.MsgClosed {
return
}
fullText := st.ItemTextBuf.String()
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", st.CurrentMsgID)
done, _ = sjson.Set(done, "output_index", st.MsgIndex)
done, _ = sjson.Set(done, "text", fullText)
out = append(out, emitEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.CurrentMsgID)
partDone, _ = sjson.Set(partDone, "output_index", st.MsgIndex)
partDone, _ = sjson.Set(partDone, "part.text", fullText)
out = append(out, emitEvent("response.content_part.done", partDone))
final := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","text":""}],"role":"assistant"}}`
final, _ = sjson.Set(final, "sequence_number", nextSeq())
final, _ = sjson.Set(final, "output_index", st.MsgIndex)
final, _ = sjson.Set(final, "item.id", st.CurrentMsgID)
final, _ = sjson.Set(final, "item.content.0.text", fullText)
out = append(out, emitEvent("response.output_item.done", final))
st.MsgClosed = true
}
// Initialize per-response fields and emit created/in_progress once
if !st.Started {
st.ResponseID = root.Get("responseId").String()
if st.ResponseID == "" {
st.ResponseID = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
}
if !strings.HasPrefix(st.ResponseID, "resp_") {
st.ResponseID = fmt.Sprintf("resp_%s", st.ResponseID)
}
if v := root.Get("createTime"); v.Exists() {
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
st.CreatedAt = t.Unix()
}
}
if st.CreatedAt == 0 {
st.CreatedAt = time.Now().Unix()
}
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
created, _ = sjson.Set(created, "sequence_number", nextSeq())
created, _ = sjson.Set(created, "response.id", st.ResponseID)
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.created", created))
inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`
inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq())
inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID)
inprog, _ = sjson.Set(inprog, "response.created_at", st.CreatedAt)
out = append(out, emitEvent("response.in_progress", inprog))
st.Started = true
st.NextIndex = 0
}
// Handle parts (text/thought/functionCall)
if parts := root.Get("candidates.0.content.parts"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
// Reasoning text
if part.Get("thought").Bool() {
if st.ReasoningClosed {
// Ignore any late thought chunks after reasoning is finalized.
return true
}
if sig := part.Get("thoughtSignature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
st.ReasoningEnc = sig.String()
} else if sig = part.Get("thought_signature"); sig.Exists() && sig.String() != "" && sig.String() != geminiResponsesThoughtSignature {
st.ReasoningEnc = sig.String()
}
if !st.ReasoningOpened {
st.ReasoningOpened = true
st.ReasoningIndex = st.NextIndex
st.NextIndex++
st.ReasoningItemID = fmt.Sprintf("rs_%s_%d", st.ResponseID, st.ReasoningIndex)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","encrypted_content":"","summary":[]}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", st.ReasoningIndex)
item, _ = sjson.Set(item, "item.id", st.ReasoningItemID)
item, _ = sjson.Set(item, "item.encrypted_content", st.ReasoningEnc)
out = append(out, emitEvent("response.output_item.added", item))
partAdded := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
partAdded, _ = sjson.Set(partAdded, "sequence_number", nextSeq())
partAdded, _ = sjson.Set(partAdded, "item_id", st.ReasoningItemID)
partAdded, _ = sjson.Set(partAdded, "output_index", st.ReasoningIndex)
out = append(out, emitEvent("response.reasoning_summary_part.added", partAdded))
}
if t := part.Get("text"); t.Exists() && t.String() != "" {
st.ReasoningBuf.WriteString(t.String())
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.Set(msg, "delta", t.String())
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
}
return true
}
// Assistant visible text
if t := part.Get("text"); t.Exists() && t.String() != "" {
// Before emitting non-reasoning outputs, finalize reasoning if open.
finalizeReasoning()
if !st.MsgOpened {
st.MsgOpened = true
st.MsgIndex = st.NextIndex
st.NextIndex++
st.CurrentMsgID = fmt.Sprintf("msg_%s_0", st.ResponseID)
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", st.MsgIndex)
item, _ = sjson.Set(item, "item.id", st.CurrentMsgID)
out = append(out, emitEvent("response.output_item.added", item))
partAdded := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partAdded, _ = sjson.Set(partAdded, "sequence_number", nextSeq())
partAdded, _ = sjson.Set(partAdded, "item_id", st.CurrentMsgID)
partAdded, _ = sjson.Set(partAdded, "output_index", st.MsgIndex)
out = append(out, emitEvent("response.content_part.added", partAdded))
st.ItemTextBuf.Reset()
}
st.TextBuf.WriteString(t.String())
st.ItemTextBuf.WriteString(t.String())
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.CurrentMsgID)
msg, _ = sjson.Set(msg, "output_index", st.MsgIndex)
msg, _ = sjson.Set(msg, "delta", t.String())
out = append(out, emitEvent("response.output_text.delta", msg))
return true
}
// Function call
if fc := part.Get("functionCall"); fc.Exists() {
// Before emitting function-call outputs, finalize reasoning and the message (if open).
// Responses streaming requires message done events before the next output_item.added.
finalizeReasoning()
finalizeMessage()
name := fc.Get("name").String()
idx := st.NextIndex
st.NextIndex++
// Ensure buffers
if st.FuncArgsBuf[idx] == nil {
st.FuncArgsBuf[idx] = &strings.Builder{}
}
if st.FuncCallIDs[idx] == "" {
st.FuncCallIDs[idx] = fmt.Sprintf("call_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
}
st.FuncNames[idx] = name
argsJSON := "{}"
if args := fc.Get("args"); args.Exists() {
argsJSON = args.Raw
}
if st.FuncArgsBuf[idx].Len() == 0 && argsJSON != "" {
st.FuncArgsBuf[idx].WriteString(argsJSON)
}
// Emit item.added for function call
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
item, _ = sjson.Set(item, "item.call_id", st.FuncCallIDs[idx])
item, _ = sjson.Set(item, "item.name", name)
out = append(out, emitEvent("response.output_item.added", item))
// Emit arguments delta (full args in one chunk).
// When Gemini omits args, emit "{}" to keep Responses streaming event order consistent.
if argsJSON != "" {
ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
ad, _ = sjson.Set(ad, "sequence_number", nextSeq())
ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
ad, _ = sjson.Set(ad, "output_index", idx)
ad, _ = sjson.Set(ad, "delta", argsJSON)
out = append(out, emitEvent("response.function_call_arguments.delta", ad))
}
// Gemini emits the full function call payload at once, so we can finalize it immediately.
if !st.FuncDone[idx] {
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
fcDone, _ = sjson.Set(fcDone, "arguments", argsJSON)
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
itemDone, _ = sjson.Set(itemDone, "item.arguments", argsJSON)
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
out = append(out, emitEvent("response.output_item.done", itemDone))
st.FuncDone[idx] = true
}
return true
}
return true
})
}
// Finalization on finishReason
if fr := root.Get("candidates.0.finishReason"); fr.Exists() && fr.String() != "" {
// Finalize reasoning first to keep ordering tight with last delta
finalizeReasoning()
finalizeMessage()
// Close function calls
if len(st.FuncArgsBuf) > 0 {
// sort indices (small N); avoid extra imports
idxs := make([]int, 0, len(st.FuncArgsBuf))
for idx := range st.FuncArgsBuf {
idxs = append(idxs, idx)
}
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, idx := range idxs {
if st.FuncDone[idx] {
continue
}
args := "{}"
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
args = b.String()
}
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
fcDone, _ = sjson.Set(fcDone, "output_index", idx)
fcDone, _ = sjson.Set(fcDone, "arguments", args)
out = append(out, emitEvent("response.function_call_arguments.done", fcDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", st.FuncCallIDs[idx]))
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
itemDone, _ = sjson.Set(itemDone, "item.call_id", st.FuncCallIDs[idx])
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[idx])
out = append(out, emitEvent("response.output_item.done", itemDone))
st.FuncDone[idx] = true
}
}
// Reasoning already finalized above if present
// Build response.completed with aggregated outputs and request echo fields
completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`
completed, _ = sjson.Set(completed, "sequence_number", nextSeq())
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
completed, _ = sjson.Set(completed, "response.created_at", st.CreatedAt)
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.Set(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.Set(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.Set(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.Set(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.Set(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.Set(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.Set(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.Set(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.Set(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.Set(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.Set(completed, "response.metadata", v.Value())
}
}
// Compose outputs in output_index order.
outputsWrapper := `{"arr":[]}`
for idx := 0; idx < st.NextIndex; idx++ {
if st.ReasoningOpened && idx == st.ReasoningIndex {
item := `{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", st.ReasoningItemID)
item, _ = sjson.Set(item, "encrypted_content", st.ReasoningEnc)
item, _ = sjson.Set(item, "summary.0.text", st.ReasoningBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
continue
}
if st.MsgOpened && idx == st.MsgIndex {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", st.CurrentMsgID)
item, _ = sjson.Set(item, "content.0.text", st.TextBuf.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
continue
}
if callID, ok := st.FuncCallIDs[idx]; ok && callID != "" {
args := "{}"
if b := st.FuncArgsBuf[idx]; b != nil && b.Len() > 0 {
args = b.String()
}
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", st.FuncNames[idx])
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
}
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
// output tokens
if v := um.Get("candidatesTokenCount"); v.Exists() {
completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int())
} else {
completed, _ = sjson.Set(completed, "response.usage.output_tokens", 0)
}
if v := um.Get("thoughtsTokenCount"); v.Exists() {
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", v.Int())
} else {
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", 0)
}
if v := um.Get("totalTokenCount"); v.Exists() {
completed, _ = sjson.Set(completed, "response.usage.total_tokens", v.Int())
} else {
completed, _ = sjson.Set(completed, "response.usage.total_tokens", 0)
}
}
out = append(out, emitEvent("response.completed", completed))
}
return out
}
// ConvertGeminiResponseToOpenAIResponsesNonStream aggregates Gemini response JSON into a single OpenAI Responses JSON object.
func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
root := gjson.ParseBytes(rawJSON)
root = unwrapGeminiResponseRoot(root)
// Base response scaffold
resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`
// id: prefer provider responseId, otherwise synthesize
id := root.Get("responseId").String()
if id == "" {
id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
}
// Normalize to response-style id (prefix resp_ if missing)
if !strings.HasPrefix(id, "resp_") {
id = fmt.Sprintf("resp_%s", id)
}
resp, _ = sjson.Set(resp, "id", id)
// created_at: map from createTime if available
createdAt := time.Now().Unix()
if v := root.Get("createTime"); v.Exists() {
if t, errParseCreateTime := time.Parse(time.RFC3339Nano, v.String()); errParseCreateTime == nil {
createdAt = t.Unix()
}
}
resp, _ = sjson.Set(resp, "created_at", createdAt)
// Echo request fields when present; fallback model from response modelVersion
if reqJSON := pickRequestJSON(originalRequestRawJSON, requestRawJSON); len(reqJSON) > 0 {
req := unwrapRequestRoot(gjson.ParseBytes(reqJSON))
if v := req.Get("instructions"); v.Exists() {
resp, _ = sjson.Set(resp, "instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
resp, _ = sjson.Set(resp, "max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
resp, _ = sjson.Set(resp, "max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
resp, _ = sjson.Set(resp, "model", v.String())
} else if v = root.Get("modelVersion"); v.Exists() {
resp, _ = sjson.Set(resp, "model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
resp, _ = sjson.Set(resp, "parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
resp, _ = sjson.Set(resp, "previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
resp, _ = sjson.Set(resp, "prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
resp, _ = sjson.Set(resp, "reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
resp, _ = sjson.Set(resp, "safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
resp, _ = sjson.Set(resp, "service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
resp, _ = sjson.Set(resp, "store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
resp, _ = sjson.Set(resp, "temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
resp, _ = sjson.Set(resp, "text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
resp, _ = sjson.Set(resp, "tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
resp, _ = sjson.Set(resp, "tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
resp, _ = sjson.Set(resp, "top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
resp, _ = sjson.Set(resp, "top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
resp, _ = sjson.Set(resp, "truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
resp, _ = sjson.Set(resp, "user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
resp, _ = sjson.Set(resp, "metadata", v.Value())
}
} else if v := root.Get("modelVersion"); v.Exists() {
resp, _ = sjson.Set(resp, "model", v.String())
}
// Build outputs from candidates[0].content.parts
var reasoningText strings.Builder
var reasoningEncrypted string
var messageText strings.Builder
var haveMessage bool
haveOutput := false
ensureOutput := func() {
if haveOutput {
return
}
resp, _ = sjson.SetRaw(resp, "output", "[]")
haveOutput = true
}
appendOutput := func(itemJSON string) {
ensureOutput()
resp, _ = sjson.SetRaw(resp, "output.-1", itemJSON)
}
if parts := root.Get("candidates.0.content.parts"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, p gjson.Result) bool {
if p.Get("thought").Bool() {
if t := p.Get("text"); t.Exists() {
reasoningText.WriteString(t.String())
}
if sig := p.Get("thoughtSignature"); sig.Exists() && sig.String() != "" {
reasoningEncrypted = sig.String()
}
return true
}
if t := p.Get("text"); t.Exists() && t.String() != "" {
messageText.WriteString(t.String())
haveMessage = true
return true
}
if fc := p.Get("functionCall"); fc.Exists() {
name := fc.Get("name").String()
args := fc.Get("args")
callID := fmt.Sprintf("call_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&funcCallIDCounter, 1))
itemJSON := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("fc_%s", callID))
itemJSON, _ = sjson.Set(itemJSON, "call_id", callID)
itemJSON, _ = sjson.Set(itemJSON, "name", name)
argsStr := ""
if args.Exists() {
argsStr = args.Raw
}
itemJSON, _ = sjson.Set(itemJSON, "arguments", argsStr)
appendOutput(itemJSON)
return true
}
return true
})
}
// Reasoning output item
if reasoningText.Len() > 0 || reasoningEncrypted != "" {
rid := strings.TrimPrefix(id, "resp_")
itemJSON := `{"id":"","type":"reasoning","encrypted_content":""}`
itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("rs_%s", rid))
itemJSON, _ = sjson.Set(itemJSON, "encrypted_content", reasoningEncrypted)
if reasoningText.Len() > 0 {
summaryJSON := `{"type":"summary_text","text":""}`
summaryJSON, _ = sjson.Set(summaryJSON, "text", reasoningText.String())
itemJSON, _ = sjson.SetRaw(itemJSON, "summary", "[]")
itemJSON, _ = sjson.SetRaw(itemJSON, "summary.-1", summaryJSON)
}
appendOutput(itemJSON)
}
// Assistant message output item
if haveMessage {
itemJSON := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
itemJSON, _ = sjson.Set(itemJSON, "id", fmt.Sprintf("msg_%s_0", strings.TrimPrefix(id, "resp_")))
itemJSON, _ = sjson.Set(itemJSON, "content.0.text", messageText.String())
appendOutput(itemJSON)
}
// usage mapping
if um := root.Get("usageMetadata"); um.Exists() {
// input tokens = prompt only (thoughts go to output)
input := um.Get("promptTokenCount").Int()
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
// cached token details: align with OpenAI "cached_tokens" semantics.
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
// output tokens
if v := um.Get("candidatesTokenCount"); v.Exists() {
resp, _ = sjson.Set(resp, "usage.output_tokens", v.Int())
}
if v := um.Get("thoughtsTokenCount"); v.Exists() {
resp, _ = sjson.Set(resp, "usage.output_tokens_details.reasoning_tokens", v.Int())
}
if v := um.Get("totalTokenCount"); v.Exists() {
resp, _ = sjson.Set(resp, "usage.total_tokens", v.Int())
}
}
return resp
}
================================================
FILE: internal/translator/gemini/openai/responses/gemini_openai-responses_response_test.go
================================================
package responses
import (
"context"
"strings"
"testing"
"github.com/tidwall/gjson"
)
func parseSSEEvent(t *testing.T, chunk string) (string, gjson.Result) {
t.Helper()
lines := strings.Split(chunk, "\n")
if len(lines) < 2 {
t.Fatalf("unexpected SSE chunk: %q", chunk)
}
event := strings.TrimSpace(strings.TrimPrefix(lines[0], "event:"))
dataLine := strings.TrimSpace(strings.TrimPrefix(lines[1], "data:"))
if !gjson.Valid(dataLine) {
t.Fatalf("invalid SSE data JSON: %q", dataLine)
}
return event, gjson.Parse(dataLine)
}
func TestConvertGeminiResponseToOpenAIResponses_UnwrapAndAggregateText(t *testing.T) {
// Vertex-style Gemini stream wraps the actual response payload under "response".
// This test ensures we unwrap and that output_text.done contains the full text.
in := []string{
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"让"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"我先"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"了解"}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"mcp__serena__list_dir","args":{"recursive":false,"relative_path":"internal"},"id":"toolu_1"}}]}}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":2},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
}
originalReq := []byte(`{"instructions":"test instructions","model":"gpt-5","max_output_tokens":123}`)
var param any
var out []string
for _, line := range in {
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", originalReq, nil, []byte(line), ¶m)...)
}
var (
gotTextDone bool
gotMessageDone bool
gotResponseDone bool
gotFuncDone bool
textDone string
messageText string
responseID string
instructions string
cachedTokens int64
funcName string
funcArgs string
posTextDone = -1
posPartDone = -1
posMessageDone = -1
posFuncAdded = -1
)
for i, chunk := range out {
ev, data := parseSSEEvent(t, chunk)
switch ev {
case "response.output_text.done":
gotTextDone = true
if posTextDone == -1 {
posTextDone = i
}
textDone = data.Get("text").String()
case "response.content_part.done":
if posPartDone == -1 {
posPartDone = i
}
case "response.output_item.done":
switch data.Get("item.type").String() {
case "message":
gotMessageDone = true
if posMessageDone == -1 {
posMessageDone = i
}
messageText = data.Get("item.content.0.text").String()
case "function_call":
gotFuncDone = true
funcName = data.Get("item.name").String()
funcArgs = data.Get("item.arguments").String()
}
case "response.output_item.added":
if data.Get("item.type").String() == "function_call" && posFuncAdded == -1 {
posFuncAdded = i
}
case "response.completed":
gotResponseDone = true
responseID = data.Get("response.id").String()
instructions = data.Get("response.instructions").String()
cachedTokens = data.Get("response.usage.input_tokens_details.cached_tokens").Int()
}
}
if !gotTextDone {
t.Fatalf("missing response.output_text.done event")
}
if posTextDone == -1 || posPartDone == -1 || posMessageDone == -1 || posFuncAdded == -1 {
t.Fatalf("missing ordering events: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
}
if !(posTextDone < posPartDone && posPartDone < posMessageDone && posMessageDone < posFuncAdded) {
t.Fatalf("unexpected message/function ordering: textDone=%d partDone=%d messageDone=%d funcAdded=%d", posTextDone, posPartDone, posMessageDone, posFuncAdded)
}
if !gotMessageDone {
t.Fatalf("missing message response.output_item.done event")
}
if !gotFuncDone {
t.Fatalf("missing function_call response.output_item.done event")
}
if !gotResponseDone {
t.Fatalf("missing response.completed event")
}
if textDone != "让我先了解" {
t.Fatalf("unexpected output_text.done text: got %q", textDone)
}
if messageText != "让我先了解" {
t.Fatalf("unexpected message done text: got %q", messageText)
}
if responseID != "resp_req_vrtx_1" {
t.Fatalf("unexpected response id: got %q", responseID)
}
if instructions != "test instructions" {
t.Fatalf("unexpected instructions echo: got %q", instructions)
}
if cachedTokens != 2 {
t.Fatalf("unexpected cached token count: got %d", cachedTokens)
}
if funcName != "mcp__serena__list_dir" {
t.Fatalf("unexpected function name: got %q", funcName)
}
if !gjson.Valid(funcArgs) {
t.Fatalf("invalid function arguments JSON: %q", funcArgs)
}
if gjson.Get(funcArgs, "recursive").Bool() != false {
t.Fatalf("unexpected recursive arg: %v", gjson.Get(funcArgs, "recursive").Value())
}
if gjson.Get(funcArgs, "relative_path").String() != "internal" {
t.Fatalf("unexpected relative_path arg: %q", gjson.Get(funcArgs, "relative_path").String())
}
}
func TestConvertGeminiResponseToOpenAIResponses_ReasoningEncryptedContent(t *testing.T) {
sig := "RXE0RENrZ0lDeEFDR0FJcVFOZDdjUzlleGFuRktRdFcvSzNyZ2MvWDNCcDQ4RmxSbGxOWUlOVU5kR1l1UHMrMGdkMVp0Vkg3ekdKU0g4YVljc2JjN3lNK0FrdGpTNUdqamI4T3Z0VVNETzdQd3pmcFhUOGl3U3hXUEJvTVFRQ09mWTFyMEtTWGZxUUlJakFqdmFGWk83RW1XRlBKckJVOVpkYzdDKw=="
in := []string{
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"thoughtSignature":"` + sig + `","text":""}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"thought":true,"text":"a"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hello"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"modelVersion":"test-model","responseId":"req_vrtx_sig"},"traceId":"t1"}`,
}
var param any
var out []string
for _, line := range in {
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
}
var (
addedEnc string
doneEnc string
)
for _, chunk := range out {
ev, data := parseSSEEvent(t, chunk)
switch ev {
case "response.output_item.added":
if data.Get("item.type").String() == "reasoning" {
addedEnc = data.Get("item.encrypted_content").String()
}
case "response.output_item.done":
if data.Get("item.type").String() == "reasoning" {
doneEnc = data.Get("item.encrypted_content").String()
}
}
}
if addedEnc != sig {
t.Fatalf("unexpected encrypted_content in response.output_item.added: got %q", addedEnc)
}
if doneEnc != sig {
t.Fatalf("unexpected encrypted_content in response.output_item.done: got %q", doneEnc)
}
}
func TestConvertGeminiResponseToOpenAIResponses_FunctionCallEventOrder(t *testing.T) {
in := []string{
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool1"}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool2","args":{"a":1}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_1"},"traceId":"t1"}`,
}
var param any
var out []string
for _, line := range in {
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
}
posAdded := []int{-1, -1, -1}
posArgsDelta := []int{-1, -1, -1}
posArgsDone := []int{-1, -1, -1}
posItemDone := []int{-1, -1, -1}
posCompleted := -1
deltaByIndex := map[int]string{}
for i, chunk := range out {
ev, data := parseSSEEvent(t, chunk)
switch ev {
case "response.output_item.added":
if data.Get("item.type").String() != "function_call" {
continue
}
idx := int(data.Get("output_index").Int())
if idx >= 0 && idx < len(posAdded) {
posAdded[idx] = i
}
case "response.function_call_arguments.delta":
idx := int(data.Get("output_index").Int())
if idx >= 0 && idx < len(posArgsDelta) {
posArgsDelta[idx] = i
deltaByIndex[idx] = data.Get("delta").String()
}
case "response.function_call_arguments.done":
idx := int(data.Get("output_index").Int())
if idx >= 0 && idx < len(posArgsDone) {
posArgsDone[idx] = i
}
case "response.output_item.done":
if data.Get("item.type").String() != "function_call" {
continue
}
idx := int(data.Get("output_index").Int())
if idx >= 0 && idx < len(posItemDone) {
posItemDone[idx] = i
}
case "response.completed":
posCompleted = i
output := data.Get("response.output")
if !output.Exists() || !output.IsArray() {
t.Fatalf("missing response.output in response.completed")
}
if len(output.Array()) != 3 {
t.Fatalf("unexpected response.output length: got %d", len(output.Array()))
}
if data.Get("response.output.0.name").String() != "tool0" || data.Get("response.output.0.arguments").String() != "{}" {
t.Fatalf("unexpected output[0]: %s", data.Get("response.output.0").Raw)
}
if data.Get("response.output.1.name").String() != "tool1" || data.Get("response.output.1.arguments").String() != "{}" {
t.Fatalf("unexpected output[1]: %s", data.Get("response.output.1").Raw)
}
if data.Get("response.output.2.name").String() != "tool2" {
t.Fatalf("unexpected output[2] name: %s", data.Get("response.output.2").Raw)
}
if !gjson.Valid(data.Get("response.output.2.arguments").String()) {
t.Fatalf("unexpected output[2] arguments: %q", data.Get("response.output.2.arguments").String())
}
}
}
if posCompleted == -1 {
t.Fatalf("missing response.completed event")
}
for idx := 0; idx < 3; idx++ {
if posAdded[idx] == -1 || posArgsDelta[idx] == -1 || posArgsDone[idx] == -1 || posItemDone[idx] == -1 {
t.Fatalf("missing function call events for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
}
if !(posAdded[idx] < posArgsDelta[idx] && posArgsDelta[idx] < posArgsDone[idx] && posArgsDone[idx] < posItemDone[idx]) {
t.Fatalf("unexpected ordering for output_index %d: added=%d argsDelta=%d argsDone=%d itemDone=%d", idx, posAdded[idx], posArgsDelta[idx], posArgsDone[idx], posItemDone[idx])
}
if idx > 0 && !(posItemDone[idx-1] < posAdded[idx]) {
t.Fatalf("function call events overlap between %d and %d: prevDone=%d nextAdded=%d", idx-1, idx, posItemDone[idx-1], posAdded[idx])
}
}
if deltaByIndex[0] != "{}" {
t.Fatalf("unexpected delta for output_index 0: got %q", deltaByIndex[0])
}
if deltaByIndex[1] != "{}" {
t.Fatalf("unexpected delta for output_index 1: got %q", deltaByIndex[1])
}
if deltaByIndex[2] == "" || !gjson.Valid(deltaByIndex[2]) || gjson.Get(deltaByIndex[2], "a").Int() != 1 {
t.Fatalf("unexpected delta for output_index 2: got %q", deltaByIndex[2])
}
if !(posItemDone[2] < posCompleted) {
t.Fatalf("response.completed should be after last output_item.done: last=%d completed=%d", posItemDone[2], posCompleted)
}
}
func TestConvertGeminiResponseToOpenAIResponses_ResponseOutputOrdering(t *testing.T) {
in := []string{
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"functionCall":{"name":"tool0","args":{"x":"y"}}}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":"hi"}]}}],"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
`data: {"response":{"candidates":[{"content":{"role":"model","parts":[{"text":""}]},"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":1,"candidatesTokenCount":1,"totalTokenCount":2,"cachedContentTokenCount":0},"modelVersion":"test-model","responseId":"req_vrtx_2"},"traceId":"t2"}`,
}
var param any
var out []string
for _, line := range in {
out = append(out, ConvertGeminiResponseToOpenAIResponses(context.Background(), "test-model", nil, nil, []byte(line), ¶m)...)
}
posFuncDone := -1
posMsgAdded := -1
posCompleted := -1
for i, chunk := range out {
ev, data := parseSSEEvent(t, chunk)
switch ev {
case "response.output_item.done":
if data.Get("item.type").String() == "function_call" && data.Get("output_index").Int() == 0 {
posFuncDone = i
}
case "response.output_item.added":
if data.Get("item.type").String() == "message" && data.Get("output_index").Int() == 1 {
posMsgAdded = i
}
case "response.completed":
posCompleted = i
if data.Get("response.output.0.type").String() != "function_call" {
t.Fatalf("expected response.output[0] to be function_call: %s", data.Get("response.output.0").Raw)
}
if data.Get("response.output.1.type").String() != "message" {
t.Fatalf("expected response.output[1] to be message: %s", data.Get("response.output.1").Raw)
}
if data.Get("response.output.1.content.0.text").String() != "hi" {
t.Fatalf("unexpected message text in response.output[1]: %s", data.Get("response.output.1").Raw)
}
}
}
if posFuncDone == -1 || posMsgAdded == -1 || posCompleted == -1 {
t.Fatalf("missing required events: funcDone=%d msgAdded=%d completed=%d", posFuncDone, posMsgAdded, posCompleted)
}
if !(posFuncDone < posMsgAdded) {
t.Fatalf("expected function_call to complete before message is added: funcDone=%d msgAdded=%d", posFuncDone, posMsgAdded)
}
if !(posMsgAdded < posCompleted) {
t.Fatalf("expected response.completed after message added: msgAdded=%d completed=%d", posMsgAdded, posCompleted)
}
}
================================================
FILE: internal/translator/gemini/openai/responses/init.go
================================================
package responses
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenaiResponse,
Gemini,
ConvertOpenAIResponsesRequestToGemini,
interfaces.TranslateResponse{
Stream: ConvertGeminiResponseToOpenAIResponses,
NonStream: ConvertGeminiResponseToOpenAIResponsesNonStream,
},
)
}
================================================
FILE: internal/translator/gemini-cli/claude/gemini-cli_claude_request.go
================================================
// Package claude provides request translation functionality for Claude Code API compatibility.
// This package handles the conversion of Claude Code API requests into Gemini CLI-compatible
// JSON format, transforming message contents, system instructions, and tool declarations
// into the format expected by Gemini CLI API clients. It performs JSON data transformation
// to ensure compatibility between Claude Code API format and Gemini CLI API's expected format.
package claude
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
// ConvertClaudeRequestToCLI parses and transforms a Claude Code API request into Gemini CLI API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini CLI API.
// The function performs the following transformations:
// 1. Extracts the model information from the request
// 2. Restructures the JSON to match Gemini CLI API format
// 3. Converts system instructions to the expected format
// 4. Maps message contents with proper role transformations
// 5. Handles tool declarations and tool choices
// 6. Maps generation configuration parameters
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the Claude Code API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini CLI API format
func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
// Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName)
// system instruction
if systemResult := gjson.GetBytes(rawJSON, "system"); systemResult.IsArray() {
systemInstruction := `{"role":"user","parts":[]}`
hasSystemParts := false
systemResult.ForEach(func(_, systemPromptResult gjson.Result) bool {
if systemPromptResult.Get("type").String() == "text" {
textResult := systemPromptResult.Get("text")
if textResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", textResult.String())
systemInstruction, _ = sjson.SetRaw(systemInstruction, "parts.-1", part)
hasSystemParts = true
}
}
return true
})
if hasSystemParts {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstruction)
}
} else if systemResult.Type == gjson.String {
out, _ = sjson.Set(out, "request.systemInstruction.parts.-1.text", systemResult.String())
}
// contents
if messagesResult := gjson.GetBytes(rawJSON, "messages"); messagesResult.IsArray() {
messagesResult.ForEach(func(_, messageResult gjson.Result) bool {
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
return true
}
role := roleResult.String()
if role == "assistant" {
role = "model"
}
contentJSON := `{"role":"","parts":[]}`
contentJSON, _ = sjson.Set(contentJSON, "role", role)
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentsResult.ForEach(func(_, contentResult gjson.Result) bool {
switch contentResult.Get("type").String() {
case "text":
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentResult.Get("text").String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
case "tool_use":
functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String()
argsResult := gjson.Parse(functionArgs)
if argsResult.IsObject() && gjson.Valid(functionArgs) {
part := `{"thoughtSignature":"","functionCall":{"name":"","args":{}}}`
part, _ = sjson.Set(part, "thoughtSignature", geminiCLIClaudeThoughtSignature)
part, _ = sjson.Set(part, "functionCall.name", functionName)
part, _ = sjson.SetRaw(part, "functionCall.args", functionArgs)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
}
case "tool_result":
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID == "" {
return true
}
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").Raw
part := `{"functionResponse":{"name":"","response":{"result":""}}}`
part, _ = sjson.Set(part, "functionResponse.name", funcName)
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
case "image":
source := contentResult.Get("source")
if source.Get("type").String() == "base64" {
mimeType := source.Get("media_type").String()
data := source.Get("data").String()
if mimeType != "" && data != "" {
part := `{"inlineData":{"mime_type":"","data":""}}`
part, _ = sjson.Set(part, "inlineData.mime_type", mimeType)
part, _ = sjson.Set(part, "inlineData.data", data)
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
}
}
}
return true
})
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
} else if contentsResult.Type == gjson.String {
part := `{"text":""}`
part, _ = sjson.Set(part, "text", contentsResult.String())
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
out, _ = sjson.SetRaw(out, "request.contents.-1", contentJSON)
}
return true
})
}
// tools
if toolsResult := gjson.GetBytes(rawJSON, "tools"); toolsResult.IsArray() {
hasTools := false
toolsResult.ForEach(func(_, toolResult gjson.Result) bool {
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := util.CleanJSONSchemaForGemini(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict")
tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
tool, _ = sjson.Delete(tool, "defer_loading")
tool, _ = sjson.Delete(tool, "eager_input_streaming")
if gjson.Valid(tool) && gjson.Parse(tool).IsObject() {
if !hasTools {
out, _ = sjson.SetRaw(out, "request.tools", `[{"functionDeclarations":[]}]`)
hasTools = true
}
out, _ = sjson.SetRaw(out, "request.tools.0.functionDeclarations.-1", tool)
}
}
return true
})
if !hasTools {
out, _ = sjson.Delete(out, "request.tools")
}
}
// tool_choice
toolChoiceResult := gjson.GetBytes(rawJSON, "tool_choice")
if toolChoiceResult.Exists() {
toolChoiceType := ""
toolChoiceName := ""
if toolChoiceResult.IsObject() {
toolChoiceType = toolChoiceResult.Get("type").String()
toolChoiceName = toolChoiceResult.Get("name").String()
} else if toolChoiceResult.Type == gjson.String {
toolChoiceType = toolChoiceResult.String()
}
switch toolChoiceType {
case "auto":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "AUTO")
case "none":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "NONE")
case "any":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
case "tool":
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.mode", "ANY")
if toolChoiceName != "" {
out, _ = sjson.Set(out, "request.toolConfig.functionCallingConfig.allowedFunctionNames", []string{toolChoiceName})
}
}
}
// Map Anthropic thinking -> Gemini CLI thinkingConfig when enabled
// Translator only does format conversion, ApplyThinking handles model capability validation.
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
switch t.Get("type").String() {
case "enabled":
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := int(b.Int())
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
case "adaptive", "auto":
// For adaptive thinking:
// - If output_config.effort is explicitly present, pass through as thinkingLevel.
// - Otherwise, treat it as "enabled with target-model maximum" and emit high.
// ApplyThinking handles clamping to target model's supported levels.
effort := ""
if v := gjson.GetBytes(rawJSON, "output_config.effort"); v.Exists() && v.Type == gjson.String {
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", effort)
} else {
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.thinkingLevel", "high")
}
out, _ = sjson.Set(out, "request.generationConfig.thinkingConfig.includeThoughts", true)
}
}
if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.temperature", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topP", v.Num)
}
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
}
outBytes := []byte(out)
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")
return outBytes
}
================================================
FILE: internal/translator/gemini-cli/claude/gemini-cli_claude_request_test.go
================================================
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToCLI_ToolChoice_SpecificTool(t *testing.T) {
inputJSON := []byte(`{
"model": "gemini-3-flash-preview",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "hi"}
]
}
],
"tools": [
{
"name": "json",
"description": "A JSON tool",
"input_schema": {
"type": "object",
"properties": {}
}
}
],
"tool_choice": {"type": "tool", "name": "json"}
}`)
output := ConvertClaudeRequestToCLI("gemini-3-flash-preview", inputJSON, false)
if got := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.mode").String(); got != "ANY" {
t.Fatalf("Expected request.toolConfig.functionCallingConfig.mode 'ANY', got '%s'", got)
}
allowed := gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Array()
if len(allowed) != 1 || allowed[0].String() != "json" {
t.Fatalf("Expected allowedFunctionNames ['json'], got %s", gjson.GetBytes(output, "request.toolConfig.functionCallingConfig.allowedFunctionNames").Raw)
}
}
================================================
FILE: internal/translator/gemini-cli/claude/gemini-cli_claude_response.go
================================================
// Package claude provides response translation functionality for Claude Code API compatibility.
// This package handles the conversion of backend client responses into Claude Code-compatible
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
// different response types including text content, thinking processes, and function calls.
// The translation ensures proper sequencing of SSE events and maintains state across
// multiple response chunks to provide a seamless streaming experience.
package claude
import (
"bytes"
"context"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Params holds parameters for response conversion and maintains state across streaming chunks.
// This structure tracks the current state of the response translation process to ensure
// proper sequencing of SSE events and transitions between different content types.
type Params struct {
HasFirstResponse bool // Indicates if the initial message_start event has been sent
ResponseType int // Current response type: 0=none, 1=content, 2=thinking, 3=function
ResponseIndex int // Index counter for content blocks in the streaming response
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
}
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
var toolUseIDCounter uint64
// ConvertGeminiCLIResponseToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude Code-compatible Server-Sent Events (SSE) format. It manages different response types
// and handles state transitions between content blocks, thinking processes, and function calls.
//
// Response type states: 0=none, 1=content, 2=thinking, 3=function
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing a Claude Code-compatible JSON response
func ConvertGeminiCLIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &Params{
HasFirstResponse: false,
ResponseType: 0,
ResponseIndex: 0,
}
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
// Only send message_stop if we have actually output content
if (*param).(*Params).HasContent {
return []string{
"event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n\n",
}
}
return []string{}
}
// Track whether tools are being used in this response chunk
usedTool := false
output := ""
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk to establish the streaming session
if !(*param).(*Params).HasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values according to Claude Code API specification
// This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
(*param).(*Params).HasFirstResponse = true
}
// Process the response parts array from the backend client
// Each part can contain text content, thinking content, or function calls
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
// Extract the different types of content from each part
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
// Handle text content (both regular content and thinking)
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
// Continue existing thinking block if already in thinking state
if (*param).(*Params).ResponseType == 2 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).HasContent = true
} else {
// Transition from another state to thinking
// First, close any existing content block
if (*param).(*Params).ResponseType != 0 {
if (*param).(*Params).ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, (*param).(*Params).ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).ResponseType = 2 // Set state to thinking
(*param).(*Params).HasContent = true
}
} else {
// Process regular text content (user-visible output)
// Continue existing text block if already in content state
if (*param).(*Params).ResponseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).HasContent = true
} else {
// Transition from another state to text content
// First, close any existing content block
if (*param).(*Params).ResponseType != 0 {
if (*param).(*Params).ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, (*param).(*Params).ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
(*param).(*Params).ResponseType = 1 // Set state to content
(*param).(*Params).HasContent = true
}
}
} else if functionCallResult.Exists() {
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude Code API compatibility
usedTool = true
fcName := functionCallResult.Get("name").String()
// Handle state transitions when switching to function calls
// Close any existing function call block first
if (*param).(*Params).ResponseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++
(*param).(*Params).ResponseType = 0
}
// Special handling for thinking state transition
if (*param).(*Params).ResponseType == 2 {
// output = output + "event: content_block_delta\n"
// output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, (*param).(*Params).ResponseIndex)
// output = output + "\n\n\n"
}
// Close any other existing content block
if (*param).(*Params).ResponseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
(*param).(*Params).ResponseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude Code format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, (*param).(*Params).ResponseIndex)
data, _ = sjson.Set(data, "content_block.id", util.SanitizeClaudeToolID(fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&toolUseIDCounter, 1))))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
}
(*param).(*Params).ResponseType = 3
(*param).(*Params).HasContent = true
}
}
}
usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata")
// Process usage metadata and finish reason when present in the response
if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
// Only send final events if we have actually output content
if (*param).(*Params).HasContent {
// Close the final content block
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, (*param).(*Params).ResponseIndex)
output = output + "\n\n\n"
// Send the final message delta with usage information and stop reason
output = output + "event: message_delta\n"
output = output + `data: `
// Create the message delta template with appropriate stop reason
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
// Set tool_use stop reason if tools were used in this response
if usedTool {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
} else if finish := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finish.Exists() && finish.String() == "MAX_TOKENS" {
template = `{"type":"message_delta","delta":{"stop_reason":"max_tokens","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
}
// Include thinking tokens in output token count if present
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n"
}
}
}
return []string{output}
}
// ConvertGeminiCLIResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the Gemini CLI API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Claude-compatible JSON response.
func ConvertGeminiCLIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
_ = originalRequestRawJSON
_ = requestRawJSON
root := gjson.ParseBytes(rawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("response.responseId").String())
out, _ = sjson.Set(out, "model", root.Get("response.modelVersion").String())
inputTokens := root.Get("response.usageMetadata.promptTokenCount").Int()
outputTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int() + root.Get("response.usageMetadata.thoughtsTokenCount").Int()
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
parts := root.Get("response.candidates.0.content.parts")
textBuilder := strings.Builder{}
thinkingBuilder := strings.Builder{}
toolIDCounter := 0
hasToolCall := false
flushText := func() {
if textBuilder.Len() == 0 {
return
}
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
textBuilder.Reset()
}
flushThinking := func() {
if thinkingBuilder.Len() == 0 {
return
}
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
thinkingBuilder.Reset()
}
if parts.IsArray() {
for _, part := range parts.Array() {
if text := part.Get("text"); text.Exists() && text.String() != "" {
if part.Get("thought").Bool() {
flushText()
thinkingBuilder.WriteString(text.String())
continue
}
flushThinking()
textBuilder.WriteString(text.String())
continue
}
if functionCall := part.Get("functionCall"); functionCall.Exists() {
flushThinking()
flushText()
hasToolCall = true
name := functionCall.Get("name").String()
toolIDCounter++
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
toolBlock, _ = sjson.Set(toolBlock, "name", name)
inputRaw := "{}"
if args := functionCall.Get("args"); args.Exists() && gjson.Valid(args.Raw) && args.IsObject() {
inputRaw = args.Raw
}
toolBlock, _ = sjson.SetRaw(toolBlock, "input", inputRaw)
out, _ = sjson.SetRaw(out, "content.-1", toolBlock)
continue
}
}
}
flushThinking()
flushText()
stopReason := "end_turn"
if hasToolCall {
stopReason = "tool_use"
} else {
if finish := root.Get("response.candidates.0.finishReason"); finish.Exists() {
switch finish.String() {
case "MAX_TOKENS":
stopReason = "max_tokens"
case "STOP", "FINISH_REASON_UNSPECIFIED", "UNKNOWN":
stopReason = "end_turn"
default:
stopReason = "end_turn"
}
}
}
out, _ = sjson.Set(out, "stop_reason", stopReason)
if inputTokens == int64(0) && outputTokens == int64(0) && !root.Get("response.usageMetadata").Exists() {
out, _ = sjson.Delete(out, "usage")
}
return out
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
}
================================================
FILE: internal/translator/gemini-cli/claude/init.go
================================================
package claude
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Claude,
GeminiCLI,
ConvertClaudeRequestToCLI,
interfaces.TranslateResponse{
Stream: ConvertGeminiCLIResponseToClaude,
NonStream: ConvertGeminiCLIResponseToClaudeNonStream,
TokenCount: ClaudeTokenCount,
},
)
}
================================================
FILE: internal/translator/gemini-cli/gemini/gemini-cli_gemini_request.go
================================================
// Package gemini provides request translation functionality for Gemini CLI to Gemini API compatibility.
// It handles parsing and transforming Gemini CLI API requests into Gemini API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini CLI API format and Gemini API's expected format.
package gemini
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiRequestToGeminiCLI parses and transforms a Gemini CLI API request into Gemini API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini API.
// The function performs the following transformations:
// 1. Extracts the model information from the request
// 2. Restructures the JSON to match Gemini API format
// 3. Converts system instructions to the expected format
// 4. Fixes CLI tool response format and grouping
//
// Parameters:
// - modelName: The name of the model to use for the request (unused in current implementation)
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini API format
func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
template := ""
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := fixCLIToolResponse(template)
if errFixCLIToolResponse != nil {
return []byte{}
}
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJSON = []byte(template)
// Normalize roles in request.contents: default to valid values if missing/invalid
contents := gjson.GetBytes(rawJSON, "request.contents")
if contents.Exists() {
prevRole := ""
idx := 0
contents.ForEach(func(_ gjson.Result, value gjson.Result) bool {
role := value.Get("role").String()
valid := role == "user" || role == "model"
if role == "" || !valid {
var newRole string
if prevRole == "" {
newRole = "user"
} else if prevRole == "user" {
newRole = "model"
} else {
newRole = "user"
}
path := fmt.Sprintf("request.contents.%d.role", idx)
rawJSON, _ = sjson.SetBytes(rawJSON, path, newRole)
role = newRole
}
prevRole = role
idx++
return true
})
}
toolsResult := gjson.GetBytes(rawJSON, "request.tools")
if toolsResult.Exists() && toolsResult.IsArray() {
toolResults := toolsResult.Array()
for i := 0; i < len(toolResults); i++ {
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations", i))
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() {
functionDeclarationsResults := functionDeclarationsResult.Array()
for j := 0; j < len(functionDeclarationsResults); j++ {
parametersResult := gjson.GetBytes(rawJSON, fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j))
if parametersResult.Exists() {
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("request.tools.%d.function_declarations.%d.parameters", i, j), fmt.Sprintf("request.tools.%d.function_declarations.%d.parametersJsonSchema", i, j))
rawJSON = []byte(strJson)
}
}
}
}
}
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool {
if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
} else if part.Get("thoughtSignature").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
}
return true
})
}
return true
})
// Filter out contents with empty parts to avoid Gemini API error:
// "required oneof field 'data' must have one initialized field"
filteredContents := "[]"
hasFiltered := false
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(_, content gjson.Result) bool {
parts := content.Get("parts")
if !parts.IsArray() || len(parts.Array()) == 0 {
hasFiltered = true
return true
}
filteredContents, _ = sjson.SetRaw(filteredContents, "-1", content.Raw)
return true
})
if hasFiltered {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "request.contents", []byte(filteredContents))
}
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
}
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ResponsesNeeded int
CallNames []string // ordered function call names for backfilling empty response names
}
// backfillFunctionResponseName ensures that a functionResponse JSON object has a non-empty name,
// falling back to fallbackName if the original is empty.
func backfillFunctionResponseName(raw string, fallbackName string) string {
name := gjson.Get(raw, "functionResponse.name").String()
if strings.TrimSpace(name) == "" && fallbackName != "" {
raw, _ = sjson.Set(raw, "functionResponse.name", fallbackName)
}
return raw
}
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
// This function transforms the CLI tool response format by intelligently grouping function calls
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
// and their responses are properly associated and structured.
//
// Parameters:
// - input: The input JSON string to be processed
//
// Returns:
// - string: The processed JSON string with grouped function calls and responses
// - error: An error if the processing fails
func fixCLIToolResponse(input string) (string, error) {
// Parse the input JSON to extract the conversation structure
parsed := gjson.Parse(input)
// Extract the contents array which contains the conversation messages
contents := parsed.Get("request.contents")
if !contents.Exists() {
// log.Debugf(input)
return input, fmt.Errorf("contents not found in input")
}
// Initialize data structures for processing and grouping
contentsWrapper := `{"contents":[]}`
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
// Process each content object in the conversation
// This iterates through messages and groups function calls with their responses
contents.ForEach(func(key, value gjson.Result) bool {
role := value.Get("role").String()
parts := value.Get("parts")
// Check if this content has function responses
var responsePartsInThisContent []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
responsePartsInThisContent = append(responsePartsInThisContent, part)
}
return true
})
// If this content has function responses, collect them
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if pending groups can be satisfied (FIFO: oldest group first)
for len(pendingGroups) > 0 && len(collectedResponses) >= pendingGroups[0].ResponsesNeeded {
group := pendingGroups[0]
pendingGroups = pendingGroups[1:]
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
functionResponseContent := `{"parts":[],"role":"function"}`
for ri, response := range groupResponses {
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", raw)
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
}
return true // Skip adding this content, responses are merged
}
// If this is a model with function calls, create a new group
if role == "model" {
var callNames []string
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
callNames = append(callNames, part.Get("functionCall.name").String())
}
return true
})
if len(callNames) > 0 {
// Add the model content
if !value.IsObject() {
log.Warnf("failed to parse model content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ResponsesNeeded: len(callNames),
CallNames: callNames,
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
if !value.IsObject() {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
}
} else {
// Non-model content (user, etc.)
if !value.IsObject() {
log.Warnf("failed to parse content")
return true
}
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", value.Raw)
}
return true
})
// Handle any remaining pending groups with remaining responses
for _, group := range pendingGroups {
if len(collectedResponses) >= group.ResponsesNeeded {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
functionResponseContent := `{"parts":[],"role":"function"}`
for ri, response := range groupResponses {
if !response.IsObject() {
log.Warnf("failed to parse function response")
continue
}
raw := backfillFunctionResponseName(response.Raw, group.CallNames[ri])
functionResponseContent, _ = sjson.SetRaw(functionResponseContent, "parts.-1", raw)
}
if gjson.Get(functionResponseContent, "parts.#").Int() > 0 {
contentsWrapper, _ = sjson.SetRaw(contentsWrapper, "contents.-1", functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
result, _ = sjson.SetRaw(result, "request.contents", gjson.Get(contentsWrapper, "contents").Raw)
return result, nil
}
================================================
FILE: internal/translator/gemini-cli/gemini/gemini-cli_gemini_response.go
================================================
// Package gemini provides request translation functionality for Gemini to Gemini CLI API compatibility.
// It handles parsing and transforming Gemini API requests into Gemini CLI API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini API format and Gemini CLI API's expected format.
package gemini
import (
"bytes"
"context"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiCliResponseToGemini parses and transforms a Gemini CLI API request into Gemini API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the Gemini API.
// The function performs the following transformations:
// 1. Extracts the response data from the request
// 2. Handles alternative response formats
// 3. Processes array responses by extracting individual response objects
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model to use for the request (unused in current implementation)
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - []string: The transformed request data in Gemini API format
func ConvertGeminiCliResponseToGemini(ctx context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) []string {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
if alt, ok := ctx.Value("alt").(string); ok {
var chunk []byte
if alt == "" {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
chunk = []byte(responseResult.Raw)
}
} else {
chunkTemplate := "[]"
responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() {
responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
}
}
}
chunk = []byte(chunkTemplate)
}
return []string{string(chunk)}
}
return []string{}
}
// ConvertGeminiCliResponseToGeminiNonStream converts a non-streaming Gemini CLI request to a non-streaming Gemini response.
// This function processes the complete Gemini CLI request and transforms it into a single Gemini-compatible
// JSON response. It extracts the response data from the request and returns it in the expected format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON request data from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion (unused in current implementation)
//
// Returns:
// - string: A Gemini-compatible JSON response containing the response data
func ConvertGeminiCliResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
return responseResult.Raw
}
return string(rawJSON)
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}
================================================
FILE: internal/translator/gemini-cli/gemini/init.go
================================================
package gemini
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Gemini,
GeminiCLI,
ConvertGeminiRequestToGeminiCLI,
interfaces.TranslateResponse{
Stream: ConvertGeminiCliResponseToGemini,
NonStream: ConvertGeminiCliResponseToGeminiNonStream,
TokenCount: GeminiTokenCount,
},
)
}
================================================
FILE: internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_request.go
================================================
// Package openai provides request translation functionality for OpenAI to Gemini CLI API compatibility.
// It converts OpenAI Chat Completions requests into Gemini CLI compatible JSON using gjson/sjson only.
package chat_completions
import (
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
// ConvertOpenAIRequestToGeminiCLI converts an OpenAI Chat Completions request (raw JSON)
// into a complete Gemini CLI request JSON. All JSON construction uses sjson and lookups use gjson.
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the OpenAI API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini CLI API format
func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := inputRawJSON
// Base envelope (no default thinkingConfig)
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
// Model
out, _ = sjson.SetBytes(out, "model", modelName)
// Let user-provided generationConfig pass through
if genConfig := gjson.GetBytes(rawJSON, "generationConfig"); genConfig.Exists() {
out, _ = sjson.SetRawBytes(out, "request.generationConfig", []byte(genConfig.Raw))
}
// Apply thinking configuration: convert OpenAI reasoning_effort to Gemini CLI thinkingConfig.
// Inline translation-only mapping; capability checks happen later in ApplyThinking.
re := gjson.GetBytes(rawJSON, "reasoning_effort")
if re.Exists() {
effort := strings.ToLower(strings.TrimSpace(re.String()))
if effort != "" {
thinkingPath := "request.generationConfig.thinkingConfig"
if effort == "auto" {
out, _ = sjson.SetBytes(out, thinkingPath+".thinkingBudget", -1)
out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", true)
} else {
out, _ = sjson.SetBytes(out, thinkingPath+".thinkingLevel", effort)
out, _ = sjson.SetBytes(out, thinkingPath+".includeThoughts", effort != "none")
}
}
}
// Temperature/top_p/top_k
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
}
if tpr := gjson.GetBytes(rawJSON, "top_p"); tpr.Exists() && tpr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.topP", tpr.Num)
}
if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
}
// Candidate count (OpenAI 'n' parameter)
if n := gjson.GetBytes(rawJSON, "n"); n.Exists() && n.Type == gjson.Number {
if val := n.Int(); val > 1 {
out, _ = sjson.SetBytes(out, "request.generationConfig.candidateCount", val)
}
}
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
if mods := gjson.GetBytes(rawJSON, "modalities"); mods.Exists() && mods.IsArray() {
var responseMods []string
for _, m := range mods.Array() {
switch strings.ToLower(m.String()) {
case "text":
responseMods = append(responseMods, "TEXT")
case "image":
responseMods = append(responseMods, "IMAGE")
}
}
if len(responseMods) > 0 {
out, _ = sjson.SetBytes(out, "request.generationConfig.responseModalities", responseMods)
}
}
// OpenRouter-style image_config support
// If the input uses top-level image_config.aspect_ratio, map it into request.generationConfig.imageConfig.aspectRatio.
if imgCfg := gjson.GetBytes(rawJSON, "image_config"); imgCfg.Exists() && imgCfg.IsObject() {
if ar := imgCfg.Get("aspect_ratio"); ar.Exists() && ar.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.aspectRatio", ar.Str)
}
if size := imgCfg.Get("image_size"); size.Exists() && size.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.generationConfig.imageConfig.imageSize", size.Str)
}
}
// messages -> systemInstruction + contents
messages := gjson.GetBytes(rawJSON, "messages")
if messages.IsArray() {
arr := messages.Array()
// First pass: assistant tool_calls id->name map
tcID2Name := map[string]string{}
for i := 0; i < len(arr); i++ {
m := arr[i]
if m.Get("role").String() == "assistant" {
tcs := m.Get("tool_calls")
if tcs.IsArray() {
for _, tc := range tcs.Array() {
if tc.Get("type").String() == "function" {
id := tc.Get("id").String()
name := tc.Get("function.name").String()
if id != "" && name != "" {
tcID2Name[id] = name
}
}
}
}
}
}
// Second pass build systemInstruction/tool responses cache
toolResponses := map[string]string{} // tool_call_id -> response text
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
if role == "tool" {
toolCallID := m.Get("tool_call_id").String()
if toolCallID != "" {
c := m.Get("content")
toolResponses[toolCallID] = c.Raw
}
}
}
systemPartIndex := 0
for i := 0; i < len(arr); i++ {
m := arr[i]
role := m.Get("role").String()
content := m.Get("content")
if (role == "system" || role == "developer") && len(arr) > 1 {
// system -> request.systemInstruction as a user message style
if content.Type == gjson.String {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.String())
systemPartIndex++
} else if content.IsObject() && content.Get("type").String() == "text" {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), content.Get("text").String())
systemPartIndex++
} else if content.IsArray() {
contents := content.Array()
if len(contents) > 0 {
out, _ = sjson.SetBytes(out, "request.systemInstruction.role", "user")
for j := 0; j < len(contents); j++ {
out, _ = sjson.SetBytes(out, fmt.Sprintf("request.systemInstruction.parts.%d.text", systemPartIndex), contents[j].Get("text").String())
systemPartIndex++
}
}
}
} else if role == "user" || ((role == "system" || role == "developer") && len(arr) == 1) {
// Build single user content node to avoid splitting into multiple contents
node := []byte(`{"role":"user","parts":[]}`)
if content.Type == gjson.String {
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
} else if content.IsArray() {
items := content.Array()
p := 0
for _, item := range items {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 {
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
case "file":
filename := item.Get("file.filename").String()
fileData := item.Get("file.file_data").String()
ext := ""
if sp := strings.Split(filename, "."); len(sp) > 1 {
ext = sp[len(sp)-1]
}
if mimeType, ok := misc.MimeTypes[ext]; ok {
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mimeType)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", fileData)
p++
} else {
log.Warnf("Unknown file name extension '%s' in user message, skip", ext)
}
}
}
}
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
} else if role == "assistant" {
p := 0
node := []byte(`{"role":"model","parts":[]}`)
if content.Type == gjson.String {
// Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
}
}
}
}
}
// Tool calls -> single model content with functionCall parts
tcs := m.Get("tool_calls")
if tcs.IsArray() {
fIDs := make([]string, 0)
for _, tc := range tcs.Array() {
if tc.Get("type").String() != "function" {
continue
}
fid := tc.Get("id").String()
fname := tc.Get("function.name").String()
fargs := tc.Get("function.arguments").String()
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
p++
if fid != "" {
fIDs = append(fIDs, fid)
}
}
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
resp := toolResponses[fid]
if resp == "" {
resp = "{}"
}
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
pp++
}
}
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
}
}
}
}
// tools -> request.tools[].functionDeclarations + request.tools[].googleSearch/codeExecution/urlContext passthrough
tools := gjson.GetBytes(rawJSON, "tools")
if tools.IsArray() && len(tools.Array()) > 0 {
functionToolNode := []byte(`{}`)
hasFunction := false
googleSearchNodes := make([][]byte, 0)
codeExecutionNodes := make([][]byte, 0)
urlContextNodes := make([][]byte, 0)
for _, t := range tools.Array() {
if t.Get("type").String() == "function" {
fn := t.Get("function")
if fn.Exists() && fn.IsObject() {
fnRaw := fn.Raw
if fn.Get("parameters").Exists() {
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema")
if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
} else {
fnRaw = renamed
}
} else {
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.SetRaw(fnRaw, "parametersJsonSchema.properties", `{}`)
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
}
fnRaw, _ = sjson.Delete(fnRaw, "strict")
if !hasFunction {
functionToolNode, _ = sjson.SetRawBytes(functionToolNode, "functionDeclarations", []byte("[]"))
}
tmp, errSet := sjson.SetRawBytes(functionToolNode, "functionDeclarations.-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
functionToolNode = tmp
hasFunction = true
}
}
if gs := t.Get("google_search"); gs.Exists() {
googleToolNode := []byte(`{}`)
var errSet error
googleToolNode, errSet = sjson.SetRawBytes(googleToolNode, "googleSearch", []byte(gs.Raw))
if errSet != nil {
log.Warnf("Failed to set googleSearch tool: %v", errSet)
continue
}
googleSearchNodes = append(googleSearchNodes, googleToolNode)
}
if ce := t.Get("code_execution"); ce.Exists() {
codeToolNode := []byte(`{}`)
var errSet error
codeToolNode, errSet = sjson.SetRawBytes(codeToolNode, "codeExecution", []byte(ce.Raw))
if errSet != nil {
log.Warnf("Failed to set codeExecution tool: %v", errSet)
continue
}
codeExecutionNodes = append(codeExecutionNodes, codeToolNode)
}
if uc := t.Get("url_context"); uc.Exists() {
urlToolNode := []byte(`{}`)
var errSet error
urlToolNode, errSet = sjson.SetRawBytes(urlToolNode, "urlContext", []byte(uc.Raw))
if errSet != nil {
log.Warnf("Failed to set urlContext tool: %v", errSet)
continue
}
urlContextNodes = append(urlContextNodes, urlToolNode)
}
}
if hasFunction || len(googleSearchNodes) > 0 || len(codeExecutionNodes) > 0 || len(urlContextNodes) > 0 {
toolsNode := []byte("[]")
if hasFunction {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", functionToolNode)
}
for _, googleNode := range googleSearchNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", googleNode)
}
for _, codeNode := range codeExecutionNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", codeNode)
}
for _, urlNode := range urlContextNodes {
toolsNode, _ = sjson.SetRawBytes(toolsNode, "-1", urlNode)
}
out, _ = sjson.SetRawBytes(out, "request.tools", toolsNode)
}
}
return common.AttachDefaultSafetySettings(out, "request.safetySettings")
}
// itoa converts int to string without strconv import for few usages.
func itoa(i int) string { return fmt.Sprintf("%d", i) }
================================================
FILE: internal/translator/gemini-cli/openai/chat-completions/gemini-cli_openai_response.go
================================================
// Package openai provides response translation functionality for Gemini CLI to OpenAI API compatibility.
// This package handles the conversion of Gemini CLI API responses into OpenAI Chat Completions-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by OpenAI API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, reasoning content, and usage metadata appropriately.
package chat_completions
import (
"bytes"
"context"
"fmt"
"strings"
"sync/atomic"
"time"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// convertCliResponseToOpenAIChatParams holds parameters for response conversion.
type convertCliResponseToOpenAIChatParams struct {
UnixTimestamp int64
FunctionIndex int
}
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
var functionCallIDCounter uint64
// ConvertCliResponseToOpenAI translates a single chunk of a streaming response from the
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
// The function handles text content, tool calls, reasoning content, and usage metadata, outputting
// responses that match the OpenAI API format. It supports incremental updates for streaming responses.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &convertCliResponseToOpenAIChatParams{
UnixTimestamp: 0,
FunctionIndex: 0,
}
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
(*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
} else {
template, _ = sjson.Set(template, "created", (*param).(*convertCliResponseToOpenAIChatParams).UnixTimestamp)
}
// Extract and set the response ID.
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
}
finishReason := ""
if stopReasonResult := gjson.GetBytes(rawJSON, "response.stop_reason"); stopReasonResult.Exists() {
finishReason = stopReasonResult.String()
}
if finishReason == "" {
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
finishReason = finishReasonResult.String()
}
}
finishReason = strings.ToLower(finishReason)
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini-cli openai response: failed to set cached_tokens: %v", err)
}
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
hasFunctionCall := false
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
thoughtSignatureResult := partResult.Get("thoughtSignature")
if !thoughtSignatureResult.Exists() {
thoughtSignatureResult = partResult.Get("thought_signature")
}
inlineDataResult := partResult.Get("inlineData")
if !inlineDataResult.Exists() {
inlineDataResult = partResult.Get("inline_data")
}
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
// Ignore encrypted thoughtSignature but keep any actual content in the same part.
if hasThoughtSignature && !hasContentPayload {
continue
}
if partTextResult.Exists() {
textContent := partTextResult.String()
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", textContent)
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", textContent)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
hasFunctionCall = true
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
if toolCallsResult.Exists() && toolCallsResult.IsArray() {
functionCallIndex = len(toolCallsResult.Array())
} else {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
}
functionCallTemplate := `{"id": "","index": 0,"type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d-%d", fcName, time.Now().UnixNano(), atomic.AddUint64(&functionCallIDCounter, 1)))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "index", functionCallIndex)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallTemplate)
} else if inlineDataResult.Exists() {
data := inlineDataResult.Get("data").String()
if data == "" {
continue
}
mimeType := inlineDataResult.Get("mimeType").String()
if mimeType == "" {
mimeType = inlineDataResult.Get("mime_type").String()
}
if mimeType == "" {
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}
}
}
if hasFunctionCall {
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
} else if finishReason != "" && (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex == 0 {
// Only pass through specific finish reasons
if finishReason == "max_tokens" || finishReason == "stop" {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReason)
}
}
return []string{template}
}
// ConvertCliResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
// This function processes the complete Gemini CLI response and transforms it into a single OpenAI-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the OpenAI API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertCliResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
return ConvertGeminiResponseToOpenAINonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, []byte(responseResult.Raw), param)
}
return ""
}
================================================
FILE: internal/translator/gemini-cli/openai/chat-completions/init.go
================================================
package chat_completions
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenAI,
GeminiCLI,
ConvertOpenAIRequestToGeminiCLI,
interfaces.TranslateResponse{
Stream: ConvertCliResponseToOpenAI,
NonStream: ConvertCliResponseToOpenAINonStream,
},
)
}
================================================
FILE: internal/translator/gemini-cli/openai/responses/gemini-cli_openai-responses_request.go
================================================
package responses
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini-cli/gemini"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
)
func ConvertOpenAIResponsesRequestToGeminiCLI(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
return ConvertGeminiRequestToGeminiCLI(modelName, rawJSON, stream)
}
================================================
FILE: internal/translator/gemini-cli/openai/responses/gemini-cli_openai-responses_response.go
================================================
package responses
import (
"context"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
"github.com/tidwall/gjson"
)
func ConvertGeminiCLIResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)
}
return ConvertGeminiResponseToOpenAIResponses(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
func ConvertGeminiCLIResponseToOpenAIResponsesNonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
responseResult := gjson.GetBytes(rawJSON, "response")
if responseResult.Exists() {
rawJSON = []byte(responseResult.Raw)
}
requestResult := gjson.GetBytes(originalRequestRawJSON, "request")
if responseResult.Exists() {
originalRequestRawJSON = []byte(requestResult.Raw)
}
requestResult = gjson.GetBytes(requestRawJSON, "request")
if responseResult.Exists() {
requestRawJSON = []byte(requestResult.Raw)
}
return ConvertGeminiResponseToOpenAIResponsesNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
================================================
FILE: internal/translator/gemini-cli/openai/responses/init.go
================================================
package responses
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenaiResponse,
GeminiCLI,
ConvertOpenAIResponsesRequestToGeminiCLI,
interfaces.TranslateResponse{
Stream: ConvertGeminiCLIResponseToOpenAIResponses,
NonStream: ConvertGeminiCLIResponseToOpenAIResponsesNonStream,
},
)
}
================================================
FILE: internal/translator/init.go
================================================
package translator
import (
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini-cli"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/openai/responses"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini-cli"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/openai/responses"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini-cli/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini-cli/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini-cli/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini-cli/openai/responses"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/gemini-cli"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini-cli"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/chat-completions"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/openai/responses"
)
================================================
FILE: internal/translator/openai/claude/init.go
================================================
package claude
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Claude,
OpenAI,
ConvertClaudeRequestToOpenAI,
interfaces.TranslateResponse{
Stream: ConvertOpenAIResponseToClaude,
NonStream: ConvertOpenAIResponseToClaudeNonStream,
TokenCount: ClaudeTokenCount,
},
)
}
================================================
FILE: internal/translator/openai/claude/openai_claude_request.go
================================================
// Package claude provides request translation functionality for Anthropic to OpenAI API.
// It handles parsing and transforming Anthropic API requests into OpenAI Chat Completions API format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Anthropic API format and OpenAI API's expected format.
package claude
import (
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertClaudeRequestToOpenAI parses and transforms an Anthropic API request into OpenAI Chat Completions API format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the OpenAI API.
func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
// Base OpenAI Chat Completions API template
out := `{"model":"","messages":[]}`
root := gjson.ParseBytes(rawJSON)
// Model mapping
out, _ = sjson.Set(out, "model", modelName)
// Max tokens
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
}
// Temperature
if temp := root.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
} else if topP := root.Get("top_p"); topP.Exists() { // Top P
out, _ = sjson.Set(out, "top_p", topP.Float())
}
// Stop sequences -> stop
if stopSequences := root.Get("stop_sequences"); stopSequences.Exists() {
if stopSequences.IsArray() {
var stops []string
stopSequences.ForEach(func(_, value gjson.Result) bool {
stops = append(stops, value.String())
return true
})
if len(stops) > 0 {
if len(stops) == 1 {
out, _ = sjson.Set(out, "stop", stops[0])
} else {
out, _ = sjson.Set(out, "stop", stops)
}
}
}
}
// Stream
out, _ = sjson.Set(out, "stream", stream)
// Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort
if thinkingConfig := root.Get("thinking"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
if thinkingType := thinkingConfig.Get("type"); thinkingType.Exists() {
switch thinkingType.String() {
case "enabled":
if budgetTokens := thinkingConfig.Get("budget_tokens"); budgetTokens.Exists() {
budget := int(budgetTokens.Int())
if effort, ok := thinking.ConvertBudgetToLevel(budget); ok && effort != "" {
out, _ = sjson.Set(out, "reasoning_effort", effort)
}
} else {
// No budget_tokens specified, default to "auto" for enabled thinking
if effort, ok := thinking.ConvertBudgetToLevel(-1); ok && effort != "" {
out, _ = sjson.Set(out, "reasoning_effort", effort)
}
}
case "adaptive", "auto":
// Adaptive thinking can carry an explicit effort in output_config.effort (Claude 4.6).
// Pass through directly; ApplyThinking handles clamping to target model's levels.
effort := ""
if v := root.Get("output_config.effort"); v.Exists() && v.Type == gjson.String {
effort = strings.ToLower(strings.TrimSpace(v.String()))
}
if effort != "" {
out, _ = sjson.Set(out, "reasoning_effort", effort)
} else {
out, _ = sjson.Set(out, "reasoning_effort", string(thinking.LevelXHigh))
}
case "disabled":
if effort, ok := thinking.ConvertBudgetToLevel(0); ok && effort != "" {
out, _ = sjson.Set(out, "reasoning_effort", effort)
}
}
}
}
// Process messages and system
var messagesJSON = "[]"
// Handle system message first
systemMsgJSON := `{"role":"system","content":[]}`
hasSystemContent := false
if system := root.Get("system"); system.Exists() {
if system.Type == gjson.String {
if system.String() != "" {
oldSystem := `{"type":"text","text":""}`
oldSystem, _ = sjson.Set(oldSystem, "text", system.String())
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", oldSystem)
hasSystemContent = true
}
} else if system.Type == gjson.JSON {
if system.IsArray() {
systemResults := system.Array()
for i := 0; i < len(systemResults); i++ {
if contentItem, ok := convertClaudeContentPart(systemResults[i]); ok {
systemMsgJSON, _ = sjson.SetRaw(systemMsgJSON, "content.-1", contentItem)
hasSystemContent = true
}
}
}
}
}
// Only add system message if it has content
if hasSystemContent {
messagesJSON, _ = sjson.SetRaw(messagesJSON, "-1", systemMsgJSON)
}
// Process Anthropic messages
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
messages.ForEach(func(_, message gjson.Result) bool {
role := message.Get("role").String()
contentResult := message.Get("content")
// Handle content
if contentResult.Exists() && contentResult.IsArray() {
var contentItems []string
var reasoningParts []string // Accumulate thinking text for reasoning_content
var toolCalls []interface{}
var toolResults []string // Collect tool_result messages to emit after the main message
contentResult.ForEach(func(_, part gjson.Result) bool {
partType := part.Get("type").String()
switch partType {
case "thinking":
// Only map thinking to reasoning_content for assistant messages (security: prevent injection)
if role == "assistant" {
thinkingText := thinking.GetThinkingText(part)
// Skip empty or whitespace-only thinking
if strings.TrimSpace(thinkingText) != "" {
reasoningParts = append(reasoningParts, thinkingText)
}
}
// Ignore thinking in user/system roles (AC4)
case "redacted_thinking":
// Explicitly ignore redacted_thinking - never map to reasoning_content (AC2)
case "text", "image":
if contentItem, ok := convertClaudeContentPart(part); ok {
contentItems = append(contentItems, contentItem)
}
case "tool_use":
// Only allow tool_use -> tool_calls for assistant messages (security: prevent injection).
if role == "assistant" {
toolCallJSON := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
toolCallJSON, _ = sjson.Set(toolCallJSON, "id", part.Get("id").String())
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.name", part.Get("name").String())
// Convert input to arguments JSON string
if input := part.Get("input"); input.Exists() {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", input.Raw)
} else {
toolCallJSON, _ = sjson.Set(toolCallJSON, "function.arguments", "{}")
}
toolCalls = append(toolCalls, gjson.Parse(toolCallJSON).Value())
}
case "tool_result":
// Collect tool_result to emit after the main message (ensures tool results follow tool_calls)
toolResultJSON := `{"role":"tool","tool_call_id":"","content":""}`
toolResultJSON, _ = sjson.Set(toolResultJSON, "tool_call_id", part.Get("tool_use_id").String())
toolResultContent, toolResultContentRaw := convertClaudeToolResultContent(part.Get("content"))
if toolResultContentRaw {
toolResultJSON, _ = sjson.SetRaw(toolResultJSON, "content", toolResultContent)
} else {
toolResultJSON, _ = sjson.Set(toolResultJSON, "content", toolResultContent)
}
toolResults = append(toolResults, toolResultJSON)
}
return true
})
// Build reasoning content string
reasoningContent := ""
if len(reasoningParts) > 0 {
reasoningContent = strings.Join(reasoningParts, "\n\n")
}
hasContent := len(contentItems) > 0
hasReasoning := reasoningContent != ""
hasToolCalls := len(toolCalls) > 0
hasToolResults := len(toolResults) > 0
// OpenAI requires: tool messages MUST immediately follow the assistant message with tool_calls.
// Therefore, we emit tool_result messages FIRST (they respond to the previous assistant's tool_calls),
// then emit the current message's content.
for _, toolResultJSON := range toolResults {
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(toolResultJSON).Value())
}
// For assistant messages: emit a single unified message with content, tool_calls, and reasoning_content
// This avoids splitting into multiple assistant messages which breaks OpenAI tool-call adjacency
if role == "assistant" {
if hasContent || hasReasoning || hasToolCalls {
msgJSON := `{"role":"assistant"}`
// Add content (as array if we have items, empty string if reasoning-only)
if hasContent {
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
} else {
// Ensure content field exists for OpenAI compatibility
msgJSON, _ = sjson.Set(msgJSON, "content", "")
}
// Add reasoning_content if present
if hasReasoning {
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent)
}
// Add tool_calls if present (in same message as content)
if hasToolCalls {
msgJSON, _ = sjson.Set(msgJSON, "tool_calls", toolCalls)
}
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
}
} else {
// For non-assistant roles: emit content message if we have content
// If the message only contains tool_results (no text/image), we still processed them above
if hasContent {
msgJSON := `{"role":""}`
msgJSON, _ = sjson.Set(msgJSON, "role", role)
contentArrayJSON := "[]"
for _, contentItem := range contentItems {
contentArrayJSON, _ = sjson.SetRaw(contentArrayJSON, "-1", contentItem)
}
msgJSON, _ = sjson.SetRaw(msgJSON, "content", contentArrayJSON)
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
} else if hasToolResults && !hasContent {
// tool_results already emitted above, no additional user message needed
}
}
} else if contentResult.Exists() && contentResult.Type == gjson.String {
// Simple string content
msgJSON := `{"role":"","content":""}`
msgJSON, _ = sjson.Set(msgJSON, "role", role)
msgJSON, _ = sjson.Set(msgJSON, "content", contentResult.String())
messagesJSON, _ = sjson.Set(messagesJSON, "-1", gjson.Parse(msgJSON).Value())
}
return true
})
}
// Set messages
if gjson.Parse(messagesJSON).IsArray() && len(gjson.Parse(messagesJSON).Array()) > 0 {
out, _ = sjson.SetRaw(out, "messages", messagesJSON)
}
// Process tools - convert Anthropic tools to OpenAI functions
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
var toolsJSON = "[]"
tools.ForEach(func(_, tool gjson.Result) bool {
openAIToolJSON := `{"type":"function","function":{"name":"","description":""}}`
openAIToolJSON, _ = sjson.Set(openAIToolJSON, "function.name", tool.Get("name").String())
openAIToolJSON, _ = sjson.Set(openAIToolJSON, "function.description", tool.Get("description").String())
// Convert Anthropic input_schema to OpenAI function parameters
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
openAIToolJSON, _ = sjson.Set(openAIToolJSON, "function.parameters", inputSchema.Value())
}
toolsJSON, _ = sjson.Set(toolsJSON, "-1", gjson.Parse(openAIToolJSON).Value())
return true
})
if gjson.Parse(toolsJSON).IsArray() && len(gjson.Parse(toolsJSON).Array()) > 0 {
out, _ = sjson.SetRaw(out, "tools", toolsJSON)
}
}
// Tool choice mapping - convert Anthropic tool_choice to OpenAI format
if toolChoice := root.Get("tool_choice"); toolChoice.Exists() {
switch toolChoice.Get("type").String() {
case "auto":
out, _ = sjson.Set(out, "tool_choice", "auto")
case "any":
out, _ = sjson.Set(out, "tool_choice", "required")
case "tool":
// Specific tool choice
toolName := toolChoice.Get("name").String()
toolChoiceJSON := `{"type":"function","function":{"name":""}}`
toolChoiceJSON, _ = sjson.Set(toolChoiceJSON, "function.name", toolName)
out, _ = sjson.SetRaw(out, "tool_choice", toolChoiceJSON)
default:
// Default to auto if not specified
out, _ = sjson.Set(out, "tool_choice", "auto")
}
}
// Handle user parameter (for tracking)
if user := root.Get("user"); user.Exists() {
out, _ = sjson.Set(out, "user", user.String())
}
return []byte(out)
}
func convertClaudeContentPart(part gjson.Result) (string, bool) {
partType := part.Get("type").String()
switch partType {
case "text":
text := part.Get("text").String()
if strings.TrimSpace(text) == "" {
return "", false
}
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text)
return textContent, true
case "image":
var imageURL string
if source := part.Get("source"); source.Exists() {
sourceType := source.Get("type").String()
switch sourceType {
case "base64":
mediaType := source.Get("media_type").String()
if mediaType == "" {
mediaType = "application/octet-stream"
}
data := source.Get("data").String()
if data != "" {
imageURL = "data:" + mediaType + ";base64," + data
}
case "url":
imageURL = source.Get("url").String()
}
}
if imageURL == "" {
imageURL = part.Get("url").String()
}
if imageURL == "" {
return "", false
}
imageContent := `{"type":"image_url","image_url":{"url":""}}`
imageContent, _ = sjson.Set(imageContent, "image_url.url", imageURL)
return imageContent, true
default:
return "", false
}
}
func convertClaudeToolResultContent(content gjson.Result) (string, bool) {
if !content.Exists() {
return "", false
}
if content.Type == gjson.String {
return content.String(), false
}
if content.IsArray() {
var parts []string
contentJSON := "[]"
hasImagePart := false
content.ForEach(func(_, item gjson.Result) bool {
switch {
case item.Type == gjson.String:
text := item.String()
parts = append(parts, text)
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text)
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
case item.IsObject() && item.Get("type").String() == "text":
text := item.Get("text").String()
parts = append(parts, text)
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", text)
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", textContent)
case item.IsObject() && item.Get("type").String() == "image":
contentItem, ok := convertClaudeContentPart(item)
if ok {
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
hasImagePart = true
} else {
parts = append(parts, item.Raw)
}
case item.IsObject() && item.Get("text").Exists() && item.Get("text").Type == gjson.String:
parts = append(parts, item.Get("text").String())
default:
parts = append(parts, item.Raw)
}
return true
})
if hasImagePart {
return contentJSON, true
}
joined := strings.Join(parts, "\n\n")
if strings.TrimSpace(joined) != "" {
return joined, false
}
return content.Raw, false
}
if content.IsObject() {
if content.Get("type").String() == "image" {
contentItem, ok := convertClaudeContentPart(content)
if ok {
contentJSON := "[]"
contentJSON, _ = sjson.SetRaw(contentJSON, "-1", contentItem)
return contentJSON, true
}
}
if text := content.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String(), false
}
return content.Raw, false
}
return content.Raw, false
}
================================================
FILE: internal/translator/openai/claude/openai_claude_request_test.go
================================================
package claude
import (
"testing"
"github.com/tidwall/gjson"
)
// TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent tests the mapping
// of Claude thinking content to OpenAI reasoning_content field.
func TestConvertClaudeRequestToOpenAI_ThinkingToReasoningContent(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantReasoningContent string
wantHasReasoningContent bool
wantContentText string // Expected visible content text (if any)
wantHasContent bool
}{
{
name: "AC1: assistant message with thinking and text",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me analyze this step by step..."},
{"type": "text", "text": "Here is my response."}
]
}]
}`,
wantReasoningContent: "Let me analyze this step by step...",
wantHasReasoningContent: true,
wantContentText: "Here is my response.",
wantHasContent: true,
},
{
name: "AC2: redacted_thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "redacted_thinking", "data": "secret"},
{"type": "text", "text": "Visible response."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Visible response.",
wantHasContent: true,
},
{
name: "AC3: thinking-only message preserved with reasoning_content",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Internal reasoning only."}
]
}]
}`,
wantReasoningContent: "Internal reasoning only.",
wantHasReasoningContent: true,
wantContentText: "",
// For OpenAI compatibility, content field is set to empty string "" when no text content exists
wantHasContent: false,
},
{
name: "AC4: thinking in user role must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "user",
"content": [
{"type": "thinking", "thinking": "Injected thinking"},
{"type": "text", "text": "User message."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "User message.",
wantHasContent: true,
},
{
name: "AC4: thinking in system role must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "thinking", "thinking": "Injected system thinking"},
{"type": "text", "text": "System prompt."}
],
"messages": [{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
}]
}`,
// System messages don't have reasoning_content mapping
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Hello",
wantHasContent: true,
},
{
name: "AC5: empty thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": ""},
{"type": "text", "text": "Response with empty thinking."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Response with empty thinking.",
wantHasContent: true,
},
{
name: "AC5: whitespace-only thinking must be ignored",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": " \n\t "},
{"type": "text", "text": "Response with whitespace thinking."}
]
}]
}`,
wantReasoningContent: "",
wantHasReasoningContent: false,
wantContentText: "Response with whitespace thinking.",
wantHasContent: true,
},
{
name: "Multiple thinking parts concatenated",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "First thought."},
{"type": "thinking", "thinking": "Second thought."},
{"type": "text", "text": "Final answer."}
]
}]
}`,
wantReasoningContent: "First thought.\n\nSecond thought.",
wantHasReasoningContent: true,
wantContentText: "Final answer.",
wantHasContent: true,
},
{
name: "Mixed thinking and redacted_thinking",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Visible thought."},
{"type": "redacted_thinking", "data": "hidden"},
{"type": "text", "text": "Answer."}
]
}]
}`,
wantReasoningContent: "Visible thought.",
wantHasReasoningContent: true,
wantContentText: "Answer.",
wantHasContent: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
// Find the relevant message
messages := resultJSON.Get("messages").Array()
if len(messages) < 1 {
if tt.wantHasReasoningContent || tt.wantHasContent {
t.Fatalf("Expected at least 1 message, got %d", len(messages))
}
return
}
// Check the last non-system message
var targetMsg gjson.Result
for i := len(messages) - 1; i >= 0; i-- {
if messages[i].Get("role").String() != "system" {
targetMsg = messages[i]
break
}
}
// Check reasoning_content
gotReasoningContent := targetMsg.Get("reasoning_content").String()
gotHasReasoningContent := targetMsg.Get("reasoning_content").Exists()
if gotHasReasoningContent != tt.wantHasReasoningContent {
t.Errorf("reasoning_content existence = %v, want %v", gotHasReasoningContent, tt.wantHasReasoningContent)
}
if gotReasoningContent != tt.wantReasoningContent {
t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
}
// Check content
content := targetMsg.Get("content")
// content has meaningful content if it's a non-empty array, or a non-empty string
var gotHasContent bool
switch {
case content.IsArray():
gotHasContent = len(content.Array()) > 0
case content.Type == gjson.String:
gotHasContent = content.String() != ""
default:
gotHasContent = false
}
if gotHasContent != tt.wantHasContent {
t.Errorf("content existence = %v, want %v", gotHasContent, tt.wantHasContent)
}
if tt.wantHasContent && tt.wantContentText != "" {
// Find text content
var foundText string
content.ForEach(func(_, v gjson.Result) bool {
if v.Get("type").String() == "text" {
foundText = v.Get("text").String()
return false
}
return true
})
if foundText != tt.wantContentText {
t.Errorf("content text = %q, want %q", foundText, tt.wantContentText)
}
}
})
}
}
// TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved tests AC3:
// that a message with only thinking content is preserved (not dropped).
func TestConvertClaudeRequestToOpenAI_ThinkingOnlyMessagePreserved(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "What is 2+2?"}]
},
{
"role": "assistant",
"content": [{"type": "thinking", "thinking": "Let me calculate: 2+2=4"}]
},
{
"role": "user",
"content": [{"type": "text", "text": "Thanks"}]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// Should have: user + assistant (thinking-only) + user = 3 messages
if len(messages) != 3 {
t.Fatalf("Expected 3 messages, got %d. Messages: %v", len(messages), resultJSON.Get("messages").Raw)
}
// Check the assistant message (index 1) has reasoning_content
assistantMsg := messages[1]
if assistantMsg.Get("role").String() != "assistant" {
t.Errorf("Expected message[1] to be assistant, got %s", assistantMsg.Get("role").String())
}
if !assistantMsg.Get("reasoning_content").Exists() {
t.Error("Expected assistant message to have reasoning_content")
}
if assistantMsg.Get("reasoning_content").String() != "Let me calculate: 2+2=4" {
t.Errorf("Unexpected reasoning_content: %s", assistantMsg.Get("reasoning_content").String())
}
}
func TestConvertClaudeRequestToOpenAI_SystemMessageScenarios(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantHasSys bool
wantSysText string
}{
{
name: "No system field",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: false,
},
{
name: "Empty string system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: false,
},
{
name: "String system field",
inputJSON: `{
"model": "claude-3-opus",
"system": "Be helpful",
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: true,
wantSysText: "Be helpful",
},
{
name: "Array system field with text",
inputJSON: `{
"model": "claude-3-opus",
"system": [{"type": "text", "text": "Array system"}],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: true,
wantSysText: "Array system",
},
{
name: "Array system field with multiple text blocks",
inputJSON: `{
"model": "claude-3-opus",
"system": [
{"type": "text", "text": "Block 1"},
{"type": "text", "text": "Block 2"}
],
"messages": [{"role": "user", "content": "hello"}]
}`,
wantHasSys: true,
wantSysText: "Block 2", // We will update the test logic to check all blocks or specifically the second one
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
hasSys := false
var sysMsg gjson.Result
if len(messages) > 0 && messages[0].Get("role").String() == "system" {
hasSys = true
sysMsg = messages[0]
}
if hasSys != tt.wantHasSys {
t.Errorf("got hasSystem = %v, want %v", hasSys, tt.wantHasSys)
}
if tt.wantHasSys {
// Check content - it could be string or array in OpenAI
content := sysMsg.Get("content")
var gotText string
if content.IsArray() {
arr := content.Array()
if len(arr) > 0 {
// Get the last element's text for validation
gotText = arr[len(arr)-1].Get("text").String()
}
} else {
gotText = content.String()
}
if tt.wantSysText != "" && gotText != tt.wantSysText {
t.Errorf("got system text = %q, want %q", gotText, tt.wantSysText)
}
}
})
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultOrderAndContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "before"},
{"type": "tool_result", "tool_use_id": "call_1", "content": [{"type":"text","text":"tool ok"}]},
{"type": "text", "text": "after"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// OpenAI requires: tool messages MUST immediately follow assistant(tool_calls).
// Correct order: assistant(tool_calls) + tool(result) + user(before+after)
if len(messages) != 3 {
t.Fatalf("Expected 3 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[0].Get("role").String() != "assistant" || !messages[0].Get("tool_calls").Exists() {
t.Fatalf("Expected messages[0] to be assistant tool_calls, got %s: %s", messages[0].Get("role").String(), messages[0].Raw)
}
// tool message MUST immediately follow assistant(tool_calls) per OpenAI spec
if messages[1].Get("role").String() != "tool" {
t.Fatalf("Expected messages[1] to be tool (must follow tool_calls), got %s", messages[1].Get("role").String())
}
if got := messages[1].Get("tool_call_id").String(); got != "call_1" {
t.Fatalf("Expected tool_call_id %q, got %q", "call_1", got)
}
if got := messages[1].Get("content").String(); got != "tool ok" {
t.Fatalf("Expected tool content %q, got %q", "tool ok", got)
}
// User message comes after tool message
if messages[2].Get("role").String() != "user" {
t.Fatalf("Expected messages[2] to be user, got %s", messages[2].Get("role").String())
}
// User message should contain both "before" and "after" text
if got := messages[2].Get("content.0.text").String(); got != "before" {
t.Fatalf("Expected user text[0] %q, got %q", "before", got)
}
if got := messages[2].Get("content.1.text").String(); got != "after" {
t.Fatalf("Expected user text[1] %q, got %q", "after", got)
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultObjectContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "call_1", "content": {"foo": "bar"}}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// assistant(tool_calls) + tool(result)
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
if messages[1].Get("role").String() != "tool" {
t.Fatalf("Expected messages[1] to be tool, got %s", messages[1].Get("role").String())
}
toolContent := messages[1].Get("content").String()
parsed := gjson.Parse(toolContent)
if parsed.Get("foo").String() != "bar" {
t.Fatalf("Expected tool content JSON foo=bar, got %q", toolContent)
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultTextAndImageContent(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_1",
"content": [
{"type": "text", "text": "tool ok"},
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "text" {
t.Fatalf("Expected first tool content type %q, got %q", "text", got)
}
if got := toolContent.Get("0.text").String(); got != "tool ok" {
t.Fatalf("Expected first tool content text %q, got %q", "tool ok", got)
}
if got := toolContent.Get("1.type").String(); got != "image_url" {
t.Fatalf("Expected second tool content type %q, got %q", "image_url", got)
}
if got := toolContent.Get("1.image_url.url").String(); got != "data:image/png;base64,iVBORw0KGgoAAAANSUhEUg==" {
t.Fatalf("Unexpected image_url: %q", got)
}
}
func TestConvertClaudeRequestToOpenAI_ToolResultURLImageOnly(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "call_1",
"content": {
"type": "image",
"source": {
"type": "url",
"url": "https://example.com/tool.png"
}
}
}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) != 2 {
t.Fatalf("Expected 2 messages, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
toolContent := messages[1].Get("content")
if !toolContent.IsArray() {
t.Fatalf("Expected tool content array, got %s", toolContent.Raw)
}
if got := toolContent.Get("0.type").String(); got != "image_url" {
t.Fatalf("Expected tool content type %q, got %q", "image_url", got)
}
if got := toolContent.Get("0.image_url.url").String(); got != "https://example.com/tool.png" {
t.Fatalf("Unexpected image_url: %q", got)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantTextToolUseTextOrder(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "pre"},
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
{"type": "text", "text": "post"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// New behavior: content + tool_calls unified in single assistant message
// Expect: assistant(content[pre,post] + tool_calls)
if len(messages) != 1 {
t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
assistantMsg := messages[0]
if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
}
// Should have both content and tool_calls in same message
if !assistantMsg.Get("tool_calls").Exists() {
t.Fatalf("Expected assistant message to have tool_calls")
}
if got := assistantMsg.Get("tool_calls.0.id").String(); got != "call_1" {
t.Fatalf("Expected tool_call id %q, got %q", "call_1", got)
}
if got := assistantMsg.Get("tool_calls.0.function.name").String(); got != "do_work" {
t.Fatalf("Expected tool_call name %q, got %q", "do_work", got)
}
// Content should have both pre and post text
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
}
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
}
}
func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *testing.T) {
inputJSON := `{
"model": "claude-3-opus",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "t1"},
{"type": "text", "text": "pre"},
{"type": "tool_use", "id": "call_1", "name": "do_work", "input": {"a": 1}},
{"type": "thinking", "thinking": "t2"},
{"type": "text", "text": "post"}
]
}
]
}`
result := ConvertClaudeRequestToOpenAI("test-model", []byte(inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
// New behavior: all content, thinking, and tool_calls unified in single assistant message
// Expect: assistant(content[pre,post] + tool_calls + reasoning_content[t1+t2])
if len(messages) != 1 {
t.Fatalf("Expected 1 message, got %d. Messages: %s", len(messages), resultJSON.Get("messages").Raw)
}
assistantMsg := messages[0]
if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected messages[0] to be assistant, got %s", assistantMsg.Get("role").String())
}
// Should have content with both pre and post
if got := assistantMsg.Get("content.0.text").String(); got != "pre" {
t.Fatalf("Expected content[0] text %q, got %q", "pre", got)
}
if got := assistantMsg.Get("content.1.text").String(); got != "post" {
t.Fatalf("Expected content[1] text %q, got %q", "post", got)
}
// Should have tool_calls
if !assistantMsg.Get("tool_calls").Exists() {
t.Fatalf("Expected assistant message to have tool_calls")
}
// Should have combined reasoning_content from both thinking blocks
if got := assistantMsg.Get("reasoning_content").String(); got != "t1\n\nt2" {
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
}
}
================================================
FILE: internal/translator/openai/claude/openai_claude_response.go
================================================
// Package claude provides response translation functionality for OpenAI to Anthropic API.
// This package handles the conversion of OpenAI Chat Completions API responses into Anthropic API-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by Anthropic API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, and usage metadata appropriately.
package claude
import (
"bytes"
"context"
"fmt"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var (
dataTag = []byte("data:")
)
// ConvertOpenAIResponseToAnthropicParams holds parameters for response conversion
type ConvertOpenAIResponseToAnthropicParams struct {
MessageID string
Model string
CreatedAt int64
ToolNameMap map[string]string
SawToolCall bool
// Content accumulator for streaming
ContentAccumulator strings.Builder
// Tool calls accumulator for streaming
ToolCallsAccumulator map[int]*ToolCallAccumulator
// Track if text content block has been started
TextContentBlockStarted bool
// Track if thinking content block has been started
ThinkingContentBlockStarted bool
// Track finish reason for later use
FinishReason string
// Track if content blocks have been stopped
ContentBlocksStopped bool
// Track if message_delta has been sent
MessageDeltaSent bool
// Track if message_start has been sent
MessageStarted bool
// Track if message_stop has been sent
MessageStopSent bool
// Tool call content block index mapping
ToolCallBlockIndexes map[int]int
// Index assigned to text content block
TextContentBlockIndex int
// Index assigned to thinking content block
ThinkingContentBlockIndex int
// Next available content block index
NextContentBlockIndex int
}
// ToolCallAccumulator holds the state for accumulating tool call data
type ToolCallAccumulator struct {
ID string
Name string
Arguments strings.Builder
}
// ConvertOpenAIResponseToClaude converts OpenAI streaming response format to Anthropic API format.
// This function processes OpenAI streaming chunks and transforms them into Anthropic-compatible JSON responses.
// It handles text content, tool calls, and usage metadata, outputting responses that match the Anthropic API format.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the OpenAI API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - []string: A slice of strings, each containing an Anthropic-compatible JSON response.
func ConvertOpenAIResponseToClaude(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertOpenAIResponseToAnthropicParams{
MessageID: "",
Model: "",
CreatedAt: 0,
ToolNameMap: nil,
SawToolCall: false,
ContentAccumulator: strings.Builder{},
ToolCallsAccumulator: nil,
TextContentBlockStarted: false,
ThinkingContentBlockStarted: false,
FinishReason: "",
ContentBlocksStopped: false,
MessageDeltaSent: false,
ToolCallBlockIndexes: make(map[int]int),
TextContentBlockIndex: -1,
ThinkingContentBlockIndex: -1,
NextContentBlockIndex: 0,
}
}
if !bytes.HasPrefix(rawJSON, dataTag) {
return []string{}
}
rawJSON = bytes.TrimSpace(rawJSON[5:])
if (*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap == nil {
(*param).(*ConvertOpenAIResponseToAnthropicParams).ToolNameMap = util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
}
// Check if this is the [DONE] marker
rawStr := strings.TrimSpace(string(rawJSON))
if rawStr == "[DONE]" {
return convertOpenAIDoneToAnthropic((*param).(*ConvertOpenAIResponseToAnthropicParams))
}
streamResult := gjson.GetBytes(originalRequestRawJSON, "stream")
if !streamResult.Exists() || (streamResult.Exists() && streamResult.Type == gjson.False) {
return convertOpenAINonStreamingToAnthropic(rawJSON)
} else {
return convertOpenAIStreamingChunkToAnthropic(rawJSON, (*param).(*ConvertOpenAIResponseToAnthropicParams))
}
}
func effectiveOpenAIFinishReason(param *ConvertOpenAIResponseToAnthropicParams) string {
if param == nil {
return ""
}
if param.SawToolCall {
return "tool_calls"
}
return param.FinishReason
}
// convertOpenAIStreamingChunkToAnthropic converts OpenAI streaming chunk to Anthropic streaming events
func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string {
root := gjson.ParseBytes(rawJSON)
var results []string
// Initialize parameters if needed
if param.MessageID == "" {
param.MessageID = root.Get("id").String()
}
if param.Model == "" {
param.Model = root.Get("model").String()
}
if param.CreatedAt == 0 {
param.CreatedAt = root.Get("created").Int()
}
// Emit message_start on the very first chunk, regardless of whether it has a role field.
// Some providers (like Copilot) may send tool_calls in the first chunk without a role field.
if delta := root.Get("choices.0.delta"); delta.Exists() {
if !param.MessageStarted {
// Send message_start event
messageStartJSON := `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}}`
messageStartJSON, _ = sjson.Set(messageStartJSON, "message.id", param.MessageID)
messageStartJSON, _ = sjson.Set(messageStartJSON, "message.model", param.Model)
results = append(results, "event: message_start\ndata: "+messageStartJSON+"\n\n")
param.MessageStarted = true
// Don't send content_block_start for text here - wait for actual content
}
// Handle reasoning content delta
if reasoning := delta.Get("reasoning_content"); reasoning.Exists() {
for _, reasoningText := range collectOpenAIReasoningTexts(reasoning) {
if reasoningText == "" {
continue
}
stopTextContentBlock(param, &results)
if !param.ThinkingContentBlockStarted {
if param.ThinkingContentBlockIndex == -1 {
param.ThinkingContentBlockIndex = param.NextContentBlockIndex
param.NextContentBlockIndex++
}
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", param.ThinkingContentBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
param.ThinkingContentBlockStarted = true
}
thinkingDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
thinkingDeltaJSON, _ = sjson.Set(thinkingDeltaJSON, "index", param.ThinkingContentBlockIndex)
thinkingDeltaJSON, _ = sjson.Set(thinkingDeltaJSON, "delta.thinking", reasoningText)
results = append(results, "event: content_block_delta\ndata: "+thinkingDeltaJSON+"\n\n")
}
}
// Handle content delta
if content := delta.Get("content"); content.Exists() && content.String() != "" {
// Send content_block_start for text if not already sent
if !param.TextContentBlockStarted {
stopThinkingContentBlock(param, &results)
if param.TextContentBlockIndex == -1 {
param.TextContentBlockIndex = param.NextContentBlockIndex
param.NextContentBlockIndex++
}
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", param.TextContentBlockIndex)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
param.TextContentBlockStarted = true
}
contentDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
contentDeltaJSON, _ = sjson.Set(contentDeltaJSON, "index", param.TextContentBlockIndex)
contentDeltaJSON, _ = sjson.Set(contentDeltaJSON, "delta.text", content.String())
results = append(results, "event: content_block_delta\ndata: "+contentDeltaJSON+"\n\n")
// Accumulate content
param.ContentAccumulator.WriteString(content.String())
}
// Handle tool calls
if toolCalls := delta.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
if param.ToolCallsAccumulator == nil {
param.ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
param.SawToolCall = true
index := int(toolCall.Get("index").Int())
blockIndex := param.toolContentBlockIndex(index)
// Initialize accumulator if needed
if _, exists := param.ToolCallsAccumulator[index]; !exists {
param.ToolCallsAccumulator[index] = &ToolCallAccumulator{}
}
accumulator := param.ToolCallsAccumulator[index]
// Handle tool call ID
if id := toolCall.Get("id"); id.Exists() {
accumulator.ID = id.String()
}
// Handle function name
if function := toolCall.Get("function"); function.Exists() {
if name := function.Get("name"); name.Exists() {
accumulator.Name = util.MapToolName(param.ToolNameMap, name.String())
stopThinkingContentBlock(param, &results)
stopTextContentBlock(param, &results)
// Send content_block_start for tool_use
contentBlockStartJSON := `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "index", blockIndex)
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.id", util.SanitizeClaudeToolID(accumulator.ID))
contentBlockStartJSON, _ = sjson.Set(contentBlockStartJSON, "content_block.name", accumulator.Name)
results = append(results, "event: content_block_start\ndata: "+contentBlockStartJSON+"\n\n")
}
// Handle function arguments
if args := function.Get("arguments"); args.Exists() {
argsText := args.String()
if argsText != "" {
accumulator.Arguments.WriteString(argsText)
}
}
}
return true
})
}
}
// Handle finish_reason (but don't send message_delta/message_stop yet)
if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" {
reason := finishReason.String()
if param.SawToolCall {
param.FinishReason = "tool_calls"
} else {
param.FinishReason = reason
}
// Send content_block_stop for thinking content if needed
if param.ThinkingContentBlockStarted {
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1
}
// Send content_block_stop for text if text content block was started
stopTextContentBlock(param, &results)
// Send content_block_stop for any tool calls
if !param.ContentBlocksStopped {
for index := range param.ToolCallsAccumulator {
accumulator := param.ToolCallsAccumulator[index]
blockIndex := param.toolContentBlockIndex(index)
// Send complete input_json_delta with all accumulated arguments
if accumulator.Arguments.Len() > 0 {
inputDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "index", blockIndex)
inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "delta.partial_json", util.FixJSON(accumulator.Arguments.String()))
results = append(results, "event: content_block_delta\ndata: "+inputDeltaJSON+"\n\n")
}
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", blockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
delete(param.ToolCallBlockIndexes, index)
}
param.ContentBlocksStopped = true
}
// Don't send message_delta here - wait for usage info or [DONE]
}
// Handle usage information separately (this comes in a later chunk)
// Only process if usage has actual values (not null)
if param.FinishReason != "" {
usage := root.Get("usage")
var inputTokens, outputTokens, cachedTokens int64
if usage.Exists() && usage.Type != gjson.Null {
inputTokens, outputTokens, cachedTokens = extractOpenAIUsage(usage)
// Send message_delta with usage
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param)))
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.input_tokens", inputTokens)
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "usage.cache_read_input_tokens", cachedTokens)
}
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
param.MessageDeltaSent = true
emitMessageStopIfNeeded(param, &results)
}
}
return results
}
// convertOpenAIDoneToAnthropic handles the [DONE] marker and sends final events
func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams) []string {
var results []string
// Ensure all content blocks are stopped before final events
if param.ThinkingContentBlockStarted {
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1
}
stopTextContentBlock(param, &results)
if !param.ContentBlocksStopped {
for index := range param.ToolCallsAccumulator {
accumulator := param.ToolCallsAccumulator[index]
blockIndex := param.toolContentBlockIndex(index)
if accumulator.Arguments.Len() > 0 {
inputDeltaJSON := `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "index", blockIndex)
inputDeltaJSON, _ = sjson.Set(inputDeltaJSON, "delta.partial_json", util.FixJSON(accumulator.Arguments.String()))
results = append(results, "event: content_block_delta\ndata: "+inputDeltaJSON+"\n\n")
}
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", blockIndex)
results = append(results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
delete(param.ToolCallBlockIndexes, index)
}
param.ContentBlocksStopped = true
}
// If we haven't sent message_delta yet (no usage info was received), send it now
if param.FinishReason != "" && !param.MessageDeltaSent {
messageDeltaJSON := `{"type":"message_delta","delta":{"stop_reason":"","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
messageDeltaJSON, _ = sjson.Set(messageDeltaJSON, "delta.stop_reason", mapOpenAIFinishReasonToAnthropic(effectiveOpenAIFinishReason(param)))
results = append(results, "event: message_delta\ndata: "+messageDeltaJSON+"\n\n")
param.MessageDeltaSent = true
}
emitMessageStopIfNeeded(param, &results)
return results
}
// convertOpenAINonStreamingToAnthropic converts OpenAI non-streaming response to Anthropic format
func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
root := gjson.ParseBytes(rawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
// Process message content and tool calls
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 {
choice := choices.Array()[0] // Take first choice
reasoningNode := choice.Get("message.reasoning_content")
for _, reasoningText := range collectOpenAIReasoningTexts(reasoningNode) {
if reasoningText == "" {
continue
}
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", reasoningText)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
// Handle text content
if content := choice.Get("message.content"); content.Exists() && content.String() != "" {
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", content.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
}
// Handle tool calls
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", toolCall.Get("function.name").String())
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", argsJSON.Raw)
} else {
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
}
} else {
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
}
out, _ = sjson.SetRaw(out, "content.-1", toolUseBlock)
return true
})
}
// Set stop reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
out, _ = sjson.Set(out, "stop_reason", mapOpenAIFinishReasonToAnthropic(finishReason.String()))
}
}
// Set usage information
if usage := root.Get("usage"); usage.Exists() {
inputTokens, outputTokens, cachedTokens := extractOpenAIUsage(usage)
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
}
}
return []string{out}
}
// mapOpenAIFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic equivalents
func mapOpenAIFinishReasonToAnthropic(openAIReason string) string {
switch openAIReason {
case "stop":
return "end_turn"
case "length":
return "max_tokens"
case "tool_calls":
return "tool_use"
case "content_filter":
return "end_turn" // Anthropic doesn't have direct equivalent
case "function_call": // Legacy OpenAI
return "tool_use"
default:
return "end_turn"
}
}
func (p *ConvertOpenAIResponseToAnthropicParams) toolContentBlockIndex(openAIToolIndex int) int {
if idx, ok := p.ToolCallBlockIndexes[openAIToolIndex]; ok {
return idx
}
idx := p.NextContentBlockIndex
p.NextContentBlockIndex++
p.ToolCallBlockIndexes[openAIToolIndex] = idx
return idx
}
func collectOpenAIReasoningTexts(node gjson.Result) []string {
var texts []string
if !node.Exists() {
return texts
}
if node.IsArray() {
node.ForEach(func(_, value gjson.Result) bool {
texts = append(texts, collectOpenAIReasoningTexts(value)...)
return true
})
return texts
}
switch node.Type {
case gjson.String:
if text := node.String(); text != "" {
texts = append(texts, text)
}
case gjson.JSON:
if text := node.Get("text"); text.Exists() {
if textStr := text.String(); textStr != "" {
texts = append(texts, textStr)
}
} else if raw := node.Raw; raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
texts = append(texts, raw)
}
}
return texts
}
func stopThinkingContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results *[]string) {
if !param.ThinkingContentBlockStarted {
return
}
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.ThinkingContentBlockIndex)
*results = append(*results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
param.ThinkingContentBlockStarted = false
param.ThinkingContentBlockIndex = -1
}
func emitMessageStopIfNeeded(param *ConvertOpenAIResponseToAnthropicParams, results *[]string) {
if param.MessageStopSent {
return
}
*results = append(*results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
param.MessageStopSent = true
}
func stopTextContentBlock(param *ConvertOpenAIResponseToAnthropicParams, results *[]string) {
if !param.TextContentBlockStarted {
return
}
contentBlockStopJSON := `{"type":"content_block_stop","index":0}`
contentBlockStopJSON, _ = sjson.Set(contentBlockStopJSON, "index", param.TextContentBlockIndex)
*results = append(*results, "event: content_block_stop\ndata: "+contentBlockStopJSON+"\n\n")
param.TextContentBlockStarted = false
param.TextContentBlockIndex = -1
}
// ConvertOpenAIResponseToClaudeNonStream converts a non-streaming OpenAI response to a non-streaming Anthropic response.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the OpenAI API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: An Anthropic-compatible JSON response.
func ConvertOpenAIResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
_ = requestRawJSON
root := gjson.ParseBytes(rawJSON)
toolNameMap := util.ToolNameMapFromClaudeRequest(originalRequestRawJSON)
out := `{"id":"","type":"message","role":"assistant","model":"","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
out, _ = sjson.Set(out, "id", root.Get("id").String())
out, _ = sjson.Set(out, "model", root.Get("model").String())
hasToolCall := false
stopReasonSet := false
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() && len(choices.Array()) > 0 {
choice := choices.Array()[0]
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
out, _ = sjson.Set(out, "stop_reason", mapOpenAIFinishReasonToAnthropic(finishReason.String()))
stopReasonSet = true
}
if message := choice.Get("message"); message.Exists() {
if contentResult := message.Get("content"); contentResult.Exists() {
if contentResult.IsArray() {
var textBuilder strings.Builder
var thinkingBuilder strings.Builder
flushText := func() {
if textBuilder.Len() == 0 {
return
}
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
textBuilder.Reset()
}
flushThinking := func() {
if thinkingBuilder.Len() == 0 {
return
}
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
out, _ = sjson.SetRaw(out, "content.-1", block)
thinkingBuilder.Reset()
}
for _, item := range contentResult.Array() {
switch item.Get("type").String() {
case "text":
flushThinking()
textBuilder.WriteString(item.Get("text").String())
case "tool_calls":
flushThinking()
flushText()
toolCalls := item.Get("tool_calls")
if toolCalls.IsArray() {
toolCalls.ForEach(func(_, tc gjson.Result) bool {
hasToolCall = true
toolUse := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUse, _ = sjson.Set(toolUse, "id", util.SanitizeClaudeToolID(tc.Get("id").String()))
toolUse, _ = sjson.Set(toolUse, "name", util.MapToolName(toolNameMap, tc.Get("function.name").String()))
argsStr := util.FixJSON(tc.Get("function.arguments").String())
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUse, _ = sjson.SetRaw(toolUse, "input", argsJSON.Raw)
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
} else {
toolUse, _ = sjson.SetRaw(toolUse, "input", "{}")
}
out, _ = sjson.SetRaw(out, "content.-1", toolUse)
return true
})
}
case "reasoning":
flushText()
if thinking := item.Get("text"); thinking.Exists() {
thinkingBuilder.WriteString(thinking.String())
}
default:
flushThinking()
flushText()
}
}
flushThinking()
flushText()
} else if contentResult.Type == gjson.String {
textContent := contentResult.String()
if textContent != "" {
block := `{"type":"text","text":""}`
block, _ = sjson.Set(block, "text", textContent)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
}
if reasoning := message.Get("reasoning_content"); reasoning.Exists() {
for _, reasoningText := range collectOpenAIReasoningTexts(reasoning) {
if reasoningText == "" {
continue
}
block := `{"type":"thinking","thinking":""}`
block, _ = sjson.Set(block, "thinking", reasoningText)
out, _ = sjson.SetRaw(out, "content.-1", block)
}
}
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
hasToolCall = true
toolUseBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
toolUseBlock, _ = sjson.Set(toolUseBlock, "id", util.SanitizeClaudeToolID(toolCall.Get("id").String()))
toolUseBlock, _ = sjson.Set(toolUseBlock, "name", util.MapToolName(toolNameMap, toolCall.Get("function.name").String()))
argsStr := util.FixJSON(toolCall.Get("function.arguments").String())
if argsStr != "" && gjson.Valid(argsStr) {
argsJSON := gjson.Parse(argsStr)
if argsJSON.IsObject() {
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", argsJSON.Raw)
} else {
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
}
} else {
toolUseBlock, _ = sjson.SetRaw(toolUseBlock, "input", "{}")
}
out, _ = sjson.SetRaw(out, "content.-1", toolUseBlock)
return true
})
}
}
}
if respUsage := root.Get("usage"); respUsage.Exists() {
inputTokens, outputTokens, cachedTokens := extractOpenAIUsage(respUsage)
out, _ = sjson.Set(out, "usage.input_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.output_tokens", outputTokens)
if cachedTokens > 0 {
out, _ = sjson.Set(out, "usage.cache_read_input_tokens", cachedTokens)
}
}
if !stopReasonSet {
if hasToolCall {
out, _ = sjson.Set(out, "stop_reason", "tool_use")
} else {
out, _ = sjson.Set(out, "stop_reason", "end_turn")
}
}
return out
}
func ClaudeTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"input_tokens":%d}`, count)
}
func extractOpenAIUsage(usage gjson.Result) (int64, int64, int64) {
if !usage.Exists() || usage.Type == gjson.Null {
return 0, 0, 0
}
inputTokens := usage.Get("prompt_tokens").Int()
outputTokens := usage.Get("completion_tokens").Int()
cachedTokens := usage.Get("prompt_tokens_details.cached_tokens").Int()
if cachedTokens > 0 {
if inputTokens >= cachedTokens {
inputTokens -= cachedTokens
} else {
inputTokens = 0
}
}
return inputTokens, outputTokens, cachedTokens
}
================================================
FILE: internal/translator/openai/gemini/init.go
================================================
package gemini
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
Gemini,
OpenAI,
ConvertGeminiRequestToOpenAI,
interfaces.TranslateResponse{
Stream: ConvertOpenAIResponseToGemini,
NonStream: ConvertOpenAIResponseToGeminiNonStream,
TokenCount: GeminiTokenCount,
},
)
}
================================================
FILE: internal/translator/openai/gemini/openai_gemini_request.go
================================================
// Package gemini provides request translation functionality for Gemini to OpenAI API.
// It handles parsing and transforming Gemini API requests into OpenAI Chat Completions API format,
// extracting model information, generation config, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini API format and OpenAI API's expected format.
package gemini
import (
"crypto/rand"
"fmt"
"math/big"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiRequestToOpenAI parses and transforms a Gemini API request into OpenAI Chat Completions API format.
// It extracts the model name, generation config, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the OpenAI API.
func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
// Base OpenAI Chat Completions API template
out := `{"model":"","messages":[]}`
root := gjson.ParseBytes(rawJSON)
// Helper for generating tool call IDs in the form: call_
genToolCallID := func() string {
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
var b strings.Builder
// 24 chars random suffix
for i := 0; i < 24; i++ {
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
b.WriteByte(letters[n.Int64()])
}
return "call_" + b.String()
}
// Model mapping
out, _ = sjson.Set(out, "model", modelName)
// Generation config mapping
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
// Temperature
if temp := genConfig.Get("temperature"); temp.Exists() {
out, _ = sjson.Set(out, "temperature", temp.Float())
}
// Max tokens
if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
}
// Top P
if topP := genConfig.Get("topP"); topP.Exists() {
out, _ = sjson.Set(out, "top_p", topP.Float())
}
// Top K (OpenAI doesn't have direct equivalent, but we can map it)
if topK := genConfig.Get("topK"); topK.Exists() {
// Store as custom parameter for potential use
out, _ = sjson.Set(out, "top_k", topK.Int())
}
// Stop sequences
if stopSequences := genConfig.Get("stopSequences"); stopSequences.Exists() && stopSequences.IsArray() {
var stops []string
stopSequences.ForEach(func(_, value gjson.Result) bool {
stops = append(stops, value.String())
return true
})
if len(stops) > 0 {
out, _ = sjson.Set(out, "stop", stops)
}
}
// Candidate count (OpenAI 'n' parameter)
if candidateCount := genConfig.Get("candidateCount"); candidateCount.Exists() {
out, _ = sjson.Set(out, "n", candidateCount.Int())
}
// Map Gemini thinkingConfig to OpenAI reasoning_effort.
// Always perform conversion to support allowCompat models that may not be in registry.
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
thinkingLevel := thinkingConfig.Get("thinkingLevel")
if !thinkingLevel.Exists() {
thinkingLevel = thinkingConfig.Get("thinking_level")
}
if thinkingLevel.Exists() {
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
if effort != "" {
out, _ = sjson.Set(out, "reasoning_effort", effort)
}
} else {
thinkingBudget := thinkingConfig.Get("thinkingBudget")
if !thinkingBudget.Exists() {
thinkingBudget = thinkingConfig.Get("thinking_budget")
}
if thinkingBudget.Exists() {
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
out, _ = sjson.Set(out, "reasoning_effort", effort)
}
}
}
}
}
// Stream parameter
out, _ = sjson.Set(out, "stream", stream)
// Process contents (Gemini messages) -> OpenAI messages
var toolCallIDs []string // Track tool call IDs for matching with tool results
// System instruction -> OpenAI system message
// Gemini may provide `systemInstruction` or `system_instruction`; support both keys.
systemInstruction := root.Get("systemInstruction")
if !systemInstruction.Exists() {
systemInstruction = root.Get("system_instruction")
}
if systemInstruction.Exists() {
parts := systemInstruction.Get("parts")
msg := `{"role":"system","content":[]}`
hasContent := false
if parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
// Handle text parts
if text := part.Get("text"); text.Exists() {
contentPart := `{"type":"text","text":""}`
contentPart, _ = sjson.Set(contentPart, "text", text.String())
msg, _ = sjson.SetRaw(msg, "content.-1", contentPart)
hasContent = true
}
// Handle inline data (e.g., images)
if inlineData := part.Get("inlineData"); inlineData.Exists() {
mimeType := inlineData.Get("mimeType").String()
if mimeType == "" {
mimeType = "application/octet-stream"
}
data := inlineData.Get("data").String()
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
contentPart := `{"type":"image_url","image_url":{"url":""}}`
contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
msg, _ = sjson.SetRaw(msg, "content.-1", contentPart)
hasContent = true
}
return true
})
}
if hasContent {
out, _ = sjson.SetRaw(out, "messages.-1", msg)
}
}
if contents := root.Get("contents"); contents.Exists() && contents.IsArray() {
contents.ForEach(func(_, content gjson.Result) bool {
role := content.Get("role").String()
parts := content.Get("parts")
// Convert role: model -> assistant
if role == "model" {
role = "assistant"
}
msg := `{"role":"","content":""}`
msg, _ = sjson.Set(msg, "role", role)
var textBuilder strings.Builder
contentWrapper := `{"arr":[]}`
contentPartsCount := 0
onlyTextContent := true
toolCallsWrapper := `{"arr":[]}`
toolCallsCount := 0
if parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
// Handle text parts
if text := part.Get("text"); text.Exists() {
formattedText := text.String()
textBuilder.WriteString(formattedText)
contentPart := `{"type":"text","text":""}`
contentPart, _ = sjson.Set(contentPart, "text", formattedText)
contentWrapper, _ = sjson.SetRaw(contentWrapper, "arr.-1", contentPart)
contentPartsCount++
}
// Handle inline data (e.g., images)
if inlineData := part.Get("inlineData"); inlineData.Exists() {
onlyTextContent = false
mimeType := inlineData.Get("mimeType").String()
if mimeType == "" {
mimeType = "application/octet-stream"
}
data := inlineData.Get("data").String()
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
contentPart := `{"type":"image_url","image_url":{"url":""}}`
contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
contentWrapper, _ = sjson.SetRaw(contentWrapper, "arr.-1", contentPart)
contentPartsCount++
}
// Handle function calls (Gemini) -> tool calls (OpenAI)
if functionCall := part.Get("functionCall"); functionCall.Exists() {
toolCallID := genToolCallID()
toolCallIDs = append(toolCallIDs, toolCallID)
toolCall := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
toolCall, _ = sjson.Set(toolCall, "id", toolCallID)
toolCall, _ = sjson.Set(toolCall, "function.name", functionCall.Get("name").String())
// Convert args to arguments JSON string
if args := functionCall.Get("args"); args.Exists() {
toolCall, _ = sjson.Set(toolCall, "function.arguments", args.Raw)
} else {
toolCall, _ = sjson.Set(toolCall, "function.arguments", "{}")
}
toolCallsWrapper, _ = sjson.SetRaw(toolCallsWrapper, "arr.-1", toolCall)
toolCallsCount++
}
// Handle function responses (Gemini) -> tool role messages (OpenAI)
if functionResponse := part.Get("functionResponse"); functionResponse.Exists() {
// Create tool message for function response
toolMsg := `{"role":"tool","tool_call_id":"","content":""}`
// Convert response.content to JSON string
if response := functionResponse.Get("response"); response.Exists() {
if contentField := response.Get("content"); contentField.Exists() {
toolMsg, _ = sjson.Set(toolMsg, "content", contentField.Raw)
} else {
toolMsg, _ = sjson.Set(toolMsg, "content", response.Raw)
}
}
// Try to match with previous tool call ID
_ = functionResponse.Get("name").String() // functionName not used for now
if len(toolCallIDs) > 0 {
// Use the last tool call ID (simple matching by function name)
// In a real implementation, you might want more sophisticated matching
toolMsg, _ = sjson.Set(toolMsg, "tool_call_id", toolCallIDs[len(toolCallIDs)-1])
} else {
// Generate a tool call ID if none available
toolMsg, _ = sjson.Set(toolMsg, "tool_call_id", genToolCallID())
}
out, _ = sjson.SetRaw(out, "messages.-1", toolMsg)
}
return true
})
}
// Set content
if contentPartsCount > 0 {
if onlyTextContent {
msg, _ = sjson.Set(msg, "content", textBuilder.String())
} else {
msg, _ = sjson.SetRaw(msg, "content", gjson.Get(contentWrapper, "arr").Raw)
}
}
// Set tool calls if any
if toolCallsCount > 0 {
msg, _ = sjson.SetRaw(msg, "tool_calls", gjson.Get(toolCallsWrapper, "arr").Raw)
}
out, _ = sjson.SetRaw(out, "messages.-1", msg)
return true
})
}
// Tools mapping: Gemini tools -> OpenAI tools
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
tools.ForEach(func(_, tool gjson.Result) bool {
if functionDeclarations := tool.Get("functionDeclarations"); functionDeclarations.Exists() && functionDeclarations.IsArray() {
functionDeclarations.ForEach(func(_, funcDecl gjson.Result) bool {
openAITool := `{"type":"function","function":{"name":"","description":""}}`
openAITool, _ = sjson.Set(openAITool, "function.name", funcDecl.Get("name").String())
openAITool, _ = sjson.Set(openAITool, "function.description", funcDecl.Get("description").String())
// Convert parameters schema
if parameters := funcDecl.Get("parameters"); parameters.Exists() {
openAITool, _ = sjson.SetRaw(openAITool, "function.parameters", parameters.Raw)
} else if parameters := funcDecl.Get("parametersJsonSchema"); parameters.Exists() {
openAITool, _ = sjson.SetRaw(openAITool, "function.parameters", parameters.Raw)
}
out, _ = sjson.SetRaw(out, "tools.-1", openAITool)
return true
})
}
return true
})
}
// Tool choice mapping (Gemini doesn't have direct equivalent, but we can handle it)
if toolConfig := root.Get("toolConfig"); toolConfig.Exists() {
if functionCallingConfig := toolConfig.Get("functionCallingConfig"); functionCallingConfig.Exists() {
mode := functionCallingConfig.Get("mode").String()
switch mode {
case "NONE":
out, _ = sjson.Set(out, "tool_choice", "none")
case "AUTO":
out, _ = sjson.Set(out, "tool_choice", "auto")
case "ANY":
out, _ = sjson.Set(out, "tool_choice", "required")
}
}
}
return []byte(out)
}
================================================
FILE: internal/translator/openai/gemini/openai_gemini_response.go
================================================
// Package gemini provides response translation functionality for OpenAI to Gemini API.
// This package handles the conversion of OpenAI Chat Completions API responses into Gemini API-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by Gemini API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, and usage metadata appropriately.
package gemini
import (
"bytes"
"context"
"fmt"
"strconv"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertOpenAIResponseToGeminiParams holds parameters for response conversion
type ConvertOpenAIResponseToGeminiParams struct {
// Tool calls accumulator for streaming
ToolCallsAccumulator map[int]*ToolCallAccumulator
// Content accumulator for streaming
ContentAccumulator strings.Builder
// Track if this is the first chunk
IsFirstChunk bool
}
// ToolCallAccumulator holds the state for accumulating tool call data
type ToolCallAccumulator struct {
ID string
Name string
Arguments strings.Builder
}
// ConvertOpenAIResponseToGemini converts OpenAI Chat Completions streaming response format to Gemini API format.
// This function processes OpenAI streaming chunks and transforms them into Gemini-compatible JSON responses.
// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini API format.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the OpenAI API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response.
func ConvertOpenAIResponseToGemini(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &ConvertOpenAIResponseToGeminiParams{
ToolCallsAccumulator: nil,
ContentAccumulator: strings.Builder{},
IsFirstChunk: false,
}
}
// Handle [DONE] marker
if strings.TrimSpace(string(rawJSON)) == "[DONE]" {
return []string{}
}
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
root := gjson.ParseBytes(rawJSON)
// Initialize accumulators if needed
if (*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator == nil {
(*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
// Process choices
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
// Handle empty choices array (usage-only chunk)
if len(choices.Array()) == 0 {
// This is a usage-only chunk, handle usage and return
if usage := root.Get("usage"); usage.Exists() {
template := `{"candidates":[],"usageMetadata":{}}`
// Set model if available
if model := root.Get("model"); model.Exists() {
template, _ = sjson.Set(template, "model", model.String())
}
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", reasoningTokens)
}
return []string{template}
}
return []string{}
}
var results []string
choices.ForEach(func(choiceIndex, choice gjson.Result) bool {
// Base Gemini response template without finishReason; set when known
template := `{"candidates":[{"content":{"parts":[],"role":"model"},"index":0}]}`
// Set model if available
if model := root.Get("model"); model.Exists() {
template, _ = sjson.Set(template, "model", model.String())
}
_ = int(choice.Get("index").Int()) // choiceIdx not used in streaming
delta := choice.Get("delta")
baseTemplate := template
// Handle role (only in first chunk)
if role := delta.Get("role"); role.Exists() && (*param).(*ConvertOpenAIResponseToGeminiParams).IsFirstChunk {
// OpenAI assistant -> Gemini model
if role.String() == "assistant" {
template, _ = sjson.Set(template, "candidates.0.content.role", "model")
}
(*param).(*ConvertOpenAIResponseToGeminiParams).IsFirstChunk = false
results = append(results, template)
return true
}
var chunkOutputs []string
// Handle reasoning/thinking delta
if reasoning := delta.Get("reasoning_content"); reasoning.Exists() {
for _, reasoningText := range extractReasoningTexts(reasoning) {
if reasoningText == "" {
continue
}
reasoningTemplate := baseTemplate
reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts.0.thought", true)
reasoningTemplate, _ = sjson.Set(reasoningTemplate, "candidates.0.content.parts.0.text", reasoningText)
chunkOutputs = append(chunkOutputs, reasoningTemplate)
}
}
// Handle content delta
if content := delta.Get("content"); content.Exists() && content.String() != "" {
contentText := content.String()
(*param).(*ConvertOpenAIResponseToGeminiParams).ContentAccumulator.WriteString(contentText)
// Create text part for this delta
contentTemplate := baseTemplate
contentTemplate, _ = sjson.Set(contentTemplate, "candidates.0.content.parts.0.text", contentText)
chunkOutputs = append(chunkOutputs, contentTemplate)
}
if len(chunkOutputs) > 0 {
results = append(results, chunkOutputs...)
return true
}
// Handle tool calls delta
if toolCalls := delta.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
toolIndex := int(toolCall.Get("index").Int())
toolID := toolCall.Get("id").String()
toolType := toolCall.Get("type").String()
function := toolCall.Get("function")
// Skip non-function tool calls explicitly marked as other types.
if toolType != "" && toolType != "function" {
return true
}
// OpenAI streaming deltas may omit the type field while still carrying function data.
if !function.Exists() {
return true
}
functionName := function.Get("name").String()
functionArgs := function.Get("arguments").String()
// Initialize accumulator if needed so later deltas without type can append arguments.
if _, exists := (*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator[toolIndex]; !exists {
(*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator[toolIndex] = &ToolCallAccumulator{
ID: toolID,
Name: functionName,
}
}
acc := (*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator[toolIndex]
// Update ID if provided
if toolID != "" {
acc.ID = toolID
}
// Update name if provided
if functionName != "" {
acc.Name = functionName
}
// Accumulate arguments
if functionArgs != "" {
acc.Arguments.WriteString(functionArgs)
}
return true
})
// Don't output anything for tool call deltas - wait for completion
return true
}
// Handle finish reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String())
template, _ = sjson.Set(template, "candidates.0.finishReason", geminiFinishReason)
// If we have accumulated tool calls, output them now
if len((*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator) > 0 {
partIndex := 0
for _, accumulator := range (*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator {
namePath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.name", partIndex)
argsPath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.args", partIndex)
template, _ = sjson.Set(template, namePath, accumulator.Name)
template, _ = sjson.SetRaw(template, argsPath, parseArgsToObjectRaw(accumulator.Arguments.String()))
partIndex++
}
// Clear accumulators
(*param).(*ConvertOpenAIResponseToGeminiParams).ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
}
results = append(results, template)
return true
}
// Handle usage information
if usage := root.Get("usage"); usage.Exists() {
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
template, _ = sjson.Set(template, "usageMetadata.thoughtsTokenCount", reasoningTokens)
}
results = append(results, template)
return true
}
return true
})
return results
}
return []string{}
}
// mapOpenAIFinishReasonToGemini maps OpenAI finish reasons to Gemini finish reasons
func mapOpenAIFinishReasonToGemini(openAIReason string) string {
switch openAIReason {
case "stop":
return "STOP"
case "length":
return "MAX_TOKENS"
case "tool_calls":
return "STOP" // Gemini doesn't have a specific tool_calls finish reason
case "content_filter":
return "SAFETY"
default:
return "STOP"
}
}
// parseArgsToObjectRaw safely parses a JSON string of function arguments into an object JSON string.
// It returns "{}" if the input is empty or cannot be parsed as a JSON object.
func parseArgsToObjectRaw(argsStr string) string {
trimmed := strings.TrimSpace(argsStr)
if trimmed == "" || trimmed == "{}" {
return "{}"
}
// First try strict JSON
if gjson.Valid(trimmed) {
strict := gjson.Parse(trimmed)
if strict.IsObject() {
return strict.Raw
}
}
// Tolerant parse: handle streams where values are barewords (e.g., 北京, celsius)
tolerant := tolerantParseJSONObjectRaw(trimmed)
if tolerant != "{}" {
return tolerant
}
// Fallback: return empty object when parsing fails
return "{}"
}
func escapeSjsonPathKey(key string) string {
key = strings.ReplaceAll(key, `\`, `\\`)
key = strings.ReplaceAll(key, `.`, `\.`)
return key
}
// tolerantParseJSONObjectRaw attempts to parse a JSON-like object string into a JSON object string, tolerating
// bareword values (unquoted strings) commonly seen during streamed tool calls.
// Example input: {"location": 北京, "unit": celsius}
func tolerantParseJSONObjectRaw(s string) string {
// Ensure we operate within the outermost braces if present
start := strings.Index(s, "{")
end := strings.LastIndex(s, "}")
if start == -1 || end == -1 || start >= end {
return "{}"
}
content := s[start+1 : end]
runes := []rune(content)
n := len(runes)
i := 0
result := "{}"
for i < n {
// Skip whitespace and commas
for i < n && (runes[i] == ' ' || runes[i] == '\n' || runes[i] == '\r' || runes[i] == '\t' || runes[i] == ',') {
i++
}
if i >= n {
break
}
// Expect quoted key
if runes[i] != '"' {
// Unable to parse this segment reliably; skip to next comma
for i < n && runes[i] != ',' {
i++
}
continue
}
// Parse JSON string for key
keyToken, nextIdx := parseJSONStringRunes(runes, i)
if nextIdx == -1 {
break
}
keyName := jsonStringTokenToRawString(keyToken)
sjsonKey := escapeSjsonPathKey(keyName)
i = nextIdx
// Skip whitespace
for i < n && (runes[i] == ' ' || runes[i] == '\n' || runes[i] == '\r' || runes[i] == '\t') {
i++
}
if i >= n || runes[i] != ':' {
break
}
i++ // skip ':'
// Skip whitespace
for i < n && (runes[i] == ' ' || runes[i] == '\n' || runes[i] == '\r' || runes[i] == '\t') {
i++
}
if i >= n {
break
}
// Parse value (string, number, object/array, bareword)
switch runes[i] {
case '"':
// JSON string
valToken, ni := parseJSONStringRunes(runes, i)
if ni == -1 {
// Malformed; treat as empty string
result, _ = sjson.Set(result, sjsonKey, "")
i = n
} else {
result, _ = sjson.Set(result, sjsonKey, jsonStringTokenToRawString(valToken))
i = ni
}
case '{', '[':
// Bracketed value: attempt to capture balanced structure
seg, ni := captureBracketed(runes, i)
if ni == -1 {
i = n
} else {
if gjson.Valid(seg) {
result, _ = sjson.SetRaw(result, sjsonKey, seg)
} else {
result, _ = sjson.Set(result, sjsonKey, seg)
}
i = ni
}
default:
// Bare token until next comma or end
j := i
for j < n && runes[j] != ',' {
j++
}
token := strings.TrimSpace(string(runes[i:j]))
// Interpret common JSON atoms and numbers; otherwise treat as string
if token == "true" {
result, _ = sjson.Set(result, sjsonKey, true)
} else if token == "false" {
result, _ = sjson.Set(result, sjsonKey, false)
} else if token == "null" {
result, _ = sjson.Set(result, sjsonKey, nil)
} else if numVal, ok := tryParseNumber(token); ok {
result, _ = sjson.Set(result, sjsonKey, numVal)
} else {
result, _ = sjson.Set(result, sjsonKey, token)
}
i = j
}
// Skip trailing whitespace and optional comma before next pair
for i < n && (runes[i] == ' ' || runes[i] == '\n' || runes[i] == '\r' || runes[i] == '\t') {
i++
}
if i < n && runes[i] == ',' {
i++
}
}
return result
}
// parseJSONStringRunes returns the JSON string token (including quotes) and the index just after it.
func parseJSONStringRunes(runes []rune, start int) (string, int) {
if start >= len(runes) || runes[start] != '"' {
return "", -1
}
i := start + 1
escaped := false
for i < len(runes) {
r := runes[i]
if r == '\\' && !escaped {
escaped = true
i++
continue
}
if r == '"' && !escaped {
return string(runes[start : i+1]), i + 1
}
escaped = false
i++
}
return string(runes[start:]), -1
}
// jsonStringTokenToRawString converts a JSON string token (including quotes) to a raw Go string value.
func jsonStringTokenToRawString(token string) string {
r := gjson.Parse(token)
if r.Type == gjson.String {
return r.String()
}
// Fallback: strip surrounding quotes if present
if len(token) >= 2 && token[0] == '"' && token[len(token)-1] == '"' {
return token[1 : len(token)-1]
}
return token
}
// captureBracketed captures a balanced JSON object/array starting at index i.
// Returns the segment string and the index just after it; -1 if malformed.
func captureBracketed(runes []rune, i int) (string, int) {
if i >= len(runes) {
return "", -1
}
startRune := runes[i]
var endRune rune
if startRune == '{' {
endRune = '}'
} else if startRune == '[' {
endRune = ']'
} else {
return "", -1
}
depth := 0
j := i
inStr := false
escaped := false
for j < len(runes) {
r := runes[j]
if inStr {
if r == '\\' && !escaped {
escaped = true
j++
continue
}
if r == '"' && !escaped {
inStr = false
} else {
escaped = false
}
j++
continue
}
if r == '"' {
inStr = true
j++
continue
}
if r == startRune {
depth++
} else if r == endRune {
depth--
if depth == 0 {
return string(runes[i : j+1]), j + 1
}
}
j++
}
return string(runes[i:]), -1
}
// tryParseNumber attempts to parse a string as an int or float.
func tryParseNumber(s string) (interface{}, bool) {
if s == "" {
return nil, false
}
// Try integer
if i64, errParseInt := strconv.ParseInt(s, 10, 64); errParseInt == nil {
return i64, true
}
if u64, errParseUInt := strconv.ParseUint(s, 10, 64); errParseUInt == nil {
return u64, true
}
if f64, errParseFloat := strconv.ParseFloat(s, 64); errParseFloat == nil {
return f64, true
}
return nil, false
}
// ConvertOpenAIResponseToGeminiNonStream converts a non-streaming OpenAI response to a non-streaming Gemini response.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the OpenAI API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Gemini-compatible JSON response.
func ConvertOpenAIResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
root := gjson.ParseBytes(rawJSON)
// Base Gemini response template without finishReason; set when known
out := `{"candidates":[{"content":{"parts":[],"role":"model"},"index":0}]}`
// Set model if available
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
// Process choices
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
choices.ForEach(func(choiceIndex, choice gjson.Result) bool {
choiceIdx := int(choice.Get("index").Int())
message := choice.Get("message")
// Set role
if role := message.Get("role"); role.Exists() {
if role.String() == "assistant" {
out, _ = sjson.Set(out, "candidates.0.content.role", "model")
}
}
partIndex := 0
// Handle reasoning content before visible text
if reasoning := message.Get("reasoning_content"); reasoning.Exists() {
for _, reasoningText := range extractReasoningTexts(reasoning) {
if reasoningText == "" {
continue
}
out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.thought", partIndex), true)
out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.text", partIndex), reasoningText)
partIndex++
}
}
// Handle content first
if content := message.Get("content"); content.Exists() && content.String() != "" {
out, _ = sjson.Set(out, fmt.Sprintf("candidates.0.content.parts.%d.text", partIndex), content.String())
partIndex++
}
// Handle tool calls
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
if toolCall.Get("type").String() == "function" {
function := toolCall.Get("function")
functionName := function.Get("name").String()
functionArgs := function.Get("arguments").String()
namePath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.name", partIndex)
argsPath := fmt.Sprintf("candidates.0.content.parts.%d.functionCall.args", partIndex)
out, _ = sjson.Set(out, namePath, functionName)
out, _ = sjson.SetRaw(out, argsPath, parseArgsToObjectRaw(functionArgs))
partIndex++
}
return true
})
}
// Handle finish reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String())
out, _ = sjson.Set(out, "candidates.0.finishReason", geminiFinishReason)
}
// Set index
out, _ = sjson.Set(out, "candidates.0.index", choiceIdx)
return true
})
}
// Handle usage information
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.Set(out, "usageMetadata.promptTokenCount", usage.Get("prompt_tokens").Int())
out, _ = sjson.Set(out, "usageMetadata.candidatesTokenCount", usage.Get("completion_tokens").Int())
out, _ = sjson.Set(out, "usageMetadata.totalTokenCount", usage.Get("total_tokens").Int())
if reasoningTokens := reasoningTokensFromUsage(usage); reasoningTokens > 0 {
out, _ = sjson.Set(out, "usageMetadata.thoughtsTokenCount", reasoningTokens)
}
}
return out
}
func GeminiTokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}
func reasoningTokensFromUsage(usage gjson.Result) int64 {
if usage.Exists() {
if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
return v.Int()
}
if v := usage.Get("output_tokens_details.reasoning_tokens"); v.Exists() {
return v.Int()
}
}
return 0
}
func extractReasoningTexts(node gjson.Result) []string {
var texts []string
if !node.Exists() {
return texts
}
if node.IsArray() {
node.ForEach(func(_, value gjson.Result) bool {
texts = append(texts, extractReasoningTexts(value)...)
return true
})
return texts
}
switch node.Type {
case gjson.String:
texts = append(texts, node.String())
case gjson.JSON:
if text := node.Get("text"); text.Exists() {
texts = append(texts, text.String())
} else if raw := strings.TrimSpace(node.Raw); raw != "" && !strings.HasPrefix(raw, "{") && !strings.HasPrefix(raw, "[") {
texts = append(texts, raw)
}
}
return texts
}
================================================
FILE: internal/translator/openai/gemini-cli/init.go
================================================
package geminiCLI
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
GeminiCLI,
OpenAI,
ConvertGeminiCLIRequestToOpenAI,
interfaces.TranslateResponse{
Stream: ConvertOpenAIResponseToGeminiCLI,
NonStream: ConvertOpenAIResponseToGeminiCLINonStream,
TokenCount: GeminiCLITokenCount,
},
)
}
================================================
FILE: internal/translator/openai/gemini-cli/openai_gemini_request.go
================================================
// Package geminiCLI provides request translation functionality for Gemini to OpenAI API.
// It handles parsing and transforming Gemini API requests into OpenAI Chat Completions API format,
// extracting model information, generation config, message contents, and tool declarations.
// The package performs JSON data transformation to ensure compatibility
// between Gemini API format and OpenAI API's expected format.
package geminiCLI
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertGeminiCLIRequestToOpenAI parses and transforms a Gemini API request into OpenAI Chat Completions API format.
// It extracts the model name, generation config, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the OpenAI API.
func ConvertGeminiCLIRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
if gjson.GetBytes(rawJSON, "systemInstruction").Exists() {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw))
rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction")
}
return ConvertGeminiRequestToOpenAI(modelName, rawJSON, stream)
}
================================================
FILE: internal/translator/openai/gemini-cli/openai_gemini_response.go
================================================
// Package geminiCLI provides response translation functionality for OpenAI to Gemini API.
// This package handles the conversion of OpenAI Chat Completions API responses into Gemini API-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by Gemini API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, and usage metadata appropriately.
package geminiCLI
import (
"context"
"fmt"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini"
"github.com/tidwall/sjson"
)
// ConvertOpenAIResponseToGeminiCLI converts OpenAI Chat Completions streaming response format to Gemini API format.
// This function processes OpenAI streaming chunks and transforms them into Gemini-compatible JSON responses.
// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini API format.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the OpenAI API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - []string: A slice of strings, each containing a Gemini-compatible JSON response.
func ConvertOpenAIResponseToGeminiCLI(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
outputs := ConvertOpenAIResponseToGemini(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
newOutputs := make([]string, 0)
for i := 0; i < len(outputs); i++ {
json := `{"response": {}}`
output, _ := sjson.SetRaw(json, "response", outputs[i])
newOutputs = append(newOutputs, output)
}
return newOutputs
}
// ConvertOpenAIResponseToGeminiCLINonStream converts a non-streaming OpenAI response to a non-streaming Gemini CLI response.
//
// Parameters:
// - ctx: The context for the request.
// - modelName: The name of the model.
// - rawJSON: The raw JSON response from the OpenAI API.
// - param: A pointer to a parameter object for the conversion.
//
// Returns:
// - string: A Gemini-compatible JSON response.
func ConvertOpenAIResponseToGeminiCLINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
strJSON := ConvertOpenAIResponseToGeminiNonStream(ctx, modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
json := `{"response": {}}`
strJSON, _ = sjson.SetRaw(json, "response", strJSON)
return strJSON
}
func GeminiCLITokenCount(ctx context.Context, count int64) string {
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
}
================================================
FILE: internal/translator/openai/openai/chat-completions/init.go
================================================
package chat_completions
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenAI,
OpenAI,
ConvertOpenAIRequestToOpenAI,
interfaces.TranslateResponse{
Stream: ConvertOpenAIResponseToOpenAI,
NonStream: ConvertOpenAIResponseToOpenAINonStream,
},
)
}
================================================
FILE: internal/translator/openai/openai/chat-completions/openai_openai_request.go
================================================
// Package openai provides request translation functionality for OpenAI to Gemini CLI API compatibility.
// It converts OpenAI Chat Completions requests into Gemini CLI compatible JSON using gjson/sjson only.
package chat_completions
import (
"github.com/tidwall/sjson"
)
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
// into a complete Gemini CLI request JSON. All JSON construction uses sjson and lookups use gjson.
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data from the OpenAI API
// - stream: A boolean indicating if the request is for a streaming response (unused in current implementation)
//
// Returns:
// - []byte: The transformed request data in Gemini CLI API format
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
// Update the "model" field in the JSON payload with the provided modelName
// The sjson.SetBytes function returns a new byte slice with the updated JSON.
updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName)
if err != nil {
// If there's an error, return the original JSON or handle the error appropriately.
// For now, we'll return the original, but in a real scenario, logging or a more robust error
// handling mechanism would be needed.
return inputRawJSON
}
return updatedJSON
}
================================================
FILE: internal/translator/openai/openai/chat-completions/openai_openai_response.go
================================================
// Package openai provides response translation functionality for Gemini CLI to OpenAI API compatibility.
// This package handles the conversion of Gemini CLI API responses into OpenAI Chat Completions-compatible
// JSON format, transforming streaming events and non-streaming responses into the format
// expected by OpenAI API clients. It supports both streaming and non-streaming modes,
// handling text content, tool calls, reasoning content, and usage metadata appropriately.
package chat_completions
import (
"bytes"
"context"
)
// ConvertOpenAIResponseToOpenAI translates a single chunk of a streaming response from the
// Gemini CLI API format to the OpenAI Chat Completions streaming format.
// It processes various Gemini CLI event types and transforms them into OpenAI-compatible JSON responses.
// The function handles text content, tool calls, reasoning content, and usage metadata, outputting
// responses that match the OpenAI API format. It supports incremental updates for streaming responses.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response (unused in current implementation)
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for maintaining state between calls
//
// Returns:
// - []string: A slice of strings, each containing an OpenAI-compatible JSON response
func ConvertOpenAIResponseToOpenAI(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
return []string{string(rawJSON)}
}
// ConvertOpenAIResponseToOpenAINonStream converts a non-streaming Gemini CLI response to a non-streaming OpenAI response.
// This function processes the complete Gemini CLI response and transforms it into a single OpenAI-compatible
// JSON response. It handles message content, tool calls, reasoning content, and usage metadata, combining all
// the information into a single response that matches the OpenAI API format.
//
// Parameters:
// - ctx: The context for the request, used for cancellation and timeout handling
// - modelName: The name of the model being used for the response
// - rawJSON: The raw JSON response from the Gemini CLI API
// - param: A pointer to a parameter object for the conversion
//
// Returns:
// - string: An OpenAI-compatible JSON response containing all message content and metadata
func ConvertOpenAIResponseToOpenAINonStream(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
return string(rawJSON)
}
================================================
FILE: internal/translator/openai/openai/responses/init.go
================================================
package responses
import (
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/translator"
)
func init() {
translator.Register(
OpenaiResponse,
OpenAI,
ConvertOpenAIResponsesRequestToOpenAIChatCompletions,
interfaces.TranslateResponse{
Stream: ConvertOpenAIChatCompletionsResponseToOpenAIResponses,
NonStream: ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream,
},
)
}
================================================
FILE: internal/translator/openai/openai/responses/openai_openai-responses_request.go
================================================
package responses
import (
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertOpenAIResponsesRequestToOpenAIChatCompletions converts OpenAI responses format to OpenAI chat completions format.
// It transforms the OpenAI responses API format (with instructions and input array) into the standard
// OpenAI chat completions format (with messages array and system content).
//
// The conversion handles:
// 1. Model name and streaming configuration
// 2. Instructions to system message conversion
// 3. Input array to messages array transformation
// 4. Tool definitions and tool choice conversion
// 5. Function calls and function results handling
// 6. Generation parameters mapping (max_tokens, reasoning, etc.)
//
// Parameters:
// - modelName: The name of the model to use for the request
// - rawJSON: The raw JSON request data in OpenAI responses format
// - stream: A boolean indicating if the request is for a streaming response
//
// Returns:
// - []byte: The transformed request data in OpenAI chat completions format
func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inputRawJSON []byte, stream bool) []byte {
rawJSON := inputRawJSON
// Base OpenAI chat completions template with default values
out := `{"model":"","messages":[],"stream":false}`
root := gjson.ParseBytes(rawJSON)
// Set model name
out, _ = sjson.Set(out, "model", modelName)
// Set stream configuration
out, _ = sjson.Set(out, "stream", stream)
// Map generation parameters from responses format to chat completions format
if maxTokens := root.Get("max_output_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
}
if parallelToolCalls := root.Get("parallel_tool_calls"); parallelToolCalls.Exists() {
out, _ = sjson.Set(out, "parallel_tool_calls", parallelToolCalls.Bool())
}
// Convert instructions to system message
if instructions := root.Get("instructions"); instructions.Exists() {
systemMessage := `{"role":"system","content":""}`
systemMessage, _ = sjson.Set(systemMessage, "content", instructions.String())
out, _ = sjson.SetRaw(out, "messages.-1", systemMessage)
}
// Convert input array to messages
if input := root.Get("input"); input.Exists() && input.IsArray() {
input.ForEach(func(_, item gjson.Result) bool {
itemType := item.Get("type").String()
if itemType == "" && item.Get("role").String() != "" {
itemType = "message"
}
switch itemType {
case "message", "":
// Handle regular message conversion
role := item.Get("role").String()
if role == "developer" {
role = "user"
}
message := `{"role":"","content":[]}`
message, _ = sjson.Set(message, "role", role)
if content := item.Get("content"); content.Exists() && content.IsArray() {
var messageContent string
var toolCalls []interface{}
content.ForEach(func(_, contentItem gjson.Result) bool {
contentType := contentItem.Get("type").String()
if contentType == "" {
contentType = "input_text"
}
switch contentType {
case "input_text", "output_text":
text := contentItem.Get("text").String()
contentPart := `{"type":"text","text":""}`
contentPart, _ = sjson.Set(contentPart, "text", text)
message, _ = sjson.SetRaw(message, "content.-1", contentPart)
case "input_image":
imageURL := contentItem.Get("image_url").String()
contentPart := `{"type":"image_url","image_url":{"url":""}}`
contentPart, _ = sjson.Set(contentPart, "image_url.url", imageURL)
message, _ = sjson.SetRaw(message, "content.-1", contentPart)
}
return true
})
if messageContent != "" {
message, _ = sjson.Set(message, "content", messageContent)
}
if len(toolCalls) > 0 {
message, _ = sjson.Set(message, "tool_calls", toolCalls)
}
} else if content.Type == gjson.String {
message, _ = sjson.Set(message, "content", content.String())
}
out, _ = sjson.SetRaw(out, "messages.-1", message)
case "function_call":
// Handle function call conversion to assistant message with tool_calls
assistantMessage := `{"role":"assistant","tool_calls":[]}`
toolCall := `{"id":"","type":"function","function":{"name":"","arguments":""}}`
if callId := item.Get("call_id"); callId.Exists() {
toolCall, _ = sjson.Set(toolCall, "id", callId.String())
}
if name := item.Get("name"); name.Exists() {
toolCall, _ = sjson.Set(toolCall, "function.name", name.String())
}
if arguments := item.Get("arguments"); arguments.Exists() {
toolCall, _ = sjson.Set(toolCall, "function.arguments", arguments.String())
}
assistantMessage, _ = sjson.SetRaw(assistantMessage, "tool_calls.0", toolCall)
out, _ = sjson.SetRaw(out, "messages.-1", assistantMessage)
case "function_call_output":
// Handle function call output conversion to tool message
toolMessage := `{"role":"tool","tool_call_id":"","content":""}`
if callId := item.Get("call_id"); callId.Exists() {
toolMessage, _ = sjson.Set(toolMessage, "tool_call_id", callId.String())
}
if output := item.Get("output"); output.Exists() {
toolMessage, _ = sjson.Set(toolMessage, "content", output.String())
}
out, _ = sjson.SetRaw(out, "messages.-1", toolMessage)
}
return true
})
} else if input.Type == gjson.String {
msg := "{}"
msg, _ = sjson.Set(msg, "role", "user")
msg, _ = sjson.Set(msg, "content", input.String())
out, _ = sjson.SetRaw(out, "messages.-1", msg)
}
// Convert tools from responses format to chat completions format
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
var chatCompletionsTools []interface{}
tools.ForEach(func(_, tool gjson.Result) bool {
// Built-in tools (e.g. {"type":"web_search"}) are already compatible with the Chat Completions schema.
// Only function tools need structural conversion because Chat Completions nests details under "function".
toolType := tool.Get("type").String()
if toolType != "" && toolType != "function" && tool.IsObject() {
// Almost all providers lack built-in tools, so we just ignore them.
// chatCompletionsTools = append(chatCompletionsTools, tool.Value())
return true
}
chatTool := `{"type":"function","function":{}}`
// Convert tool structure from responses format to chat completions format
function := `{"name":"","description":"","parameters":{}}`
if name := tool.Get("name"); name.Exists() {
function, _ = sjson.Set(function, "name", name.String())
}
if description := tool.Get("description"); description.Exists() {
function, _ = sjson.Set(function, "description", description.String())
}
if parameters := tool.Get("parameters"); parameters.Exists() {
function, _ = sjson.SetRaw(function, "parameters", parameters.Raw)
}
chatTool, _ = sjson.SetRaw(chatTool, "function", function)
chatCompletionsTools = append(chatCompletionsTools, gjson.Parse(chatTool).Value())
return true
})
if len(chatCompletionsTools) > 0 {
out, _ = sjson.Set(out, "tools", chatCompletionsTools)
}
}
if reasoningEffort := root.Get("reasoning.effort"); reasoningEffort.Exists() {
effort := strings.ToLower(strings.TrimSpace(reasoningEffort.String()))
if effort != "" {
out, _ = sjson.Set(out, "reasoning_effort", effort)
}
}
// Convert tool_choice if present
if toolChoice := root.Get("tool_choice"); toolChoice.Exists() {
out, _ = sjson.Set(out, "tool_choice", toolChoice.String())
}
return []byte(out)
}
================================================
FILE: internal/translator/openai/openai/responses/openai_openai-responses_response.go
================================================
package responses
import (
"bytes"
"context"
"fmt"
"strings"
"sync/atomic"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type oaiToResponsesStateReasoning struct {
ReasoningID string
ReasoningData string
}
type oaiToResponsesState struct {
Seq int
ResponseID string
Created int64
Started bool
ReasoningID string
ReasoningIndex int
// aggregation buffers for response.output
// Per-output message text buffers by index
MsgTextBuf map[int]*strings.Builder
ReasoningBuf strings.Builder
Reasonings []oaiToResponsesStateReasoning
FuncArgsBuf map[int]*strings.Builder // index -> args
FuncNames map[int]string // index -> name
FuncCallIDs map[int]string // index -> call_id
// message item state per output index
MsgItemAdded map[int]bool // whether response.output_item.added emitted for message
MsgContentAdded map[int]bool // whether response.content_part.added emitted for message
MsgItemDone map[int]bool // whether message done events were emitted
// function item done state
FuncArgsDone map[int]bool
FuncItemDone map[int]bool
// usage aggregation
PromptTokens int64
CachedTokens int64
CompletionTokens int64
TotalTokens int64
ReasoningTokens int64
UsageSeen bool
}
// responseIDCounter provides a process-wide unique counter for synthesized response identifiers.
var responseIDCounter uint64
func emitRespEvent(event string, payload string) string {
return fmt.Sprintf("event: %s\ndata: %s", event, payload)
}
// ConvertOpenAIChatCompletionsResponseToOpenAIResponses converts OpenAI Chat Completions streaming chunks
// to OpenAI Responses SSE events (response.*).
func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
if *param == nil {
*param = &oaiToResponsesState{
FuncArgsBuf: make(map[int]*strings.Builder),
FuncNames: make(map[int]string),
FuncCallIDs: make(map[int]string),
MsgTextBuf: make(map[int]*strings.Builder),
MsgItemAdded: make(map[int]bool),
MsgContentAdded: make(map[int]bool),
MsgItemDone: make(map[int]bool),
FuncArgsDone: make(map[int]bool),
FuncItemDone: make(map[int]bool),
Reasonings: make([]oaiToResponsesStateReasoning, 0),
}
}
st := (*param).(*oaiToResponsesState)
if bytes.HasPrefix(rawJSON, []byte("data:")) {
rawJSON = bytes.TrimSpace(rawJSON[5:])
}
rawJSON = bytes.TrimSpace(rawJSON)
if len(rawJSON) == 0 {
return []string{}
}
if bytes.Equal(rawJSON, []byte("[DONE]")) {
return []string{}
}
root := gjson.ParseBytes(rawJSON)
obj := root.Get("object")
if obj.Exists() && obj.String() != "" && obj.String() != "chat.completion.chunk" {
return []string{}
}
if !root.Get("choices").Exists() || !root.Get("choices").IsArray() {
return []string{}
}
if usage := root.Get("usage"); usage.Exists() {
if v := usage.Get("prompt_tokens"); v.Exists() {
st.PromptTokens = v.Int()
st.UsageSeen = true
}
if v := usage.Get("prompt_tokens_details.cached_tokens"); v.Exists() {
st.CachedTokens = v.Int()
st.UsageSeen = true
}
if v := usage.Get("completion_tokens"); v.Exists() {
st.CompletionTokens = v.Int()
st.UsageSeen = true
} else if v := usage.Get("output_tokens"); v.Exists() {
st.CompletionTokens = v.Int()
st.UsageSeen = true
}
if v := usage.Get("output_tokens_details.reasoning_tokens"); v.Exists() {
st.ReasoningTokens = v.Int()
st.UsageSeen = true
} else if v := usage.Get("completion_tokens_details.reasoning_tokens"); v.Exists() {
st.ReasoningTokens = v.Int()
st.UsageSeen = true
}
if v := usage.Get("total_tokens"); v.Exists() {
st.TotalTokens = v.Int()
st.UsageSeen = true
}
}
nextSeq := func() int { st.Seq++; return st.Seq }
var out []string
if !st.Started {
st.ResponseID = root.Get("id").String()
st.Created = root.Get("created").Int()
// reset aggregation state for a new streaming response
st.MsgTextBuf = make(map[int]*strings.Builder)
st.ReasoningBuf.Reset()
st.ReasoningID = ""
st.ReasoningIndex = 0
st.FuncArgsBuf = make(map[int]*strings.Builder)
st.FuncNames = make(map[int]string)
st.FuncCallIDs = make(map[int]string)
st.MsgItemAdded = make(map[int]bool)
st.MsgContentAdded = make(map[int]bool)
st.MsgItemDone = make(map[int]bool)
st.FuncArgsDone = make(map[int]bool)
st.FuncItemDone = make(map[int]bool)
st.PromptTokens = 0
st.CachedTokens = 0
st.CompletionTokens = 0
st.TotalTokens = 0
st.ReasoningTokens = 0
st.UsageSeen = false
// response.created
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
created, _ = sjson.Set(created, "sequence_number", nextSeq())
created, _ = sjson.Set(created, "response.id", st.ResponseID)
created, _ = sjson.Set(created, "response.created_at", st.Created)
out = append(out, emitRespEvent("response.created", created))
inprog := `{"type":"response.in_progress","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress"}}`
inprog, _ = sjson.Set(inprog, "sequence_number", nextSeq())
inprog, _ = sjson.Set(inprog, "response.id", st.ResponseID)
inprog, _ = sjson.Set(inprog, "response.created_at", st.Created)
out = append(out, emitRespEvent("response.in_progress", inprog))
st.Started = true
}
stopReasoning := func(text string) {
// Emit reasoning done events
textDone := `{"type":"response.reasoning_summary_text.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
textDone, _ = sjson.Set(textDone, "sequence_number", nextSeq())
textDone, _ = sjson.Set(textDone, "item_id", st.ReasoningID)
textDone, _ = sjson.Set(textDone, "output_index", st.ReasoningIndex)
textDone, _ = sjson.Set(textDone, "text", text)
out = append(out, emitRespEvent("response.reasoning_summary_text.done", textDone))
partDone := `{"type":"response.reasoning_summary_part.done","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", st.ReasoningID)
partDone, _ = sjson.Set(partDone, "output_index", st.ReasoningIndex)
partDone, _ = sjson.Set(partDone, "part.text", text)
out = append(out, emitRespEvent("response.reasoning_summary_part.done", partDone))
outputItemDone := `{"type":"response.output_item.done","item":{"id":"","type":"reasoning","encrypted_content":"","summary":[{"type":"summary_text","text":""}]},"output_index":0,"sequence_number":0}`
outputItemDone, _ = sjson.Set(outputItemDone, "sequence_number", nextSeq())
outputItemDone, _ = sjson.Set(outputItemDone, "item.id", st.ReasoningID)
outputItemDone, _ = sjson.Set(outputItemDone, "output_index", st.ReasoningIndex)
outputItemDone, _ = sjson.Set(outputItemDone, "item.summary.text", text)
out = append(out, emitRespEvent("response.output_item.done", outputItemDone))
st.Reasonings = append(st.Reasonings, oaiToResponsesStateReasoning{ReasoningID: st.ReasoningID, ReasoningData: text})
st.ReasoningID = ""
}
// choices[].delta content / tool_calls / reasoning_content
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
choices.ForEach(func(_, choice gjson.Result) bool {
idx := int(choice.Get("index").Int())
delta := choice.Get("delta")
if delta.Exists() {
if c := delta.Get("content"); c.Exists() && c.String() != "" {
// Ensure the message item and its first content part are announced before any text deltas
if st.ReasoningID != "" {
stopReasoning(st.ReasoningBuf.String())
st.ReasoningBuf.Reset()
}
if !st.MsgItemAdded[idx] {
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"in_progress","content":[],"role":"assistant"}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
out = append(out, emitRespEvent("response.output_item.added", item))
st.MsgItemAdded[idx] = true
}
if !st.MsgContentAdded[idx] {
part := `{"type":"response.content_part.added","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
part, _ = sjson.Set(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
part, _ = sjson.Set(part, "output_index", idx)
part, _ = sjson.Set(part, "content_index", 0)
out = append(out, emitRespEvent("response.content_part.added", part))
st.MsgContentAdded[idx] = true
}
msg := `{"type":"response.output_text.delta","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"delta":"","logprobs":[]}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
msg, _ = sjson.Set(msg, "output_index", idx)
msg, _ = sjson.Set(msg, "content_index", 0)
msg, _ = sjson.Set(msg, "delta", c.String())
out = append(out, emitRespEvent("response.output_text.delta", msg))
// aggregate for response.output
if st.MsgTextBuf[idx] == nil {
st.MsgTextBuf[idx] = &strings.Builder{}
}
st.MsgTextBuf[idx].WriteString(c.String())
}
// reasoning_content (OpenAI reasoning incremental text)
if rc := delta.Get("reasoning_content"); rc.Exists() && rc.String() != "" {
// On first appearance, add reasoning item and part
if st.ReasoningID == "" {
st.ReasoningID = fmt.Sprintf("rs_%s_%d", st.ResponseID, idx)
st.ReasoningIndex = idx
item := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"reasoning","status":"in_progress","summary":[]}}`
item, _ = sjson.Set(item, "sequence_number", nextSeq())
item, _ = sjson.Set(item, "output_index", idx)
item, _ = sjson.Set(item, "item.id", st.ReasoningID)
out = append(out, emitRespEvent("response.output_item.added", item))
part := `{"type":"response.reasoning_summary_part.added","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"part":{"type":"summary_text","text":""}}`
part, _ = sjson.Set(part, "sequence_number", nextSeq())
part, _ = sjson.Set(part, "item_id", st.ReasoningID)
part, _ = sjson.Set(part, "output_index", st.ReasoningIndex)
out = append(out, emitRespEvent("response.reasoning_summary_part.added", part))
}
// Append incremental text to reasoning buffer
st.ReasoningBuf.WriteString(rc.String())
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.ReasoningID)
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.Set(msg, "delta", rc.String())
out = append(out, emitRespEvent("response.reasoning_summary_text.delta", msg))
}
// tool calls
if tcs := delta.Get("tool_calls"); tcs.Exists() && tcs.IsArray() {
if st.ReasoningID != "" {
stopReasoning(st.ReasoningBuf.String())
st.ReasoningBuf.Reset()
}
// Before emitting any function events, if a message is open for this index,
// close its text/content to match Codex expected ordering.
if st.MsgItemAdded[idx] && !st.MsgItemDone[idx] {
fullText := ""
if b := st.MsgTextBuf[idx]; b != nil {
fullText = b.String()
}
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
done, _ = sjson.Set(done, "output_index", idx)
done, _ = sjson.Set(done, "content_index", 0)
done, _ = sjson.Set(done, "text", fullText)
out = append(out, emitRespEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
partDone, _ = sjson.Set(partDone, "output_index", idx)
partDone, _ = sjson.Set(partDone, "content_index", 0)
partDone, _ = sjson.Set(partDone, "part.text", fullText)
out = append(out, emitRespEvent("response.content_part.done", partDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", idx)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, idx))
itemDone, _ = sjson.Set(itemDone, "item.content.0.text", fullText)
out = append(out, emitRespEvent("response.output_item.done", itemDone))
st.MsgItemDone[idx] = true
}
// Only emit item.added once per tool call and preserve call_id across chunks.
newCallID := tcs.Get("0.id").String()
nameChunk := tcs.Get("0.function.name").String()
if nameChunk != "" {
st.FuncNames[idx] = nameChunk
}
existingCallID := st.FuncCallIDs[idx]
effectiveCallID := existingCallID
shouldEmitItem := false
if existingCallID == "" && newCallID != "" {
// First time seeing a valid call_id for this index
effectiveCallID = newCallID
st.FuncCallIDs[idx] = newCallID
shouldEmitItem = true
}
if shouldEmitItem && effectiveCallID != "" {
o := `{"type":"response.output_item.added","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"in_progress","arguments":"","call_id":"","name":""}}`
o, _ = sjson.Set(o, "sequence_number", nextSeq())
o, _ = sjson.Set(o, "output_index", idx)
o, _ = sjson.Set(o, "item.id", fmt.Sprintf("fc_%s", effectiveCallID))
o, _ = sjson.Set(o, "item.call_id", effectiveCallID)
name := st.FuncNames[idx]
o, _ = sjson.Set(o, "item.name", name)
out = append(out, emitRespEvent("response.output_item.added", o))
}
// Ensure args buffer exists for this index
if st.FuncArgsBuf[idx] == nil {
st.FuncArgsBuf[idx] = &strings.Builder{}
}
// Append arguments delta if available and we have a valid call_id to reference
if args := tcs.Get("0.function.arguments"); args.Exists() && args.String() != "" {
// Prefer an already known call_id; fall back to newCallID if first time
refCallID := st.FuncCallIDs[idx]
if refCallID == "" {
refCallID = newCallID
}
if refCallID != "" {
ad := `{"type":"response.function_call_arguments.delta","sequence_number":0,"item_id":"","output_index":0,"delta":""}`
ad, _ = sjson.Set(ad, "sequence_number", nextSeq())
ad, _ = sjson.Set(ad, "item_id", fmt.Sprintf("fc_%s", refCallID))
ad, _ = sjson.Set(ad, "output_index", idx)
ad, _ = sjson.Set(ad, "delta", args.String())
out = append(out, emitRespEvent("response.function_call_arguments.delta", ad))
}
st.FuncArgsBuf[idx].WriteString(args.String())
}
}
}
// finish_reason triggers finalization, including text done/content done/item done,
// reasoning done/part.done, function args done/item done, and completed
if fr := choice.Get("finish_reason"); fr.Exists() && fr.String() != "" {
// Emit message done events for all indices that started a message
if len(st.MsgItemAdded) > 0 {
// sort indices for deterministic order
idxs := make([]int, 0, len(st.MsgItemAdded))
for i := range st.MsgItemAdded {
idxs = append(idxs, i)
}
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, i := range idxs {
if st.MsgItemAdded[i] && !st.MsgItemDone[i] {
fullText := ""
if b := st.MsgTextBuf[i]; b != nil {
fullText = b.String()
}
done := `{"type":"response.output_text.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"text":"","logprobs":[]}`
done, _ = sjson.Set(done, "sequence_number", nextSeq())
done, _ = sjson.Set(done, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
done, _ = sjson.Set(done, "output_index", i)
done, _ = sjson.Set(done, "content_index", 0)
done, _ = sjson.Set(done, "text", fullText)
out = append(out, emitRespEvent("response.output_text.done", done))
partDone := `{"type":"response.content_part.done","sequence_number":0,"item_id":"","output_index":0,"content_index":0,"part":{"type":"output_text","annotations":[],"logprobs":[],"text":""}}`
partDone, _ = sjson.Set(partDone, "sequence_number", nextSeq())
partDone, _ = sjson.Set(partDone, "item_id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
partDone, _ = sjson.Set(partDone, "output_index", i)
partDone, _ = sjson.Set(partDone, "content_index", 0)
partDone, _ = sjson.Set(partDone, "part.text", fullText)
out = append(out, emitRespEvent("response.content_part.done", partDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", i)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
itemDone, _ = sjson.Set(itemDone, "item.content.0.text", fullText)
out = append(out, emitRespEvent("response.output_item.done", itemDone))
st.MsgItemDone[i] = true
}
}
}
if st.ReasoningID != "" {
stopReasoning(st.ReasoningBuf.String())
st.ReasoningBuf.Reset()
}
// Emit function call done events for any active function calls
if len(st.FuncCallIDs) > 0 {
idxs := make([]int, 0, len(st.FuncCallIDs))
for i := range st.FuncCallIDs {
idxs = append(idxs, i)
}
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, i := range idxs {
callID := st.FuncCallIDs[i]
if callID == "" || st.FuncItemDone[i] {
continue
}
args := "{}"
if b := st.FuncArgsBuf[i]; b != nil && b.Len() > 0 {
args = b.String()
}
fcDone := `{"type":"response.function_call_arguments.done","sequence_number":0,"item_id":"","output_index":0,"arguments":""}`
fcDone, _ = sjson.Set(fcDone, "sequence_number", nextSeq())
fcDone, _ = sjson.Set(fcDone, "item_id", fmt.Sprintf("fc_%s", callID))
fcDone, _ = sjson.Set(fcDone, "output_index", i)
fcDone, _ = sjson.Set(fcDone, "arguments", args)
out = append(out, emitRespEvent("response.function_call_arguments.done", fcDone))
itemDone := `{"type":"response.output_item.done","sequence_number":0,"output_index":0,"item":{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}}`
itemDone, _ = sjson.Set(itemDone, "sequence_number", nextSeq())
itemDone, _ = sjson.Set(itemDone, "output_index", i)
itemDone, _ = sjson.Set(itemDone, "item.id", fmt.Sprintf("fc_%s", callID))
itemDone, _ = sjson.Set(itemDone, "item.arguments", args)
itemDone, _ = sjson.Set(itemDone, "item.call_id", callID)
itemDone, _ = sjson.Set(itemDone, "item.name", st.FuncNames[i])
out = append(out, emitRespEvent("response.output_item.done", itemDone))
st.FuncItemDone[i] = true
st.FuncArgsDone[i] = true
}
}
completed := `{"type":"response.completed","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null}}`
completed, _ = sjson.Set(completed, "sequence_number", nextSeq())
completed, _ = sjson.Set(completed, "response.id", st.ResponseID)
completed, _ = sjson.Set(completed, "response.created_at", st.Created)
// Inject original request fields into response as per docs/response.completed.json
if requestRawJSON != nil {
req := gjson.ParseBytes(requestRawJSON)
if v := req.Get("instructions"); v.Exists() {
completed, _ = sjson.Set(completed, "response.instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_output_tokens", v.Int())
}
if v := req.Get("max_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
completed, _ = sjson.Set(completed, "response.model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
completed, _ = sjson.Set(completed, "response.parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
completed, _ = sjson.Set(completed, "response.previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
completed, _ = sjson.Set(completed, "response.prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
completed, _ = sjson.Set(completed, "response.reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
completed, _ = sjson.Set(completed, "response.service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
completed, _ = sjson.Set(completed, "response.store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
completed, _ = sjson.Set(completed, "response.temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
completed, _ = sjson.Set(completed, "response.text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
completed, _ = sjson.Set(completed, "response.tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
completed, _ = sjson.Set(completed, "response.top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
completed, _ = sjson.Set(completed, "response.truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
completed, _ = sjson.Set(completed, "response.user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
completed, _ = sjson.Set(completed, "response.metadata", v.Value())
}
}
// Build response.output using aggregated buffers
outputsWrapper := `{"arr":[]}`
if len(st.Reasonings) > 0 {
for _, r := range st.Reasonings {
item := `{"id":"","type":"reasoning","summary":[{"type":"summary_text","text":""}]}`
item, _ = sjson.Set(item, "id", r.ReasoningID)
item, _ = sjson.Set(item, "summary.0.text", r.ReasoningData)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
// Append message items in ascending index order
if len(st.MsgItemAdded) > 0 {
midxs := make([]int, 0, len(st.MsgItemAdded))
for i := range st.MsgItemAdded {
midxs = append(midxs, i)
}
for i := 0; i < len(midxs); i++ {
for j := i + 1; j < len(midxs); j++ {
if midxs[j] < midxs[i] {
midxs[i], midxs[j] = midxs[j], midxs[i]
}
}
}
for _, i := range midxs {
txt := ""
if b := st.MsgTextBuf[i]; b != nil {
txt = b.String()
}
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", st.ResponseID, i))
item, _ = sjson.Set(item, "content.0.text", txt)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if len(st.FuncArgsBuf) > 0 {
idxs := make([]int, 0, len(st.FuncArgsBuf))
for i := range st.FuncArgsBuf {
idxs = append(idxs, i)
}
// small-N sort without extra imports
for i := 0; i < len(idxs); i++ {
for j := i + 1; j < len(idxs); j++ {
if idxs[j] < idxs[i] {
idxs[i], idxs[j] = idxs[j], idxs[i]
}
}
}
for _, i := range idxs {
args := ""
if b := st.FuncArgsBuf[i]; b != nil {
args = b.String()
}
callID := st.FuncCallIDs[i]
name := st.FuncNames[i]
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
}
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
completed, _ = sjson.SetRaw(completed, "response.output", gjson.Get(outputsWrapper, "arr").Raw)
}
if st.UsageSeen {
completed, _ = sjson.Set(completed, "response.usage.input_tokens", st.PromptTokens)
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", st.CachedTokens)
completed, _ = sjson.Set(completed, "response.usage.output_tokens", st.CompletionTokens)
if st.ReasoningTokens > 0 {
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", st.ReasoningTokens)
}
total := st.TotalTokens
if total == 0 {
total = st.PromptTokens + st.CompletionTokens
}
completed, _ = sjson.Set(completed, "response.usage.total_tokens", total)
}
out = append(out, emitRespEvent("response.completed", completed))
}
return true
})
}
return out
}
// ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream builds a single Responses JSON
// from a non-streaming OpenAI Chat Completions response.
func ConvertOpenAIChatCompletionsResponseToOpenAIResponsesNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
root := gjson.ParseBytes(rawJSON)
// Basic response scaffold
resp := `{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"incomplete_details":null}`
// id: use provider id if present, otherwise synthesize
id := root.Get("id").String()
if id == "" {
id = fmt.Sprintf("resp_%x_%d", time.Now().UnixNano(), atomic.AddUint64(&responseIDCounter, 1))
}
resp, _ = sjson.Set(resp, "id", id)
// created_at: map from chat.completion created
created := root.Get("created").Int()
if created == 0 {
created = time.Now().Unix()
}
resp, _ = sjson.Set(resp, "created_at", created)
// Echo request fields when available (aligns with streaming path behavior)
if len(requestRawJSON) > 0 {
req := gjson.ParseBytes(requestRawJSON)
if v := req.Get("instructions"); v.Exists() {
resp, _ = sjson.Set(resp, "instructions", v.String())
}
if v := req.Get("max_output_tokens"); v.Exists() {
resp, _ = sjson.Set(resp, "max_output_tokens", v.Int())
} else {
// Also support max_tokens from chat completion style
if v = req.Get("max_tokens"); v.Exists() {
resp, _ = sjson.Set(resp, "max_output_tokens", v.Int())
}
}
if v := req.Get("max_tool_calls"); v.Exists() {
resp, _ = sjson.Set(resp, "max_tool_calls", v.Int())
}
if v := req.Get("model"); v.Exists() {
resp, _ = sjson.Set(resp, "model", v.String())
} else if v = root.Get("model"); v.Exists() {
resp, _ = sjson.Set(resp, "model", v.String())
}
if v := req.Get("parallel_tool_calls"); v.Exists() {
resp, _ = sjson.Set(resp, "parallel_tool_calls", v.Bool())
}
if v := req.Get("previous_response_id"); v.Exists() {
resp, _ = sjson.Set(resp, "previous_response_id", v.String())
}
if v := req.Get("prompt_cache_key"); v.Exists() {
resp, _ = sjson.Set(resp, "prompt_cache_key", v.String())
}
if v := req.Get("reasoning"); v.Exists() {
resp, _ = sjson.Set(resp, "reasoning", v.Value())
}
if v := req.Get("safety_identifier"); v.Exists() {
resp, _ = sjson.Set(resp, "safety_identifier", v.String())
}
if v := req.Get("service_tier"); v.Exists() {
resp, _ = sjson.Set(resp, "service_tier", v.String())
}
if v := req.Get("store"); v.Exists() {
resp, _ = sjson.Set(resp, "store", v.Bool())
}
if v := req.Get("temperature"); v.Exists() {
resp, _ = sjson.Set(resp, "temperature", v.Float())
}
if v := req.Get("text"); v.Exists() {
resp, _ = sjson.Set(resp, "text", v.Value())
}
if v := req.Get("tool_choice"); v.Exists() {
resp, _ = sjson.Set(resp, "tool_choice", v.Value())
}
if v := req.Get("tools"); v.Exists() {
resp, _ = sjson.Set(resp, "tools", v.Value())
}
if v := req.Get("top_logprobs"); v.Exists() {
resp, _ = sjson.Set(resp, "top_logprobs", v.Int())
}
if v := req.Get("top_p"); v.Exists() {
resp, _ = sjson.Set(resp, "top_p", v.Float())
}
if v := req.Get("truncation"); v.Exists() {
resp, _ = sjson.Set(resp, "truncation", v.String())
}
if v := req.Get("user"); v.Exists() {
resp, _ = sjson.Set(resp, "user", v.Value())
}
if v := req.Get("metadata"); v.Exists() {
resp, _ = sjson.Set(resp, "metadata", v.Value())
}
} else if v := root.Get("model"); v.Exists() {
// Fallback model from response
resp, _ = sjson.Set(resp, "model", v.String())
}
// Build output list from choices[...]
outputsWrapper := `{"arr":[]}`
// Detect and capture reasoning content if present
rcText := gjson.GetBytes(rawJSON, "choices.0.message.reasoning_content").String()
includeReasoning := rcText != ""
if !includeReasoning && len(requestRawJSON) > 0 {
includeReasoning = gjson.GetBytes(requestRawJSON, "reasoning").Exists()
}
if includeReasoning {
rid := id
if strings.HasPrefix(rid, "resp_") {
rid = strings.TrimPrefix(rid, "resp_")
}
// Prefer summary_text from reasoning_content; encrypted_content is optional
reasoningItem := `{"id":"","type":"reasoning","encrypted_content":"","summary":[]}`
reasoningItem, _ = sjson.Set(reasoningItem, "id", fmt.Sprintf("rs_%s", rid))
if rcText != "" {
reasoningItem, _ = sjson.Set(reasoningItem, "summary.0.type", "summary_text")
reasoningItem, _ = sjson.Set(reasoningItem, "summary.0.text", rcText)
}
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", reasoningItem)
}
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
choices.ForEach(func(_, choice gjson.Result) bool {
msg := choice.Get("message")
if msg.Exists() {
// Text message part
if c := msg.Get("content"); c.Exists() && c.String() != "" {
item := `{"id":"","type":"message","status":"completed","content":[{"type":"output_text","annotations":[],"logprobs":[],"text":""}],"role":"assistant"}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("msg_%s_%d", id, int(choice.Get("index").Int())))
item, _ = sjson.Set(item, "content.0.text", c.String())
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
}
// Function/tool calls
if tcs := msg.Get("tool_calls"); tcs.Exists() && tcs.IsArray() {
tcs.ForEach(func(_, tc gjson.Result) bool {
callID := tc.Get("id").String()
name := tc.Get("function.name").String()
args := tc.Get("function.arguments").String()
item := `{"id":"","type":"function_call","status":"completed","arguments":"","call_id":"","name":""}`
item, _ = sjson.Set(item, "id", fmt.Sprintf("fc_%s", callID))
item, _ = sjson.Set(item, "arguments", args)
item, _ = sjson.Set(item, "call_id", callID)
item, _ = sjson.Set(item, "name", name)
outputsWrapper, _ = sjson.SetRaw(outputsWrapper, "arr.-1", item)
return true
})
}
}
return true
})
}
if gjson.Get(outputsWrapper, "arr.#").Int() > 0 {
resp, _ = sjson.SetRaw(resp, "output", gjson.Get(outputsWrapper, "arr").Raw)
}
// usage mapping
if usage := root.Get("usage"); usage.Exists() {
// Map common tokens
if usage.Get("prompt_tokens").Exists() || usage.Get("completion_tokens").Exists() || usage.Get("total_tokens").Exists() {
resp, _ = sjson.Set(resp, "usage.input_tokens", usage.Get("prompt_tokens").Int())
if d := usage.Get("prompt_tokens_details.cached_tokens"); d.Exists() {
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", d.Int())
}
resp, _ = sjson.Set(resp, "usage.output_tokens", usage.Get("completion_tokens").Int())
// Reasoning tokens not available in Chat Completions; set only if present under output_tokens_details
if d := usage.Get("output_tokens_details.reasoning_tokens"); d.Exists() {
resp, _ = sjson.Set(resp, "usage.output_tokens_details.reasoning_tokens", d.Int())
}
resp, _ = sjson.Set(resp, "usage.total_tokens", usage.Get("total_tokens").Int())
} else {
// Fallback to raw usage object if structure differs
resp, _ = sjson.Set(resp, "usage", usage.Value())
}
}
return resp
}
================================================
FILE: internal/translator/translator/translator.go
================================================
// Package translator provides request and response translation functionality
// between different AI API formats. It acts as a wrapper around the SDK translator
// registry, providing convenient functions for translating requests and responses
// between OpenAI, Claude, Gemini, and other API formats.
package translator
import (
"context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
// registry holds the default translator registry instance.
var registry = sdktranslator.Default()
// Register registers a new translator for converting between two API formats.
//
// Parameters:
// - from: The source API format identifier
// - to: The target API format identifier
// - request: The request translation function
// - response: The response translation function
func Register(from, to string, request interfaces.TranslateRequestFunc, response interfaces.TranslateResponse) {
registry.Register(sdktranslator.FromString(from), sdktranslator.FromString(to), request, response)
}
// Request translates a request from one API format to another.
//
// Parameters:
// - from: The source API format identifier
// - to: The target API format identifier
// - modelName: The model name for the request
// - rawJSON: The raw JSON request data
// - stream: Whether this is a streaming request
//
// Returns:
// - []byte: The translated request JSON
func Request(from, to, modelName string, rawJSON []byte, stream bool) []byte {
return registry.TranslateRequest(sdktranslator.FromString(from), sdktranslator.FromString(to), modelName, rawJSON, stream)
}
// NeedConvert checks if a response translation is needed between two API formats.
//
// Parameters:
// - from: The source API format identifier
// - to: The target API format identifier
//
// Returns:
// - bool: True if response translation is needed, false otherwise
func NeedConvert(from, to string) bool {
return registry.HasResponseTransformer(sdktranslator.FromString(from), sdktranslator.FromString(to))
}
// Response translates a streaming response from one API format to another.
//
// Parameters:
// - from: The source API format identifier
// - to: The target API format identifier
// - ctx: The context for the translation
// - modelName: The model name for the response
// - originalRequestRawJSON: The original request JSON
// - requestRawJSON: The translated request JSON
// - rawJSON: The raw response JSON
// - param: Additional parameters for translation
//
// Returns:
// - []string: The translated response lines
func Response(from, to string, ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
return registry.TranslateStream(ctx, sdktranslator.FromString(from), sdktranslator.FromString(to), modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// ResponseNonStream translates a non-streaming response from one API format to another.
//
// Parameters:
// - from: The source API format identifier
// - to: The target API format identifier
// - ctx: The context for the translation
// - modelName: The model name for the response
// - originalRequestRawJSON: The original request JSON
// - requestRawJSON: The translated request JSON
// - rawJSON: The raw response JSON
// - param: Additional parameters for translation
//
// Returns:
// - string: The translated response JSON
func ResponseNonStream(from, to string, ctx context.Context, modelName string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
return registry.TranslateNonStream(ctx, sdktranslator.FromString(from), sdktranslator.FromString(to), modelName, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
================================================
FILE: internal/tui/app.go
================================================
package tui
import (
"fmt"
"io"
"os"
"strings"
"github.com/charmbracelet/bubbles/textinput"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// Tab identifiers
const (
tabDashboard = iota
tabConfig
tabAuthFiles
tabAPIKeys
tabOAuth
tabUsage
tabLogs
)
// App is the root bubbletea model that contains all tab sub-models.
type App struct {
activeTab int
tabs []string
standalone bool
logsEnabled bool
authenticated bool
authInput textinput.Model
authError string
authConnecting bool
dashboard dashboardModel
config configTabModel
auth authTabModel
keys keysTabModel
oauth oauthTabModel
usage usageTabModel
logs logsTabModel
client *Client
width int
height int
ready bool
// Track which tabs have been initialized (fetched data)
initialized [7]bool
}
type authConnectMsg struct {
cfg map[string]any
err error
}
// NewApp creates the root TUI application model.
func NewApp(port int, secretKey string, hook *LogHook) App {
standalone := hook != nil
authRequired := !standalone
ti := textinput.New()
ti.CharLimit = 512
ti.EchoMode = textinput.EchoPassword
ti.EchoCharacter = '*'
ti.SetValue(strings.TrimSpace(secretKey))
ti.Focus()
client := NewClient(port, secretKey)
app := App{
activeTab: tabDashboard,
standalone: standalone,
logsEnabled: true,
authenticated: !authRequired,
authInput: ti,
dashboard: newDashboardModel(client),
config: newConfigTabModel(client),
auth: newAuthTabModel(client),
keys: newKeysTabModel(client),
oauth: newOAuthTabModel(client),
usage: newUsageTabModel(client),
logs: newLogsTabModel(client, hook),
client: client,
initialized: [7]bool{
tabDashboard: true,
tabLogs: true,
},
}
app.refreshTabs()
if authRequired {
app.initialized = [7]bool{}
}
app.setAuthInputPrompt()
return app
}
func (a App) Init() tea.Cmd {
if !a.authenticated {
return textinput.Blink
}
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
cmds = append(cmds, a.logs.Init())
}
return tea.Batch(cmds...)
}
func (a App) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
switch msg := msg.(type) {
case tea.WindowSizeMsg:
a.width = msg.Width
a.height = msg.Height
a.ready = true
if a.width > 0 {
a.authInput.Width = a.width - 6
}
contentH := a.height - 4 // tab bar + status bar
if contentH < 1 {
contentH = 1
}
contentW := a.width
a.dashboard.SetSize(contentW, contentH)
a.config.SetSize(contentW, contentH)
a.auth.SetSize(contentW, contentH)
a.keys.SetSize(contentW, contentH)
a.oauth.SetSize(contentW, contentH)
a.usage.SetSize(contentW, contentH)
a.logs.SetSize(contentW, contentH)
return a, nil
case authConnectMsg:
a.authConnecting = false
if msg.err != nil {
a.authError = fmt.Sprintf(T("auth_gate_connect_fail"), msg.err.Error())
return a, nil
}
a.authError = ""
a.authenticated = true
a.logsEnabled = a.standalone || isLogsEnabledFromConfig(msg.cfg)
a.refreshTabs()
a.initialized = [7]bool{}
a.initialized[tabDashboard] = true
cmds := []tea.Cmd{a.dashboard.Init()}
if a.logsEnabled {
a.initialized[tabLogs] = true
cmds = append(cmds, a.logs.Init())
}
return a, tea.Batch(cmds...)
case configUpdateMsg:
var cmdLogs tea.Cmd
if !a.standalone && msg.err == nil && msg.path == "logging-to-file" {
logsEnabledConfig, okConfig := msg.value.(bool)
if okConfig {
logsEnabledBefore := a.logsEnabled
a.logsEnabled = logsEnabledConfig
if logsEnabledBefore != a.logsEnabled {
a.refreshTabs()
}
if !a.logsEnabled {
a.initialized[tabLogs] = false
}
if !logsEnabledBefore && a.logsEnabled {
a.initialized[tabLogs] = true
cmdLogs = a.logs.Init()
}
}
}
var cmdConfig tea.Cmd
a.config, cmdConfig = a.config.Update(msg)
if cmdConfig != nil && cmdLogs != nil {
return a, tea.Batch(cmdConfig, cmdLogs)
}
if cmdConfig != nil {
return a, cmdConfig
}
return a, cmdLogs
case tea.KeyMsg:
if !a.authenticated {
switch msg.String() {
case "ctrl+c", "q":
return a, tea.Quit
case "L":
ToggleLocale()
a.refreshTabs()
a.setAuthInputPrompt()
return a, nil
case "enter":
if a.authConnecting {
return a, nil
}
password := strings.TrimSpace(a.authInput.Value())
if password == "" {
a.authError = T("auth_gate_password_required")
return a, nil
}
a.authError = ""
a.authConnecting = true
return a, a.connectWithPassword(password)
default:
var cmd tea.Cmd
a.authInput, cmd = a.authInput.Update(msg)
return a, cmd
}
}
switch msg.String() {
case "ctrl+c":
return a, tea.Quit
case "q":
// Only quit if not in logs tab (where 'q' might be useful)
if !a.logsEnabled || a.activeTab != tabLogs {
return a, tea.Quit
}
case "L":
ToggleLocale()
a.refreshTabs()
return a.broadcastToAllTabs(localeChangedMsg{})
case "tab":
if len(a.tabs) == 0 {
return a, nil
}
prevTab := a.activeTab
a.activeTab = (a.activeTab + 1) % len(a.tabs)
return a, a.initTabIfNeeded(prevTab)
case "shift+tab":
if len(a.tabs) == 0 {
return a, nil
}
prevTab := a.activeTab
a.activeTab = (a.activeTab - 1 + len(a.tabs)) % len(a.tabs)
return a, a.initTabIfNeeded(prevTab)
}
}
if !a.authenticated {
var cmd tea.Cmd
a.authInput, cmd = a.authInput.Update(msg)
return a, cmd
}
// Route msg to active tab
var cmd tea.Cmd
switch a.activeTab {
case tabDashboard:
a.dashboard, cmd = a.dashboard.Update(msg)
case tabConfig:
a.config, cmd = a.config.Update(msg)
case tabAuthFiles:
a.auth, cmd = a.auth.Update(msg)
case tabAPIKeys:
a.keys, cmd = a.keys.Update(msg)
case tabOAuth:
a.oauth, cmd = a.oauth.Update(msg)
case tabUsage:
a.usage, cmd = a.usage.Update(msg)
case tabLogs:
a.logs, cmd = a.logs.Update(msg)
}
// Keep logs polling alive even when logs tab is not active.
if a.logsEnabled && a.activeTab != tabLogs {
switch msg.(type) {
case logsPollMsg, logsTickMsg, logLineMsg:
var logCmd tea.Cmd
a.logs, logCmd = a.logs.Update(msg)
if logCmd != nil {
cmd = logCmd
}
}
}
return a, cmd
}
// localeChangedMsg is broadcast to all tabs when the user toggles locale.
type localeChangedMsg struct{}
func (a *App) refreshTabs() {
names := TabNames()
if a.logsEnabled {
a.tabs = names
} else {
filtered := make([]string, 0, len(names)-1)
for idx, name := range names {
if idx == tabLogs {
continue
}
filtered = append(filtered, name)
}
a.tabs = filtered
}
if len(a.tabs) == 0 {
a.activeTab = tabDashboard
return
}
if a.activeTab >= len(a.tabs) {
a.activeTab = len(a.tabs) - 1
}
}
func (a *App) initTabIfNeeded(_ int) tea.Cmd {
if a.initialized[a.activeTab] {
return nil
}
a.initialized[a.activeTab] = true
switch a.activeTab {
case tabDashboard:
return a.dashboard.Init()
case tabConfig:
return a.config.Init()
case tabAuthFiles:
return a.auth.Init()
case tabAPIKeys:
return a.keys.Init()
case tabOAuth:
return a.oauth.Init()
case tabUsage:
return a.usage.Init()
case tabLogs:
if !a.logsEnabled {
return nil
}
return a.logs.Init()
}
return nil
}
func (a App) View() string {
if !a.authenticated {
return a.renderAuthView()
}
if !a.ready {
return T("initializing_tui")
}
var sb strings.Builder
// Tab bar
sb.WriteString(a.renderTabBar())
sb.WriteString("\n")
// Content
switch a.activeTab {
case tabDashboard:
sb.WriteString(a.dashboard.View())
case tabConfig:
sb.WriteString(a.config.View())
case tabAuthFiles:
sb.WriteString(a.auth.View())
case tabAPIKeys:
sb.WriteString(a.keys.View())
case tabOAuth:
sb.WriteString(a.oauth.View())
case tabUsage:
sb.WriteString(a.usage.View())
case tabLogs:
if a.logsEnabled {
sb.WriteString(a.logs.View())
}
}
// Status bar
sb.WriteString("\n")
sb.WriteString(a.renderStatusBar())
return sb.String()
}
func (a App) renderAuthView() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("auth_gate_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_gate_help")))
sb.WriteString("\n\n")
if a.authConnecting {
sb.WriteString(warningStyle.Render(T("auth_gate_connecting")))
sb.WriteString("\n\n")
}
if strings.TrimSpace(a.authError) != "" {
sb.WriteString(errorStyle.Render(a.authError))
sb.WriteString("\n\n")
}
sb.WriteString(a.authInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_gate_enter")))
return sb.String()
}
func (a App) renderTabBar() string {
var tabs []string
for i, name := range a.tabs {
if i == a.activeTab {
tabs = append(tabs, tabActiveStyle.Render(name))
} else {
tabs = append(tabs, tabInactiveStyle.Render(name))
}
}
tabBar := lipgloss.JoinHorizontal(lipgloss.Top, tabs...)
return tabBarStyle.Width(a.width).Render(tabBar)
}
func (a App) renderStatusBar() string {
left := strings.TrimRight(T("status_left"), " ")
right := strings.TrimRight(T("status_right"), " ")
width := a.width
if width < 1 {
width = 1
}
// statusBarStyle has left/right padding(1), so content area is width-2.
contentWidth := width - 2
if contentWidth < 0 {
contentWidth = 0
}
if lipgloss.Width(left) > contentWidth {
left = fitStringWidth(left, contentWidth)
right = ""
}
remaining := contentWidth - lipgloss.Width(left)
if remaining < 0 {
remaining = 0
}
if lipgloss.Width(right) > remaining {
right = fitStringWidth(right, remaining)
}
gap := contentWidth - lipgloss.Width(left) - lipgloss.Width(right)
if gap < 0 {
gap = 0
}
return statusBarStyle.Width(width).Render(left + strings.Repeat(" ", gap) + right)
}
func fitStringWidth(text string, maxWidth int) string {
if maxWidth <= 0 {
return ""
}
if lipgloss.Width(text) <= maxWidth {
return text
}
out := ""
for _, r := range text {
next := out + string(r)
if lipgloss.Width(next) > maxWidth {
break
}
out = next
}
return out
}
func isLogsEnabledFromConfig(cfg map[string]any) bool {
if cfg == nil {
return true
}
value, ok := cfg["logging-to-file"]
if !ok {
return true
}
enabled, ok := value.(bool)
if !ok {
return true
}
return enabled
}
func (a *App) setAuthInputPrompt() {
if a == nil {
return
}
a.authInput.Prompt = fmt.Sprintf(" %s: ", T("auth_gate_password"))
}
func (a App) connectWithPassword(password string) tea.Cmd {
return func() tea.Msg {
a.client.SetSecretKey(password)
cfg, errGetConfig := a.client.GetConfig()
return authConnectMsg{cfg: cfg, err: errGetConfig}
}
}
// Run starts the TUI application.
// output specifies where bubbletea renders. If nil, defaults to os.Stdout.
func Run(port int, secretKey string, hook *LogHook, output io.Writer) error {
if output == nil {
output = os.Stdout
}
app := NewApp(port, secretKey, hook)
p := tea.NewProgram(app, tea.WithAltScreen(), tea.WithOutput(output))
_, err := p.Run()
return err
}
func (a App) broadcastToAllTabs(msg tea.Msg) (tea.Model, tea.Cmd) {
var cmds []tea.Cmd
var cmd tea.Cmd
a.dashboard, cmd = a.dashboard.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.config, cmd = a.config.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.auth, cmd = a.auth.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.keys, cmd = a.keys.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.oauth, cmd = a.oauth.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.usage, cmd = a.usage.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
a.logs, cmd = a.logs.Update(msg)
if cmd != nil {
cmds = append(cmds, cmd)
}
return a, tea.Batch(cmds...)
}
================================================
FILE: internal/tui/auth_tab.go
================================================
package tui
import (
"fmt"
"strconv"
"strings"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// editableField represents an editable field on an auth file.
type editableField struct {
label string
key string // API field key: "prefix", "proxy_url", "priority"
}
var authEditableFields = []editableField{
{label: "Prefix", key: "prefix"},
{label: "Proxy URL", key: "proxy_url"},
{label: "Priority", key: "priority"},
}
// authTabModel displays auth credential files with interactive management.
type authTabModel struct {
client *Client
viewport viewport.Model
files []map[string]any
err error
width int
height int
ready bool
cursor int
expanded int // -1 = none expanded, >=0 = expanded index
confirm int // -1 = no confirmation, >=0 = confirm delete for index
status string
// Editing state
editing bool // true when editing a field
editField int // index into authEditableFields
editInput textinput.Model // text input for editing
editFileName string // name of file being edited
}
type authFilesMsg struct {
files []map[string]any
err error
}
type authActionMsg struct {
action string // "deleted", "toggled", "updated"
err error
}
func newAuthTabModel(client *Client) authTabModel {
ti := textinput.New()
ti.CharLimit = 256
return authTabModel{
client: client,
expanded: -1,
confirm: -1,
editInput: ti,
}
}
func (m authTabModel) Init() tea.Cmd {
return m.fetchFiles
}
func (m authTabModel) fetchFiles() tea.Msg {
files, err := m.client.GetAuthFiles()
return authFilesMsg{files: files, err: err}
}
func (m authTabModel) Update(msg tea.Msg) (authTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case authFilesMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.files = msg.files
if m.cursor >= len(m.files) {
m.cursor = max(0, len(m.files)-1)
}
m.status = ""
}
m.viewport.SetContent(m.renderContent())
return m, nil
case authActionMsg:
if msg.err != nil {
m.status = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.status = successStyle.Render("✓ " + msg.action)
}
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, m.fetchFiles
case tea.KeyMsg:
// ---- Editing mode ----
if m.editing {
return m.handleEditInput(msg)
}
// ---- Delete confirmation mode ----
if m.confirm >= 0 {
return m.handleConfirmInput(msg)
}
// ---- Normal mode ----
return m.handleNormalInput(msg)
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
// startEdit activates inline editing for a field on the currently selected auth file.
func (m *authTabModel) startEdit(fieldIdx int) tea.Cmd {
if m.cursor >= len(m.files) {
return nil
}
f := m.files[m.cursor]
m.editFileName = getString(f, "name")
m.editField = fieldIdx
m.editing = true
// Pre-populate with current value
key := authEditableFields[fieldIdx].key
currentVal := getAnyString(f, key)
m.editInput.SetValue(currentVal)
m.editInput.Focus()
m.editInput.Prompt = fmt.Sprintf(" %s: ", authEditableFields[fieldIdx].label)
m.viewport.SetContent(m.renderContent())
return textinput.Blink
}
func (m *authTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.editInput.Width = w - 20
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m authTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m authTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("auth_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_help1")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("auth_help2")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if len(m.files) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_auth_files")))
sb.WriteString("\n")
return sb.String()
}
for i, f := range m.files {
name := getString(f, "name")
channel := getString(f, "channel")
email := getString(f, "email")
disabled := getBool(f, "disabled")
statusIcon := successStyle.Render("●")
statusText := T("status_active")
if disabled {
statusIcon = lipgloss.NewStyle().Foreground(colorMuted).Render("○")
statusText = T("status_disabled")
}
cursor := " "
rowStyle := lipgloss.NewStyle()
if i == m.cursor {
cursor = "▸ "
rowStyle = lipgloss.NewStyle().Bold(true)
}
displayName := name
if len(displayName) > 24 {
displayName = displayName[:21] + "..."
}
displayEmail := email
if len(displayEmail) > 28 {
displayEmail = displayEmail[:25] + "..."
}
row := fmt.Sprintf("%s%s %-24s %-12s %-28s %s",
cursor, statusIcon, displayName, channel, displayEmail, statusText)
sb.WriteString(rowStyle.Render(row))
sb.WriteString("\n")
// Delete confirmation
if m.confirm == i {
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete"), name)))
sb.WriteString("\n")
}
// Inline edit input
if m.editing && i == m.cursor {
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(" " + T("enter_save") + " • " + T("esc_cancel")))
sb.WriteString("\n")
}
// Expanded detail view
if m.expanded == i {
sb.WriteString(m.renderDetail(f))
}
}
if m.status != "" {
sb.WriteString("\n")
sb.WriteString(m.status)
sb.WriteString("\n")
}
return sb.String()
}
func (m authTabModel) renderDetail(f map[string]any) string {
var sb strings.Builder
labelStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("111")).
Bold(true)
valueStyle := lipgloss.NewStyle().
Foreground(lipgloss.Color("252"))
editableMarker := lipgloss.NewStyle().
Foreground(lipgloss.Color("214")).
Render(" ✎")
sb.WriteString(" ┌─────────────────────────────────────────────\n")
fields := []struct {
label string
key string
editable bool
}{
{"Name", "name", false},
{"Channel", "channel", false},
{"Email", "email", false},
{"Status", "status", false},
{"Status Msg", "status_message", false},
{"File Name", "file_name", false},
{"Auth Type", "auth_type", false},
{"Prefix", "prefix", true},
{"Proxy URL", "proxy_url", true},
{"Priority", "priority", true},
{"Project ID", "project_id", false},
{"Disabled", "disabled", false},
{"Created", "created_at", false},
{"Updated", "updated_at", false},
}
for _, field := range fields {
val := getAnyString(f, field.key)
if val == "" || val == "" {
if field.editable {
val = T("not_set")
} else {
continue
}
}
editMark := ""
if field.editable {
editMark = editableMarker
}
line := fmt.Sprintf(" │ %s %s%s",
labelStyle.Render(fmt.Sprintf("%-12s:", field.label)),
valueStyle.Render(val),
editMark)
sb.WriteString(line)
sb.WriteString("\n")
}
sb.WriteString(" └─────────────────────────────────────────────\n")
return sb.String()
}
// getAnyString converts any value to its string representation.
func getAnyString(m map[string]any, key string) string {
v, ok := m[key]
if !ok || v == nil {
return ""
}
return fmt.Sprintf("%v", v)
}
func max(a, b int) int {
if a > b {
return a
}
return b
}
func (m authTabModel) handleEditInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "enter":
value := m.editInput.Value()
fieldKey := authEditableFields[m.editField].key
fileName := m.editFileName
m.editing = false
m.editInput.Blur()
fields := map[string]any{}
if fieldKey == "priority" {
p, err := strconv.Atoi(value)
if err != nil {
return m, func() tea.Msg {
return authActionMsg{err: fmt.Errorf("%s: %s", T("invalid_int"), value)}
}
}
fields[fieldKey] = p
} else {
fields[fieldKey] = value
}
return m, func() tea.Msg {
err := m.client.PatchAuthFileFields(fileName, fields)
if err != nil {
return authActionMsg{err: err}
}
return authActionMsg{action: fmt.Sprintf(T("updated_field"), fieldKey, fileName)}
}
case "esc":
m.editing = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.editInput, cmd = m.editInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
func (m authTabModel) handleConfirmInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "y", "Y":
idx := m.confirm
m.confirm = -1
if idx < len(m.files) {
name := getString(m.files[idx], "name")
return m, func() tea.Msg {
err := m.client.DeleteAuthFile(name)
if err != nil {
return authActionMsg{err: err}
}
return authActionMsg{action: fmt.Sprintf(T("deleted"), name)}
}
}
m.viewport.SetContent(m.renderContent())
return m, nil
case "n", "N", "esc":
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, nil
}
return m, nil
}
func (m authTabModel) handleNormalInput(msg tea.KeyMsg) (authTabModel, tea.Cmd) {
switch msg.String() {
case "j", "down":
if len(m.files) > 0 {
m.cursor = (m.cursor + 1) % len(m.files)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "k", "up":
if len(m.files) > 0 {
m.cursor = (m.cursor - 1 + len(m.files)) % len(m.files)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "enter", " ":
if m.expanded == m.cursor {
m.expanded = -1
} else {
m.expanded = m.cursor
}
m.viewport.SetContent(m.renderContent())
return m, nil
case "d", "D":
if m.cursor < len(m.files) {
m.confirm = m.cursor
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "e", "E":
if m.cursor < len(m.files) {
f := m.files[m.cursor]
name := getString(f, "name")
disabled := getBool(f, "disabled")
newDisabled := !disabled
return m, func() tea.Msg {
err := m.client.ToggleAuthFile(name, newDisabled)
if err != nil {
return authActionMsg{err: err}
}
action := T("enabled")
if newDisabled {
action = T("disabled")
}
return authActionMsg{action: fmt.Sprintf("%s %s", action, name)}
}
}
return m, nil
case "1":
return m, m.startEdit(0) // prefix
case "2":
return m, m.startEdit(1) // proxy_url
case "3":
return m, m.startEdit(2) // priority
case "r":
m.status = ""
return m, m.fetchFiles
default:
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
}
================================================
FILE: internal/tui/browser.go
================================================
package tui
import (
"os/exec"
"runtime"
)
// openBrowser opens the specified URL in the user's default browser.
func openBrowser(url string) error {
switch runtime.GOOS {
case "darwin":
return exec.Command("open", url).Start()
case "linux":
return exec.Command("xdg-open", url).Start()
case "windows":
return exec.Command("rundll32", "url.dll,FileProtocolHandler", url).Start()
default:
return exec.Command("xdg-open", url).Start()
}
}
================================================
FILE: internal/tui/client.go
================================================
package tui
import (
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strconv"
"strings"
"time"
)
// Client wraps HTTP calls to the management API.
type Client struct {
baseURL string
secretKey string
http *http.Client
}
// NewClient creates a new management API client.
func NewClient(port int, secretKey string) *Client {
return &Client{
baseURL: fmt.Sprintf("http://127.0.0.1:%d", port),
secretKey: strings.TrimSpace(secretKey),
http: &http.Client{
Timeout: 10 * time.Second,
},
}
}
// SetSecretKey updates management API bearer token used by this client.
func (c *Client) SetSecretKey(secretKey string) {
c.secretKey = strings.TrimSpace(secretKey)
}
func (c *Client) doRequest(method, path string, body io.Reader) ([]byte, int, error) {
url := c.baseURL + path
req, err := http.NewRequest(method, url, body)
if err != nil {
return nil, 0, err
}
if c.secretKey != "" {
req.Header.Set("Authorization", "Bearer "+c.secretKey)
}
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.http.Do(req)
if err != nil {
return nil, 0, err
}
defer resp.Body.Close()
data, err := io.ReadAll(resp.Body)
if err != nil {
return nil, resp.StatusCode, err
}
return data, resp.StatusCode, nil
}
func (c *Client) get(path string) ([]byte, error) {
data, code, err := c.doRequest("GET", path, nil)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
func (c *Client) put(path string, body io.Reader) ([]byte, error) {
data, code, err := c.doRequest("PUT", path, body)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
func (c *Client) patch(path string, body io.Reader) ([]byte, error) {
data, code, err := c.doRequest("PATCH", path, body)
if err != nil {
return nil, err
}
if code >= 400 {
return nil, fmt.Errorf("HTTP %d: %s", code, strings.TrimSpace(string(data)))
}
return data, nil
}
// getJSON fetches a path and unmarshals JSON into a generic map.
func (c *Client) getJSON(path string) (map[string]any, error) {
data, err := c.get(path)
if err != nil {
return nil, err
}
var result map[string]any
if err := json.Unmarshal(data, &result); err != nil {
return nil, err
}
return result, nil
}
// postJSON sends a JSON body via POST and checks for errors.
func (c *Client) postJSON(path string, body any) error {
jsonBody, err := json.Marshal(body)
if err != nil {
return err
}
_, code, err := c.doRequest("POST", path, strings.NewReader(string(jsonBody)))
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("HTTP %d", code)
}
return nil
}
// GetConfig fetches the parsed config.
func (c *Client) GetConfig() (map[string]any, error) {
return c.getJSON("/v0/management/config")
}
// GetConfigYAML fetches the raw config.yaml content.
func (c *Client) GetConfigYAML() (string, error) {
data, err := c.get("/v0/management/config.yaml")
if err != nil {
return "", err
}
return string(data), nil
}
// PutConfigYAML uploads new config.yaml content.
func (c *Client) PutConfigYAML(yamlContent string) error {
_, err := c.put("/v0/management/config.yaml", strings.NewReader(yamlContent))
return err
}
// GetUsage fetches usage statistics.
func (c *Client) GetUsage() (map[string]any, error) {
return c.getJSON("/v0/management/usage")
}
// GetAuthFiles lists auth credential files.
// API returns {"files": [...]}.
func (c *Client) GetAuthFiles() ([]map[string]any, error) {
wrapper, err := c.getJSON("/v0/management/auth-files")
if err != nil {
return nil, err
}
return extractList(wrapper, "files")
}
// DeleteAuthFile deletes a single auth file by name.
func (c *Client) DeleteAuthFile(name string) error {
query := url.Values{}
query.Set("name", name)
path := "/v0/management/auth-files?" + query.Encode()
_, code, err := c.doRequest("DELETE", path, nil)
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("delete failed (HTTP %d)", code)
}
return nil
}
// ToggleAuthFile enables or disables an auth file.
func (c *Client) ToggleAuthFile(name string, disabled bool) error {
body, _ := json.Marshal(map[string]any{"name": name, "disabled": disabled})
_, err := c.patch("/v0/management/auth-files/status", strings.NewReader(string(body)))
return err
}
// PatchAuthFileFields updates editable fields on an auth file.
func (c *Client) PatchAuthFileFields(name string, fields map[string]any) error {
fields["name"] = name
body, _ := json.Marshal(fields)
_, err := c.patch("/v0/management/auth-files/fields", strings.NewReader(string(body)))
return err
}
// GetLogs fetches log lines from the server.
func (c *Client) GetLogs(after int64, limit int) ([]string, int64, error) {
query := url.Values{}
if limit > 0 {
query.Set("limit", strconv.Itoa(limit))
}
if after > 0 {
query.Set("after", strconv.FormatInt(after, 10))
}
path := "/v0/management/logs"
encodedQuery := query.Encode()
if encodedQuery != "" {
path += "?" + encodedQuery
}
wrapper, err := c.getJSON(path)
if err != nil {
return nil, after, err
}
lines := []string{}
if rawLines, ok := wrapper["lines"]; ok && rawLines != nil {
rawJSON, errMarshal := json.Marshal(rawLines)
if errMarshal != nil {
return nil, after, errMarshal
}
if errUnmarshal := json.Unmarshal(rawJSON, &lines); errUnmarshal != nil {
return nil, after, errUnmarshal
}
}
latest := after
if rawLatest, ok := wrapper["latest-timestamp"]; ok {
switch value := rawLatest.(type) {
case float64:
latest = int64(value)
case json.Number:
if parsed, errParse := value.Int64(); errParse == nil {
latest = parsed
}
case int64:
latest = value
case int:
latest = int64(value)
}
}
if latest < after {
latest = after
}
return lines, latest, nil
}
// GetAPIKeys fetches the list of API keys.
// API returns {"api-keys": [...]}.
func (c *Client) GetAPIKeys() ([]string, error) {
wrapper, err := c.getJSON("/v0/management/api-keys")
if err != nil {
return nil, err
}
arr, ok := wrapper["api-keys"]
if !ok {
return nil, nil
}
raw, err := json.Marshal(arr)
if err != nil {
return nil, err
}
var result []string
if err := json.Unmarshal(raw, &result); err != nil {
return nil, err
}
return result, nil
}
// AddAPIKey adds a new API key by sending old=nil, new=key which appends.
func (c *Client) AddAPIKey(key string) error {
body := map[string]any{"old": nil, "new": key}
jsonBody, _ := json.Marshal(body)
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
return err
}
// EditAPIKey replaces an API key at the given index.
func (c *Client) EditAPIKey(index int, newValue string) error {
body := map[string]any{"index": index, "value": newValue}
jsonBody, _ := json.Marshal(body)
_, err := c.patch("/v0/management/api-keys", strings.NewReader(string(jsonBody)))
return err
}
// DeleteAPIKey deletes an API key by index.
func (c *Client) DeleteAPIKey(index int) error {
_, code, err := c.doRequest("DELETE", fmt.Sprintf("/v0/management/api-keys?index=%d", index), nil)
if err != nil {
return err
}
if code >= 400 {
return fmt.Errorf("delete failed (HTTP %d)", code)
}
return nil
}
// GetGeminiKeys fetches Gemini API keys.
// API returns {"gemini-api-key": [...]}.
func (c *Client) GetGeminiKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/gemini-api-key", "gemini-api-key")
}
// GetClaudeKeys fetches Claude API keys.
func (c *Client) GetClaudeKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/claude-api-key", "claude-api-key")
}
// GetCodexKeys fetches Codex API keys.
func (c *Client) GetCodexKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/codex-api-key", "codex-api-key")
}
// GetVertexKeys fetches Vertex API keys.
func (c *Client) GetVertexKeys() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/vertex-api-key", "vertex-api-key")
}
// GetOpenAICompat fetches OpenAI compatibility entries.
func (c *Client) GetOpenAICompat() ([]map[string]any, error) {
return c.getWrappedKeyList("/v0/management/openai-compatibility", "openai-compatibility")
}
// getWrappedKeyList fetches a wrapped list from the API.
func (c *Client) getWrappedKeyList(path, key string) ([]map[string]any, error) {
wrapper, err := c.getJSON(path)
if err != nil {
return nil, err
}
return extractList(wrapper, key)
}
// extractList pulls an array of maps from a wrapper object by key.
func extractList(wrapper map[string]any, key string) ([]map[string]any, error) {
arr, ok := wrapper[key]
if !ok || arr == nil {
return nil, nil
}
raw, err := json.Marshal(arr)
if err != nil {
return nil, err
}
var result []map[string]any
if err := json.Unmarshal(raw, &result); err != nil {
return nil, err
}
return result, nil
}
// GetDebug fetches the current debug setting.
func (c *Client) GetDebug() (bool, error) {
wrapper, err := c.getJSON("/v0/management/debug")
if err != nil {
return false, err
}
if v, ok := wrapper["debug"]; ok {
if b, ok := v.(bool); ok {
return b, nil
}
}
return false, nil
}
// GetAuthStatus polls the OAuth session status.
// Returns status ("wait", "ok", "error") and optional error message.
func (c *Client) GetAuthStatus(state string) (string, string, error) {
query := url.Values{}
query.Set("state", state)
path := "/v0/management/get-auth-status?" + query.Encode()
wrapper, err := c.getJSON(path)
if err != nil {
return "", "", err
}
status := getString(wrapper, "status")
errMsg := getString(wrapper, "error")
return status, errMsg, nil
}
// ----- Config field update methods -----
// PutBoolField updates a boolean config field.
func (c *Client) PutBoolField(path string, value bool) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// PutIntField updates an integer config field.
func (c *Client) PutIntField(path string, value int) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// PutStringField updates a string config field.
func (c *Client) PutStringField(path string, value string) error {
body, _ := json.Marshal(map[string]any{"value": value})
_, err := c.put("/v0/management/"+path, strings.NewReader(string(body)))
return err
}
// DeleteField sends a DELETE request for a config field.
func (c *Client) DeleteField(path string) error {
_, _, err := c.doRequest("DELETE", "/v0/management/"+path, nil)
return err
}
================================================
FILE: internal/tui/config_tab.go
================================================
package tui
import (
"fmt"
"strconv"
"strings"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// configField represents a single editable config field.
type configField struct {
label string
apiPath string // management API path (e.g. "debug", "proxy-url")
kind string // "bool", "int", "string", "readonly"
value string // current display value
rawValue any // raw value from API
}
// configTabModel displays parsed config with interactive editing.
type configTabModel struct {
client *Client
viewport viewport.Model
fields []configField
cursor int
editing bool
textInput textinput.Model
err error
message string // status message (success/error)
width int
height int
ready bool
}
type configDataMsg struct {
config map[string]any
err error
}
type configUpdateMsg struct {
path string
value any
err error
}
func newConfigTabModel(client *Client) configTabModel {
ti := textinput.New()
ti.CharLimit = 256
return configTabModel{
client: client,
textInput: ti,
}
}
func (m configTabModel) Init() tea.Cmd {
return m.fetchConfig
}
func (m configTabModel) fetchConfig() tea.Msg {
cfg, err := m.client.GetConfig()
return configDataMsg{config: cfg, err: err}
}
func (m configTabModel) Update(msg tea.Msg) (configTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case configDataMsg:
if msg.err != nil {
m.err = msg.err
m.fields = nil
} else {
m.err = nil
m.fields = m.parseConfig(msg.config)
}
m.viewport.SetContent(m.renderContent())
return m, nil
case configUpdateMsg:
if msg.err != nil {
m.message = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.message = successStyle.Render(T("updated_ok"))
}
m.viewport.SetContent(m.renderContent())
// Refresh config from server
return m, m.fetchConfig
case tea.KeyMsg:
if m.editing {
return m.handleEditingKey(msg)
}
return m.handleNormalKey(msg)
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m configTabModel) handleNormalKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
switch msg.String() {
case "r":
m.message = ""
return m, m.fetchConfig
case "up", "k":
if m.cursor > 0 {
m.cursor--
m.viewport.SetContent(m.renderContent())
// Ensure cursor is visible
m.ensureCursorVisible()
}
return m, nil
case "down", "j":
if m.cursor < len(m.fields)-1 {
m.cursor++
m.viewport.SetContent(m.renderContent())
m.ensureCursorVisible()
}
return m, nil
case "enter", " ":
if m.cursor >= 0 && m.cursor < len(m.fields) {
f := m.fields[m.cursor]
if f.kind == "readonly" {
return m, nil
}
if f.kind == "bool" {
// Toggle directly
return m, m.toggleBool(m.cursor)
}
// Start editing for int/string
m.editing = true
m.textInput.SetValue(configFieldEditValue(f))
m.textInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
}
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m configTabModel) handleEditingKey(msg tea.KeyMsg) (configTabModel, tea.Cmd) {
switch msg.String() {
case "enter":
m.editing = false
m.textInput.Blur()
return m, m.submitEdit(m.cursor, m.textInput.Value())
case "esc":
m.editing = false
m.textInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.textInput, cmd = m.textInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
func (m configTabModel) toggleBool(idx int) tea.Cmd {
return func() tea.Msg {
f := m.fields[idx]
current := f.value == "true"
newValue := !current
errPutBool := m.client.PutBoolField(f.apiPath, newValue)
return configUpdateMsg{
path: f.apiPath,
value: newValue,
err: errPutBool,
}
}
}
func (m configTabModel) submitEdit(idx int, newValue string) tea.Cmd {
return func() tea.Msg {
f := m.fields[idx]
var err error
var value any
switch f.kind {
case "int":
valueInt, errAtoi := strconv.Atoi(newValue)
if errAtoi != nil {
return configUpdateMsg{
path: f.apiPath,
err: fmt.Errorf("%s: %s", T("invalid_int"), newValue),
}
}
value = valueInt
err = m.client.PutIntField(f.apiPath, valueInt)
case "string":
value = newValue
err = m.client.PutStringField(f.apiPath, newValue)
}
return configUpdateMsg{
path: f.apiPath,
value: value,
err: err,
}
}
}
func configFieldEditValue(f configField) string {
if rawString, ok := f.rawValue.(string); ok {
return rawString
}
return f.value
}
func (m *configTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m *configTabModel) ensureCursorVisible() {
// Each field takes ~1 line, header takes ~4 lines
targetLine := m.cursor + 5
if targetLine < m.viewport.YOffset {
m.viewport.SetYOffset(targetLine)
}
if targetLine >= m.viewport.YOffset+m.viewport.Height {
m.viewport.SetYOffset(targetLine - m.viewport.Height + 1)
}
}
func (m configTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m configTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("config_title")))
sb.WriteString("\n")
if m.message != "" {
sb.WriteString(" " + m.message)
sb.WriteString("\n")
}
sb.WriteString(helpStyle.Render(T("config_help1")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("config_help2")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render(" ⚠ Error: " + m.err.Error()))
return sb.String()
}
if len(m.fields) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_config")))
return sb.String()
}
currentSection := ""
for i, f := range m.fields {
// Section headers
section := fieldSection(f.apiPath)
if section != currentSection {
currentSection = section
sb.WriteString("\n")
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(" ── " + section + " "))
sb.WriteString("\n")
}
isSelected := i == m.cursor
prefix := " "
if isSelected {
prefix = "▸ "
}
labelStr := lipgloss.NewStyle().
Foreground(colorInfo).
Bold(isSelected).
Width(32).
Render(f.label)
var valueStr string
if m.editing && isSelected {
valueStr = m.textInput.View()
} else {
switch f.kind {
case "bool":
if f.value == "true" {
valueStr = successStyle.Render("● ON")
} else {
valueStr = lipgloss.NewStyle().Foreground(colorMuted).Render("○ OFF")
}
case "readonly":
valueStr = lipgloss.NewStyle().Foreground(colorSubtext).Render(f.value)
default:
valueStr = valueStyle.Render(f.value)
}
}
line := prefix + labelStr + " " + valueStr
if isSelected && !m.editing {
line = lipgloss.NewStyle().Background(colorSurface).Render(line)
}
sb.WriteString(line + "\n")
}
return sb.String()
}
func (m configTabModel) parseConfig(cfg map[string]any) []configField {
var fields []configField
// Server settings
fields = append(fields, configField{"Port", "port", "readonly", fmt.Sprintf("%.0f", getFloat(cfg, "port")), nil})
fields = append(fields, configField{"Host", "host", "readonly", getString(cfg, "host"), nil})
fields = append(fields, configField{"Debug", "debug", "bool", fmt.Sprintf("%v", getBool(cfg, "debug")), nil})
fields = append(fields, configField{"Proxy URL", "proxy-url", "string", getString(cfg, "proxy-url"), nil})
fields = append(fields, configField{"Request Retry", "request-retry", "int", fmt.Sprintf("%.0f", getFloat(cfg, "request-retry")), nil})
fields = append(fields, configField{"Max Retry Interval (s)", "max-retry-interval", "int", fmt.Sprintf("%.0f", getFloat(cfg, "max-retry-interval")), nil})
fields = append(fields, configField{"Force Model Prefix", "force-model-prefix", "string", getString(cfg, "force-model-prefix"), nil})
// Logging
fields = append(fields, configField{"Logging to File", "logging-to-file", "bool", fmt.Sprintf("%v", getBool(cfg, "logging-to-file")), nil})
fields = append(fields, configField{"Logs Max Total Size (MB)", "logs-max-total-size-mb", "int", fmt.Sprintf("%.0f", getFloat(cfg, "logs-max-total-size-mb")), nil})
fields = append(fields, configField{"Error Logs Max Files", "error-logs-max-files", "int", fmt.Sprintf("%.0f", getFloat(cfg, "error-logs-max-files")), nil})
fields = append(fields, configField{"Usage Stats Enabled", "usage-statistics-enabled", "bool", fmt.Sprintf("%v", getBool(cfg, "usage-statistics-enabled")), nil})
fields = append(fields, configField{"Request Log", "request-log", "bool", fmt.Sprintf("%v", getBool(cfg, "request-log")), nil})
// Quota exceeded
fields = append(fields, configField{"Switch Project on Quota", "quota-exceeded/switch-project", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-project")), nil})
fields = append(fields, configField{"Switch Preview Model", "quota-exceeded/switch-preview-model", "bool", fmt.Sprintf("%v", getBoolNested(cfg, "quota-exceeded", "switch-preview-model")), nil})
// Routing
if routing, ok := cfg["routing"].(map[string]any); ok {
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", getString(routing, "strategy"), nil})
} else {
fields = append(fields, configField{"Routing Strategy", "routing/strategy", "string", "", nil})
}
// WebSocket auth
fields = append(fields, configField{"WebSocket Auth", "ws-auth", "bool", fmt.Sprintf("%v", getBool(cfg, "ws-auth")), nil})
// AMP settings
if amp, ok := cfg["ampcode"].(map[string]any); ok {
upstreamURL := getString(amp, "upstream-url")
upstreamAPIKey := getString(amp, "upstream-api-key")
fields = append(fields, configField{"AMP Upstream URL", "ampcode/upstream-url", "string", upstreamURL, upstreamURL})
fields = append(fields, configField{"AMP Upstream API Key", "ampcode/upstream-api-key", "string", maskIfNotEmpty(upstreamAPIKey), upstreamAPIKey})
fields = append(fields, configField{"AMP Restrict Mgmt Localhost", "ampcode/restrict-management-to-localhost", "bool", fmt.Sprintf("%v", getBool(amp, "restrict-management-to-localhost")), nil})
}
return fields
}
func fieldSection(apiPath string) string {
if strings.HasPrefix(apiPath, "ampcode/") {
return T("section_ampcode")
}
if strings.HasPrefix(apiPath, "quota-exceeded/") {
return T("section_quota")
}
if strings.HasPrefix(apiPath, "routing/") {
return T("section_routing")
}
switch apiPath {
case "port", "host", "debug", "proxy-url", "request-retry", "max-retry-interval", "force-model-prefix":
return T("section_server")
case "logging-to-file", "logs-max-total-size-mb", "error-logs-max-files", "usage-statistics-enabled", "request-log":
return T("section_logging")
case "ws-auth":
return T("section_websocket")
default:
return T("section_other")
}
}
func getBoolNested(m map[string]any, keys ...string) bool {
current := m
for i, key := range keys {
if i == len(keys)-1 {
return getBool(current, key)
}
if nested, ok := current[key].(map[string]any); ok {
current = nested
} else {
return false
}
}
return false
}
func maskIfNotEmpty(s string) string {
if s == "" {
return T("not_set")
}
return maskKey(s)
}
================================================
FILE: internal/tui/dashboard.go
================================================
package tui
import (
"encoding/json"
"fmt"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// dashboardModel displays server info, stats cards, and config overview.
type dashboardModel struct {
client *Client
viewport viewport.Model
content string
err error
width int
height int
ready bool
// Cached data for re-rendering on locale change
lastConfig map[string]any
lastUsage map[string]any
lastAuthFiles []map[string]any
lastAPIKeys []string
}
type dashboardDataMsg struct {
config map[string]any
usage map[string]any
authFiles []map[string]any
apiKeys []string
err error
}
func newDashboardModel(client *Client) dashboardModel {
return dashboardModel{
client: client,
}
}
func (m dashboardModel) Init() tea.Cmd {
return m.fetchData
}
func (m dashboardModel) fetchData() tea.Msg {
cfg, cfgErr := m.client.GetConfig()
usage, usageErr := m.client.GetUsage()
authFiles, authErr := m.client.GetAuthFiles()
apiKeys, keysErr := m.client.GetAPIKeys()
var err error
for _, e := range []error{cfgErr, usageErr, authErr, keysErr} {
if e != nil {
err = e
break
}
}
return dashboardDataMsg{config: cfg, usage: usage, authFiles: authFiles, apiKeys: apiKeys, err: err}
}
func (m dashboardModel) Update(msg tea.Msg) (dashboardModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
// Re-render immediately with cached data using new locale
m.content = m.renderDashboard(m.lastConfig, m.lastUsage, m.lastAuthFiles, m.lastAPIKeys)
m.viewport.SetContent(m.content)
// Also fetch fresh data in background
return m, m.fetchData
case dashboardDataMsg:
if msg.err != nil {
m.err = msg.err
m.content = errorStyle.Render("⚠ Error: " + msg.err.Error())
} else {
m.err = nil
// Cache data for locale switching
m.lastConfig = msg.config
m.lastUsage = msg.usage
m.lastAuthFiles = msg.authFiles
m.lastAPIKeys = msg.apiKeys
m.content = m.renderDashboard(msg.config, msg.usage, msg.authFiles, msg.apiKeys)
}
m.viewport.SetContent(m.content)
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *dashboardModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.content)
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m dashboardModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m dashboardModel) renderDashboard(cfg, usage map[string]any, authFiles []map[string]any, apiKeys []string) string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("dashboard_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("dashboard_help")))
sb.WriteString("\n\n")
// ━━━ Connection Status ━━━
connStyle := lipgloss.NewStyle().Bold(true).Foreground(colorSuccess)
sb.WriteString(connStyle.Render(T("connected")))
sb.WriteString(fmt.Sprintf(" %s", m.client.baseURL))
sb.WriteString("\n\n")
// ━━━ Stats Cards ━━━
cardWidth := 25
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 18 {
cardWidth = 18
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(2)
// Card 1: API Keys
keyCount := len(apiKeys)
card1 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("🔑 %d", keyCount)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("mgmt_keys")),
))
// Card 2: Auth Files
authCount := len(authFiles)
activeAuth := 0
for _, f := range authFiles {
if !getBool(f, "disabled") {
activeAuth++
}
}
card2 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("📄 %d", authCount)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (%d %s)", T("auth_files_label"), activeAuth, T("active_suffix"))),
))
// Card 3: Total Requests
totalReqs := int64(0)
successReqs := int64(0)
failedReqs := int64(0)
totalTokens := int64(0)
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
totalReqs = int64(getFloat(usageMap, "total_requests"))
successReqs = int64(getFloat(usageMap, "success_count"))
failedReqs = int64(getFloat(usageMap, "failure_count"))
totalTokens = int64(getFloat(usageMap, "total_tokens"))
}
}
card3 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(fmt.Sprintf("📈 %d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s (✓%d ✗%d)", T("total_requests"), successReqs, failedReqs)),
))
// Card 4: Total Tokens
tokenStr := formatLargeNumber(totalTokens)
card4 := cardStyle.Render(fmt.Sprintf(
"%s\n%s",
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("🔤 %s", tokenStr)),
lipgloss.NewStyle().Foreground(colorMuted).Render(T("total_tokens")),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Current Config ━━━
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("current_config")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
if cfg != nil {
debug := getBool(cfg, "debug")
retry := getFloat(cfg, "request-retry")
proxyURL := getString(cfg, "proxy-url")
loggingToFile := getBool(cfg, "logging-to-file")
usageEnabled := true
if v, ok := cfg["usage-statistics-enabled"]; ok {
if b, ok2 := v.(bool); ok2 {
usageEnabled = b
}
}
configItems := []struct {
label string
value string
}{
{T("debug_mode"), boolEmoji(debug)},
{T("usage_stats"), boolEmoji(usageEnabled)},
{T("log_to_file"), boolEmoji(loggingToFile)},
{T("retry_count"), fmt.Sprintf("%.0f", retry)},
}
if proxyURL != "" {
configItems = append(configItems, struct {
label string
value string
}{T("proxy_url"), proxyURL})
}
// Render config items as a compact row
for _, item := range configItems {
sb.WriteString(fmt.Sprintf(" %s %s\n",
labelStyle.Render(item.label+":"),
valueStyle.Render(item.value)))
}
// Routing strategy
strategy := "round-robin"
if routing, ok := cfg["routing"].(map[string]any); ok {
if s := getString(routing, "strategy"); s != "" {
strategy = s
}
}
sb.WriteString(fmt.Sprintf(" %s %s\n",
labelStyle.Render(T("routing_strategy")+":"),
valueStyle.Render(strategy)))
}
sb.WriteString("\n")
// ━━━ Per-Model Usage ━━━
if usage != nil {
if usageMap, ok := usage["usage"].(map[string]any); ok {
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("model_stats")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-40s %10s %12s", T("model"), T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for _, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
reqs := int64(getFloat(stats, "total_requests"))
toks := int64(getFloat(stats, "total_tokens"))
row := fmt.Sprintf(" %-40s %10d %12s", truncate(model, 40), reqs, formatLargeNumber(toks))
sb.WriteString(tableCellStyle.Render(row))
sb.WriteString("\n")
}
}
}
}
}
}
}
}
return sb.String()
}
func formatKV(key, value string) string {
return fmt.Sprintf(" %s %s\n", labelStyle.Render(key+":"), valueStyle.Render(value))
}
func getString(m map[string]any, key string) string {
if v, ok := m[key]; ok {
if s, ok := v.(string); ok {
return s
}
}
return ""
}
func getFloat(m map[string]any, key string) float64 {
if v, ok := m[key]; ok {
switch n := v.(type) {
case float64:
return n
case json.Number:
f, _ := n.Float64()
return f
}
}
return 0
}
func getBool(m map[string]any, key string) bool {
if v, ok := m[key]; ok {
if b, ok := v.(bool); ok {
return b
}
}
return false
}
func boolEmoji(b bool) string {
if b {
return T("bool_yes")
}
return T("bool_no")
}
func formatLargeNumber(n int64) string {
if n >= 1_000_000 {
return fmt.Sprintf("%.1fM", float64(n)/1_000_000)
}
if n >= 1_000 {
return fmt.Sprintf("%.1fK", float64(n)/1_000)
}
return fmt.Sprintf("%d", n)
}
func truncate(s string, maxLen int) string {
if len(s) > maxLen {
return s[:maxLen-3] + "..."
}
return s
}
func minInt(a, b int) int {
if a < b {
return a
}
return b
}
================================================
FILE: internal/tui/i18n.go
================================================
package tui
// i18n provides a simple internationalization system for the TUI.
// Supported locales: "zh" (Chinese, default), "en" (English).
var currentLocale = "en"
// SetLocale changes the active locale.
func SetLocale(locale string) {
if _, ok := locales[locale]; ok {
currentLocale = locale
}
}
// CurrentLocale returns the active locale code.
func CurrentLocale() string {
return currentLocale
}
// ToggleLocale switches between zh and en.
func ToggleLocale() {
if currentLocale == "zh" {
currentLocale = "en"
} else {
currentLocale = "zh"
}
}
// T returns the translated string for the given key.
func T(key string) string {
if m, ok := locales[currentLocale]; ok {
if v, ok := m[key]; ok {
return v
}
}
// Fallback to English
if m, ok := locales["en"]; ok {
if v, ok := m[key]; ok {
return v
}
}
return key
}
var locales = map[string]map[string]string{
"zh": zhStrings,
"en": enStrings,
}
// ──────────────────────────────────────────
// Tab names
// ──────────────────────────────────────────
var zhTabNames = []string{"仪表盘", "配置", "认证文件", "API 密钥", "OAuth", "使用统计", "日志"}
var enTabNames = []string{"Dashboard", "Config", "Auth Files", "API Keys", "OAuth", "Usage", "Logs"}
// TabNames returns tab names in the current locale.
func TabNames() []string {
if currentLocale == "zh" {
return zhTabNames
}
return enTabNames
}
var zhStrings = map[string]string{
// ── Common ──
"loading": "加载中...",
"refresh": "刷新",
"save": "保存",
"cancel": "取消",
"confirm": "确认",
"yes": "是",
"no": "否",
"error": "错误",
"success": "成功",
"navigate": "导航",
"scroll": "滚动",
"enter_save": "Enter: 保存",
"esc_cancel": "Esc: 取消",
"enter_submit": "Enter: 提交",
"press_r": "[r] 刷新",
"press_scroll": "[↑↓] 滚动",
"not_set": "(未设置)",
"error_prefix": "⚠ 错误: ",
// ── Status bar ──
"status_left": " CLIProxyAPI 管理终端",
"status_right": "Tab/Shift+Tab: 切换 • L: 语言 • q/Ctrl+C: 退出 ",
"initializing_tui": "正在初始化...",
"auth_gate_title": "🔐 连接管理 API",
"auth_gate_help": " 请输入管理密码并按 Enter 连接",
"auth_gate_password": "密码",
"auth_gate_enter": " Enter: 连接 • q/Ctrl+C: 退出 • L: 语言",
"auth_gate_connecting": "正在连接...",
"auth_gate_connect_fail": "连接失败:%s",
"auth_gate_password_required": "请输入密码",
// ── Dashboard ──
"dashboard_title": "📊 仪表盘",
"dashboard_help": " [r] 刷新 • [↑↓] 滚动",
"connected": "● 已连接",
"mgmt_keys": "管理密钥",
"auth_files_label": "认证文件",
"active_suffix": "活跃",
"total_requests": "请求",
"success_label": "成功",
"failure_label": "失败",
"total_tokens": "总 Tokens",
"current_config": "当前配置",
"debug_mode": "启用调试模式",
"usage_stats": "启用使用统计",
"log_to_file": "启用日志记录到文件",
"retry_count": "重试次数",
"proxy_url": "代理 URL",
"routing_strategy": "路由策略",
"model_stats": "模型统计",
"model": "模型",
"requests": "请求数",
"tokens": "Tokens",
"bool_yes": "是 ✓",
"bool_no": "否",
// ── Config ──
"config_title": "⚙ 配置",
"config_help1": " [↑↓/jk] 导航 • [Enter/Space] 编辑 • [r] 刷新",
"config_help2": " 布尔: Enter 切换 • 文本/数字: Enter 输入, Enter 确认, Esc 取消",
"updated_ok": "✓ 更新成功",
"no_config": " 未加载配置",
"invalid_int": "无效整数",
"section_server": "服务器",
"section_logging": "日志与统计",
"section_quota": "配额超限处理",
"section_routing": "路由",
"section_websocket": "WebSocket",
"section_ampcode": "AMP Code",
"section_other": "其他",
// ── Auth Files ──
"auth_title": "🔑 认证文件",
"auth_help1": " [↑↓/jk] 导航 • [Enter] 展开 • [e] 启用/停用 • [d] 删除 • [r] 刷新",
"auth_help2": " [1] 编辑 prefix • [2] 编辑 proxy_url • [3] 编辑 priority",
"no_auth_files": " 无认证文件",
"confirm_delete": "⚠ 删除 %s? [y/n]",
"deleted": "已删除 %s",
"enabled": "已启用",
"disabled": "已停用",
"updated_field": "已更新 %s 的 %s",
"status_active": "活跃",
"status_disabled": "已停用",
// ── API Keys ──
"keys_title": "🔐 API 密钥",
"keys_help": " [↑↓/jk] 导航 • [a] 添加 • [e] 编辑 • [d] 删除 • [c] 复制 • [r] 刷新",
"no_keys": " 无 API Key,按 [a] 添加",
"access_keys": "Access API Keys",
"confirm_delete_key": "⚠ 确认删除 %s? [y/n]",
"key_added": "已添加 API Key",
"key_updated": "已更新 API Key",
"key_deleted": "已删除 API Key",
"copied": "✓ 已复制到剪贴板",
"copy_failed": "✗ 复制失败",
"new_key_prompt": " New Key: ",
"edit_key_prompt": " Edit Key: ",
"enter_add": " Enter: 添加 • Esc: 取消",
"enter_save_esc": " Enter: 保存 • Esc: 取消",
// ── OAuth ──
"oauth_title": "🔐 OAuth 登录",
"oauth_select": " 选择提供商并按 [Enter] 开始 OAuth 登录:",
"oauth_help": " [↑↓/jk] 导航 • [Enter] 登录 • [Esc] 清除状态",
"oauth_initiating": "⏳ 正在初始化 %s 登录...",
"oauth_success": "认证成功! 请刷新 Auth Files 标签查看新凭证。",
"oauth_completed": "认证流程已完成。",
"oauth_failed": "认证失败",
"oauth_timeout": "OAuth 流程超时 (5 分钟)",
"oauth_press_esc": " 按 [Esc] 取消",
"oauth_auth_url": " 授权链接:",
"oauth_remote_hint": " 远程浏览器模式:在浏览器中打开上述链接完成授权后,将回调 URL 粘贴到下方。",
"oauth_callback_url": " 回调 URL:",
"oauth_press_c": " 按 [c] 输入回调 URL • [Esc] 返回",
"oauth_submitting": "⏳ 提交回调中...",
"oauth_submit_ok": "✓ 回调已提交,等待处理...",
"oauth_submit_fail": "✗ 提交回调失败",
"oauth_waiting": " 等待认证中...",
// ── Usage ──
"usage_title": "📈 使用统计",
"usage_help": " [r] 刷新 • [↑↓] 滚动",
"usage_no_data": " 使用数据不可用",
"usage_total_reqs": "总请求数",
"usage_total_tokens": "总 Token 数",
"usage_success": "成功",
"usage_failure": "失败",
"usage_total_token_l": "总Token",
"usage_rpm": "RPM",
"usage_tpm": "TPM",
"usage_req_by_hour": "请求趋势 (按小时)",
"usage_tok_by_hour": "Token 使用趋势 (按小时)",
"usage_req_by_day": "请求趋势 (按天)",
"usage_api_detail": "API 详细统计",
"usage_input": "输入",
"usage_output": "输出",
"usage_cached": "缓存",
"usage_reasoning": "思考",
// ── Logs ──
"logs_title": "📋 日志",
"logs_auto_scroll": "● 自动滚动",
"logs_paused": "○ 已暂停",
"logs_filter": "过滤",
"logs_lines": "行数",
"logs_help": " [a] 自动滚动 • [c] 清除 • [1] 全部 [2] info+ [3] warn+ [4] error • [↑↓] 滚动",
"logs_waiting": " 等待日志输出...",
}
var enStrings = map[string]string{
// ── Common ──
"loading": "Loading...",
"refresh": "Refresh",
"save": "Save",
"cancel": "Cancel",
"confirm": "Confirm",
"yes": "Yes",
"no": "No",
"error": "Error",
"success": "Success",
"navigate": "Navigate",
"scroll": "Scroll",
"enter_save": "Enter: Save",
"esc_cancel": "Esc: Cancel",
"enter_submit": "Enter: Submit",
"press_r": "[r] Refresh",
"press_scroll": "[↑↓] Scroll",
"not_set": "(not set)",
"error_prefix": "⚠ Error: ",
// ── Status bar ──
"status_left": " CLIProxyAPI Management TUI",
"status_right": "Tab/Shift+Tab: switch • L: lang • q/Ctrl+C: quit ",
"initializing_tui": "Initializing...",
"auth_gate_title": "🔐 Connect Management API",
"auth_gate_help": " Enter management password and press Enter to connect",
"auth_gate_password": "Password",
"auth_gate_enter": " Enter: connect • q/Ctrl+C: quit • L: lang",
"auth_gate_connecting": "Connecting...",
"auth_gate_connect_fail": "Connection failed: %s",
"auth_gate_password_required": "password is required",
// ── Dashboard ──
"dashboard_title": "📊 Dashboard",
"dashboard_help": " [r] Refresh • [↑↓] Scroll",
"connected": "● Connected",
"mgmt_keys": "Mgmt Keys",
"auth_files_label": "Auth Files",
"active_suffix": "active",
"total_requests": "Requests",
"success_label": "Success",
"failure_label": "Failed",
"total_tokens": "Total Tokens",
"current_config": "Current Config",
"debug_mode": "Debug Mode",
"usage_stats": "Usage Statistics",
"log_to_file": "Log to File",
"retry_count": "Retry Count",
"proxy_url": "Proxy URL",
"routing_strategy": "Routing Strategy",
"model_stats": "Model Stats",
"model": "Model",
"requests": "Requests",
"tokens": "Tokens",
"bool_yes": "Yes ✓",
"bool_no": "No",
// ── Config ──
"config_title": "⚙ Configuration",
"config_help1": " [↑↓/jk] Navigate • [Enter/Space] Edit • [r] Refresh",
"config_help2": " Bool: Enter to toggle • String/Int: Enter to type, Enter to confirm, Esc to cancel",
"updated_ok": "✓ Updated successfully",
"no_config": " No configuration loaded",
"invalid_int": "invalid integer",
"section_server": "Server",
"section_logging": "Logging & Stats",
"section_quota": "Quota Exceeded Handling",
"section_routing": "Routing",
"section_websocket": "WebSocket",
"section_ampcode": "AMP Code",
"section_other": "Other",
// ── Auth Files ──
"auth_title": "🔑 Auth Files",
"auth_help1": " [↑↓/jk] Navigate • [Enter] Expand • [e] Enable/Disable • [d] Delete • [r] Refresh",
"auth_help2": " [1] Edit prefix • [2] Edit proxy_url • [3] Edit priority",
"no_auth_files": " No auth files found",
"confirm_delete": "⚠ Delete %s? [y/n]",
"deleted": "Deleted %s",
"enabled": "Enabled",
"disabled": "Disabled",
"updated_field": "Updated %s on %s",
"status_active": "active",
"status_disabled": "disabled",
// ── API Keys ──
"keys_title": "🔐 API Keys",
"keys_help": " [↑↓/jk] Navigate • [a] Add • [e] Edit • [d] Delete • [c] Copy • [r] Refresh",
"no_keys": " No API Keys. Press [a] to add",
"access_keys": "Access API Keys",
"confirm_delete_key": "⚠ Delete %s? [y/n]",
"key_added": "API Key added",
"key_updated": "API Key updated",
"key_deleted": "API Key deleted",
"copied": "✓ Copied to clipboard",
"copy_failed": "✗ Copy failed",
"new_key_prompt": " New Key: ",
"edit_key_prompt": " Edit Key: ",
"enter_add": " Enter: Add • Esc: Cancel",
"enter_save_esc": " Enter: Save • Esc: Cancel",
// ── OAuth ──
"oauth_title": "🔐 OAuth Login",
"oauth_select": " Select a provider and press [Enter] to start OAuth login:",
"oauth_help": " [↑↓/jk] Navigate • [Enter] Login • [Esc] Clear status",
"oauth_initiating": "⏳ Initiating %s login...",
"oauth_success": "Authentication successful! Refresh Auth Files tab to see the new credential.",
"oauth_completed": "Authentication flow completed.",
"oauth_failed": "Authentication failed",
"oauth_timeout": "OAuth flow timed out (5 minutes)",
"oauth_press_esc": " Press [Esc] to cancel",
"oauth_auth_url": " Authorization URL:",
"oauth_remote_hint": " Remote browser mode: Open the URL above in browser, paste the callback URL below after authorization.",
"oauth_callback_url": " Callback URL:",
"oauth_press_c": " Press [c] to enter callback URL • [Esc] to go back",
"oauth_submitting": "⏳ Submitting callback...",
"oauth_submit_ok": "✓ Callback submitted, waiting...",
"oauth_submit_fail": "✗ Callback submission failed",
"oauth_waiting": " Waiting for authentication...",
// ── Usage ──
"usage_title": "📈 Usage Statistics",
"usage_help": " [r] Refresh • [↑↓] Scroll",
"usage_no_data": " Usage data not available",
"usage_total_reqs": "Total Requests",
"usage_total_tokens": "Total Tokens",
"usage_success": "Success",
"usage_failure": "Failed",
"usage_total_token_l": "Total Tokens",
"usage_rpm": "RPM",
"usage_tpm": "TPM",
"usage_req_by_hour": "Requests by Hour",
"usage_tok_by_hour": "Token Usage by Hour",
"usage_req_by_day": "Requests by Day",
"usage_api_detail": "API Detail Statistics",
"usage_input": "Input",
"usage_output": "Output",
"usage_cached": "Cached",
"usage_reasoning": "Reasoning",
// ── Logs ──
"logs_title": "📋 Logs",
"logs_auto_scroll": "● AUTO-SCROLL",
"logs_paused": "○ PAUSED",
"logs_filter": "Filter",
"logs_lines": "Lines",
"logs_help": " [a] Auto-scroll • [c] Clear • [1] All [2] info+ [3] warn+ [4] error • [↑↓] Scroll",
"logs_waiting": " Waiting for log output...",
}
================================================
FILE: internal/tui/keys_tab.go
================================================
package tui
import (
"fmt"
"strings"
"github.com/atotto/clipboard"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// keysTabModel displays and manages API keys.
type keysTabModel struct {
client *Client
viewport viewport.Model
keys []string
gemini []map[string]any
claude []map[string]any
codex []map[string]any
vertex []map[string]any
openai []map[string]any
err error
width int
height int
ready bool
cursor int
confirm int // -1 = no deletion pending
status string
// Editing / Adding
editing bool
adding bool
editIdx int
editInput textinput.Model
}
type keysDataMsg struct {
apiKeys []string
gemini []map[string]any
claude []map[string]any
codex []map[string]any
vertex []map[string]any
openai []map[string]any
err error
}
type keyActionMsg struct {
action string
err error
}
func newKeysTabModel(client *Client) keysTabModel {
ti := textinput.New()
ti.CharLimit = 512
ti.Prompt = " Key: "
return keysTabModel{
client: client,
confirm: -1,
editInput: ti,
}
}
func (m keysTabModel) Init() tea.Cmd {
return m.fetchKeys
}
func (m keysTabModel) fetchKeys() tea.Msg {
result := keysDataMsg{}
apiKeys, err := m.client.GetAPIKeys()
if err != nil {
result.err = err
return result
}
result.apiKeys = apiKeys
result.gemini, _ = m.client.GetGeminiKeys()
result.claude, _ = m.client.GetClaudeKeys()
result.codex, _ = m.client.GetCodexKeys()
result.vertex, _ = m.client.GetVertexKeys()
result.openai, _ = m.client.GetOpenAICompat()
return result
}
func (m keysTabModel) Update(msg tea.Msg) (keysTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case keysDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.keys = msg.apiKeys
m.gemini = msg.gemini
m.claude = msg.claude
m.codex = msg.codex
m.vertex = msg.vertex
m.openai = msg.openai
if m.cursor >= len(m.keys) {
m.cursor = max(0, len(m.keys)-1)
}
}
m.viewport.SetContent(m.renderContent())
return m, nil
case keyActionMsg:
if msg.err != nil {
m.status = errorStyle.Render("✗ " + msg.err.Error())
} else {
m.status = successStyle.Render("✓ " + msg.action)
}
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, m.fetchKeys
case tea.KeyMsg:
// ---- Editing / Adding mode ----
if m.editing || m.adding {
switch msg.String() {
case "enter":
value := strings.TrimSpace(m.editInput.Value())
if value == "" {
m.editing = false
m.adding = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
}
isAdding := m.adding
editIdx := m.editIdx
m.editing = false
m.adding = false
m.editInput.Blur()
if isAdding {
return m, func() tea.Msg {
err := m.client.AddAPIKey(value)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_added")}
}
}
return m, func() tea.Msg {
err := m.client.EditAPIKey(editIdx, value)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_updated")}
}
case "esc":
m.editing = false
m.adding = false
m.editInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.editInput, cmd = m.editInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
// ---- Delete confirmation ----
if m.confirm >= 0 {
switch msg.String() {
case "y", "Y":
idx := m.confirm
m.confirm = -1
return m, func() tea.Msg {
err := m.client.DeleteAPIKey(idx)
if err != nil {
return keyActionMsg{err: err}
}
return keyActionMsg{action: T("key_deleted")}
}
case "n", "N", "esc":
m.confirm = -1
m.viewport.SetContent(m.renderContent())
return m, nil
}
return m, nil
}
// ---- Normal mode ----
switch msg.String() {
case "j", "down":
if len(m.keys) > 0 {
m.cursor = (m.cursor + 1) % len(m.keys)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "k", "up":
if len(m.keys) > 0 {
m.cursor = (m.cursor - 1 + len(m.keys)) % len(m.keys)
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "a":
// Add new key
m.adding = true
m.editing = false
m.editInput.SetValue("")
m.editInput.Prompt = T("new_key_prompt")
m.editInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
case "e":
// Edit selected key
if m.cursor < len(m.keys) {
m.editing = true
m.adding = false
m.editIdx = m.cursor
m.editInput.SetValue(m.keys[m.cursor])
m.editInput.Prompt = T("edit_key_prompt")
m.editInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
}
return m, nil
case "d":
// Delete selected key
if m.cursor < len(m.keys) {
m.confirm = m.cursor
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "c":
// Copy selected key to clipboard
if m.cursor < len(m.keys) {
key := m.keys[m.cursor]
if err := clipboard.WriteAll(key); err != nil {
m.status = errorStyle.Render(T("copy_failed") + ": " + err.Error())
} else {
m.status = successStyle.Render(T("copied"))
}
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "r":
m.status = ""
return m, m.fetchKeys
default:
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *keysTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.editInput.Width = w - 16
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m keysTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m keysTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("keys_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("keys_help")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.err != nil {
sb.WriteString(errorStyle.Render(T("error_prefix") + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
// ━━━ Access API Keys (interactive) ━━━
sb.WriteString(tableHeaderStyle.Render(fmt.Sprintf(" %s (%d)", T("access_keys"), len(m.keys))))
sb.WriteString("\n")
if len(m.keys) == 0 {
sb.WriteString(subtitleStyle.Render(T("no_keys")))
sb.WriteString("\n")
}
for i, key := range m.keys {
cursor := " "
rowStyle := lipgloss.NewStyle()
if i == m.cursor {
cursor = "▸ "
rowStyle = lipgloss.NewStyle().Bold(true)
}
row := fmt.Sprintf("%s%d. %s", cursor, i+1, maskKey(key))
sb.WriteString(rowStyle.Render(row))
sb.WriteString("\n")
// Delete confirmation
if m.confirm == i {
sb.WriteString(warningStyle.Render(fmt.Sprintf(" "+T("confirm_delete_key"), maskKey(key))))
sb.WriteString("\n")
}
// Edit input
if m.editing && m.editIdx == i {
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("enter_save_esc")))
sb.WriteString("\n")
}
}
// Add input
if m.adding {
sb.WriteString("\n")
sb.WriteString(m.editInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("enter_add")))
sb.WriteString("\n")
}
sb.WriteString("\n")
// ━━━ Provider Keys (read-only display) ━━━
renderProviderKeys(&sb, "Gemini API Keys", m.gemini)
renderProviderKeys(&sb, "Claude API Keys", m.claude)
renderProviderKeys(&sb, "Codex API Keys", m.codex)
renderProviderKeys(&sb, "Vertex API Keys", m.vertex)
if len(m.openai) > 0 {
renderSection(&sb, "OpenAI Compatibility", len(m.openai))
for i, entry := range m.openai {
name := getString(entry, "name")
baseURL := getString(entry, "base-url")
prefix := getString(entry, "prefix")
info := name
if prefix != "" {
info += " (prefix: " + prefix + ")"
}
if baseURL != "" {
info += " → " + baseURL
}
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
}
sb.WriteString("\n")
}
if m.status != "" {
sb.WriteString(m.status)
sb.WriteString("\n")
}
return sb.String()
}
func renderSection(sb *strings.Builder, title string, count int) {
header := fmt.Sprintf("%s (%d)", title, count)
sb.WriteString(tableHeaderStyle.Render(" " + header))
sb.WriteString("\n")
}
func renderProviderKeys(sb *strings.Builder, title string, keys []map[string]any) {
if len(keys) == 0 {
return
}
renderSection(sb, title, len(keys))
for i, key := range keys {
apiKey := getString(key, "api-key")
prefix := getString(key, "prefix")
baseURL := getString(key, "base-url")
info := maskKey(apiKey)
if prefix != "" {
info += " (prefix: " + prefix + ")"
}
if baseURL != "" {
info += " → " + baseURL
}
sb.WriteString(fmt.Sprintf(" %d. %s\n", i+1, info))
}
sb.WriteString("\n")
}
func maskKey(key string) string {
if len(key) <= 8 {
return strings.Repeat("*", len(key))
}
return key[:4] + strings.Repeat("*", len(key)-8) + key[len(key)-4:]
}
================================================
FILE: internal/tui/loghook.go
================================================
package tui
import (
"fmt"
"strings"
"sync"
log "github.com/sirupsen/logrus"
)
// LogHook is a logrus hook that captures log entries and sends them to a channel.
type LogHook struct {
ch chan string
formatter log.Formatter
mu sync.Mutex
levels []log.Level
}
// NewLogHook creates a new LogHook with a buffered channel of the given size.
func NewLogHook(bufSize int) *LogHook {
return &LogHook{
ch: make(chan string, bufSize),
formatter: &log.TextFormatter{DisableColors: true, FullTimestamp: true},
levels: log.AllLevels,
}
}
// SetFormatter sets a custom formatter for the hook.
func (h *LogHook) SetFormatter(f log.Formatter) {
h.mu.Lock()
defer h.mu.Unlock()
h.formatter = f
}
// Levels returns the log levels this hook should fire on.
func (h *LogHook) Levels() []log.Level {
return h.levels
}
// Fire is called by logrus when a log entry is fired.
func (h *LogHook) Fire(entry *log.Entry) error {
h.mu.Lock()
f := h.formatter
h.mu.Unlock()
var line string
if f != nil {
b, err := f.Format(entry)
if err == nil {
line = strings.TrimRight(string(b), "\n\r")
} else {
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
}
} else {
line = fmt.Sprintf("[%s] %s", entry.Level, entry.Message)
}
// Non-blocking send
select {
case h.ch <- line:
default:
// Drop oldest if full
select {
case <-h.ch:
default:
}
select {
case h.ch <- line:
default:
}
}
return nil
}
// Chan returns the channel to read log lines from.
func (h *LogHook) Chan() <-chan string {
return h.ch
}
================================================
FILE: internal/tui/logs_tab.go
================================================
package tui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
)
// logsTabModel displays real-time log lines from hook/API source.
type logsTabModel struct {
client *Client
hook *LogHook
viewport viewport.Model
lines []string
maxLines int
autoScroll bool
width int
height int
ready bool
filter string // "", "debug", "info", "warn", "error"
after int64
lastErr error
}
type logsPollMsg struct {
lines []string
latest int64
err error
}
type logsTickMsg struct{}
type logLineMsg string
func newLogsTabModel(client *Client, hook *LogHook) logsTabModel {
return logsTabModel{
client: client,
hook: hook,
maxLines: 5000,
autoScroll: true,
}
}
func (m logsTabModel) Init() tea.Cmd {
if m.hook != nil {
return m.waitForLog
}
return m.fetchLogs
}
func (m logsTabModel) fetchLogs() tea.Msg {
lines, latest, err := m.client.GetLogs(m.after, 200)
return logsPollMsg{
lines: lines,
latest: latest,
err: err,
}
}
func (m logsTabModel) waitForNextPoll() tea.Cmd {
return tea.Tick(2*time.Second, func(_ time.Time) tea.Msg {
return logsTickMsg{}
})
}
func (m logsTabModel) waitForLog() tea.Msg {
if m.hook == nil {
return nil
}
line, ok := <-m.hook.Chan()
if !ok {
return nil
}
return logLineMsg(line)
}
func (m logsTabModel) Update(msg tea.Msg) (logsTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderLogs())
return m, nil
case logsTickMsg:
if m.hook != nil {
return m, nil
}
return m, m.fetchLogs
case logsPollMsg:
if m.hook != nil {
return m, nil
}
if msg.err != nil {
m.lastErr = msg.err
} else {
m.lastErr = nil
m.after = msg.latest
if len(msg.lines) > 0 {
m.lines = append(m.lines, msg.lines...)
if len(m.lines) > m.maxLines {
m.lines = m.lines[len(m.lines)-m.maxLines:]
}
}
}
m.viewport.SetContent(m.renderLogs())
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, m.waitForNextPoll()
case logLineMsg:
m.lines = append(m.lines, string(msg))
if len(m.lines) > m.maxLines {
m.lines = m.lines[len(m.lines)-m.maxLines:]
}
m.viewport.SetContent(m.renderLogs())
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, m.waitForLog
case tea.KeyMsg:
switch msg.String() {
case "a":
m.autoScroll = !m.autoScroll
if m.autoScroll {
m.viewport.GotoBottom()
}
return m, nil
case "c":
m.lines = nil
m.lastErr = nil
m.viewport.SetContent(m.renderLogs())
return m, nil
case "1":
m.filter = ""
m.viewport.SetContent(m.renderLogs())
return m, nil
case "2":
m.filter = "info"
m.viewport.SetContent(m.renderLogs())
return m, nil
case "3":
m.filter = "warn"
m.viewport.SetContent(m.renderLogs())
return m, nil
case "4":
m.filter = "error"
m.viewport.SetContent(m.renderLogs())
return m, nil
default:
wasAtBottom := m.viewport.AtBottom()
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
// If user scrolls up, disable auto-scroll
if !m.viewport.AtBottom() && wasAtBottom {
m.autoScroll = false
}
// If user scrolls to bottom, re-enable auto-scroll
if m.viewport.AtBottom() {
m.autoScroll = true
}
return m, cmd
}
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *logsTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderLogs())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m logsTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m logsTabModel) renderLogs() string {
var sb strings.Builder
scrollStatus := successStyle.Render(T("logs_auto_scroll"))
if !m.autoScroll {
scrollStatus = warningStyle.Render(T("logs_paused"))
}
filterLabel := "ALL"
if m.filter != "" {
filterLabel = strings.ToUpper(m.filter) + "+"
}
header := fmt.Sprintf(" %s %s %s: %s %s: %d",
T("logs_title"), scrollStatus, T("logs_filter"), filterLabel, T("logs_lines"), len(m.lines))
sb.WriteString(titleStyle.Render(header))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("logs_help")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", m.width))
sb.WriteString("\n")
if m.lastErr != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.lastErr.Error()))
sb.WriteString("\n")
}
if len(m.lines) == 0 {
sb.WriteString(subtitleStyle.Render(T("logs_waiting")))
return sb.String()
}
for _, line := range m.lines {
if m.filter != "" && !m.matchLevel(line) {
continue
}
styled := m.styleLine(line)
sb.WriteString(styled)
sb.WriteString("\n")
}
return sb.String()
}
func (m logsTabModel) matchLevel(line string) bool {
switch m.filter {
case "error":
return strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") || strings.Contains(line, "[panic]")
case "warn":
return strings.Contains(line, "[warn") || strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]")
case "info":
return !strings.Contains(line, "[debug]")
default:
return true
}
}
func (m logsTabModel) styleLine(line string) string {
if strings.Contains(line, "[error]") || strings.Contains(line, "[fatal]") {
return logErrorStyle.Render(line)
}
if strings.Contains(line, "[warn") {
return logWarnStyle.Render(line)
}
if strings.Contains(line, "[info") {
return logInfoStyle.Render(line)
}
if strings.Contains(line, "[debug]") {
return logDebugStyle.Render(line)
}
return line
}
================================================
FILE: internal/tui/oauth_tab.go
================================================
package tui
import (
"fmt"
"strings"
"time"
"github.com/charmbracelet/bubbles/textinput"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// oauthProvider represents an OAuth provider option.
type oauthProvider struct {
name string
apiPath string // management API path
emoji string
}
var oauthProviders = []oauthProvider{
{"Gemini CLI", "gemini-cli-auth-url", "🟦"},
{"Claude (Anthropic)", "anthropic-auth-url", "🟧"},
{"Codex (OpenAI)", "codex-auth-url", "🟩"},
{"Antigravity", "antigravity-auth-url", "🟪"},
{"Qwen", "qwen-auth-url", "🟨"},
{"Kimi", "kimi-auth-url", "🟫"},
{"IFlow", "iflow-auth-url", "⬜"},
}
// oauthTabModel handles OAuth login flows.
type oauthTabModel struct {
client *Client
viewport viewport.Model
cursor int
state oauthState
message string
err error
width int
height int
ready bool
// Remote browser mode
authURL string // auth URL to display
authState string // OAuth state parameter
providerName string // current provider name
callbackInput textinput.Model
inputActive bool // true when user is typing callback URL
}
type oauthState int
const (
oauthIdle oauthState = iota
oauthPending
oauthRemote // remote browser mode: waiting for manual callback
oauthSuccess
oauthError
)
// Messages
type oauthStartMsg struct {
url string
state string
providerName string
err error
}
type oauthPollMsg struct {
done bool
message string
err error
}
type oauthCallbackSubmitMsg struct {
err error
}
func newOAuthTabModel(client *Client) oauthTabModel {
ti := textinput.New()
ti.Placeholder = "http://localhost:.../auth/callback?code=...&state=..."
ti.CharLimit = 2048
ti.Prompt = " 回调 URL: "
return oauthTabModel{
client: client,
callbackInput: ti,
}
}
func (m oauthTabModel) Init() tea.Cmd {
return nil
}
func (m oauthTabModel) Update(msg tea.Msg) (oauthTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case oauthStartMsg:
if msg.err != nil {
m.state = oauthError
m.err = msg.err
m.message = errorStyle.Render("✗ " + msg.err.Error())
m.viewport.SetContent(m.renderContent())
return m, nil
}
m.authURL = msg.url
m.authState = msg.state
m.providerName = msg.providerName
m.state = oauthRemote
m.callbackInput.SetValue("")
m.callbackInput.Focus()
m.inputActive = true
m.message = ""
m.viewport.SetContent(m.renderContent())
// Also start polling in the background
return m, tea.Batch(textinput.Blink, m.pollOAuthStatus(msg.state))
case oauthPollMsg:
if msg.err != nil {
m.state = oauthError
m.err = msg.err
m.message = errorStyle.Render("✗ " + msg.err.Error())
m.inputActive = false
m.callbackInput.Blur()
} else if msg.done {
m.state = oauthSuccess
m.message = successStyle.Render("✓ " + msg.message)
m.inputActive = false
m.callbackInput.Blur()
} else {
m.message = warningStyle.Render("⏳ " + msg.message)
}
m.viewport.SetContent(m.renderContent())
return m, nil
case oauthCallbackSubmitMsg:
if msg.err != nil {
m.message = errorStyle.Render(T("oauth_submit_fail") + ": " + msg.err.Error())
} else {
m.message = successStyle.Render(T("oauth_submit_ok"))
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
// ---- Input active: typing callback URL ----
if m.inputActive {
switch msg.String() {
case "enter":
callbackURL := m.callbackInput.Value()
if callbackURL == "" {
return m, nil
}
m.inputActive = false
m.callbackInput.Blur()
m.message = warningStyle.Render(T("oauth_submitting"))
m.viewport.SetContent(m.renderContent())
return m, m.submitCallback(callbackURL)
case "esc":
m.inputActive = false
m.callbackInput.Blur()
m.viewport.SetContent(m.renderContent())
return m, nil
default:
var cmd tea.Cmd
m.callbackInput, cmd = m.callbackInput.Update(msg)
m.viewport.SetContent(m.renderContent())
return m, cmd
}
}
// ---- Remote mode but not typing ----
if m.state == oauthRemote {
switch msg.String() {
case "c", "C":
// Re-activate input
m.inputActive = true
m.callbackInput.Focus()
m.viewport.SetContent(m.renderContent())
return m, textinput.Blink
case "esc":
m.state = oauthIdle
m.message = ""
m.authURL = ""
m.authState = ""
m.viewport.SetContent(m.renderContent())
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
// ---- Pending (auto polling) ----
if m.state == oauthPending {
if msg.String() == "esc" {
m.state = oauthIdle
m.message = ""
m.viewport.SetContent(m.renderContent())
}
return m, nil
}
// ---- Idle ----
switch msg.String() {
case "up", "k":
if m.cursor > 0 {
m.cursor--
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "down", "j":
if m.cursor < len(oauthProviders)-1 {
m.cursor++
m.viewport.SetContent(m.renderContent())
}
return m, nil
case "enter":
if m.cursor >= 0 && m.cursor < len(oauthProviders) {
provider := oauthProviders[m.cursor]
m.state = oauthPending
m.message = warningStyle.Render(fmt.Sprintf(T("oauth_initiating"), provider.name))
m.viewport.SetContent(m.renderContent())
return m, m.startOAuth(provider)
}
return m, nil
case "esc":
m.state = oauthIdle
m.message = ""
m.err = nil
m.viewport.SetContent(m.renderContent())
return m, nil
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m oauthTabModel) startOAuth(provider oauthProvider) tea.Cmd {
return func() tea.Msg {
// Call the auth URL endpoint with is_webui=true
data, err := m.client.getJSON("/v0/management/" + provider.apiPath + "?is_webui=true")
if err != nil {
return oauthStartMsg{err: fmt.Errorf("failed to start %s login: %w", provider.name, err)}
}
authURL := getString(data, "url")
state := getString(data, "state")
if authURL == "" {
return oauthStartMsg{err: fmt.Errorf("no auth URL returned for %s", provider.name)}
}
// Try to open browser (best effort)
_ = openBrowser(authURL)
return oauthStartMsg{url: authURL, state: state, providerName: provider.name}
}
}
func (m oauthTabModel) submitCallback(callbackURL string) tea.Cmd {
return func() tea.Msg {
// Determine provider from current context
providerKey := ""
for _, p := range oauthProviders {
if p.name == m.providerName {
// Map provider name to the canonical key the API expects
switch p.apiPath {
case "gemini-cli-auth-url":
providerKey = "gemini"
case "anthropic-auth-url":
providerKey = "anthropic"
case "codex-auth-url":
providerKey = "codex"
case "antigravity-auth-url":
providerKey = "antigravity"
case "qwen-auth-url":
providerKey = "qwen"
case "kimi-auth-url":
providerKey = "kimi"
case "iflow-auth-url":
providerKey = "iflow"
}
break
}
}
body := map[string]string{
"provider": providerKey,
"redirect_url": callbackURL,
"state": m.authState,
}
err := m.client.postJSON("/v0/management/oauth-callback", body)
if err != nil {
return oauthCallbackSubmitMsg{err: err}
}
return oauthCallbackSubmitMsg{}
}
}
func (m oauthTabModel) pollOAuthStatus(state string) tea.Cmd {
return func() tea.Msg {
// Poll session status for up to 5 minutes
deadline := time.Now().Add(5 * time.Minute)
for {
if time.Now().After(deadline) {
return oauthPollMsg{done: false, err: fmt.Errorf("%s", T("oauth_timeout"))}
}
time.Sleep(2 * time.Second)
status, errMsg, err := m.client.GetAuthStatus(state)
if err != nil {
continue // Ignore transient errors
}
switch status {
case "ok":
return oauthPollMsg{
done: true,
message: T("oauth_success"),
}
case "error":
return oauthPollMsg{
done: false,
err: fmt.Errorf("%s: %s", T("oauth_failed"), errMsg),
}
case "wait":
continue
default:
return oauthPollMsg{
done: true,
message: T("oauth_completed"),
}
}
}
}
}
func (m *oauthTabModel) SetSize(w, h int) {
m.width = w
m.height = h
m.callbackInput.Width = w - 16
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m oauthTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m oauthTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("oauth_title")))
sb.WriteString("\n\n")
if m.message != "" {
sb.WriteString(" " + m.message)
sb.WriteString("\n\n")
}
// ---- Remote browser mode ----
if m.state == oauthRemote {
sb.WriteString(m.renderRemoteMode())
return sb.String()
}
if m.state == oauthPending {
sb.WriteString(helpStyle.Render(T("oauth_press_esc")))
return sb.String()
}
sb.WriteString(helpStyle.Render(T("oauth_select")))
sb.WriteString("\n\n")
for i, p := range oauthProviders {
isSelected := i == m.cursor
prefix := " "
if isSelected {
prefix = "▸ "
}
label := fmt.Sprintf("%s %s", p.emoji, p.name)
if isSelected {
label = lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("#FFFFFF")).Background(colorPrimary).Padding(0, 1).Render(label)
} else {
label = lipgloss.NewStyle().Foreground(colorText).Padding(0, 1).Render(label)
}
sb.WriteString(prefix + label + "\n")
}
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("oauth_help")))
return sb.String()
}
func (m oauthTabModel) renderRemoteMode() string {
var sb strings.Builder
providerStyle := lipgloss.NewStyle().Bold(true).Foreground(colorHighlight)
sb.WriteString(providerStyle.Render(fmt.Sprintf(" ✦ %s OAuth", m.providerName)))
sb.WriteString("\n\n")
// Auth URL section
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_auth_url")))
sb.WriteString("\n")
// Wrap URL to fit terminal width
urlStyle := lipgloss.NewStyle().Foreground(lipgloss.Color("252"))
maxURLWidth := m.width - 6
if maxURLWidth < 40 {
maxURLWidth = 40
}
wrappedURL := wrapText(m.authURL, maxURLWidth)
for _, line := range wrappedURL {
sb.WriteString(" " + urlStyle.Render(line) + "\n")
}
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("oauth_remote_hint")))
sb.WriteString("\n\n")
// Callback URL input
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorInfo).Render(T("oauth_callback_url")))
sb.WriteString("\n")
if m.inputActive {
sb.WriteString(m.callbackInput.View())
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(" " + T("enter_submit") + " • " + T("esc_cancel")))
} else {
sb.WriteString(helpStyle.Render(T("oauth_press_c")))
}
sb.WriteString("\n\n")
sb.WriteString(warningStyle.Render(T("oauth_waiting")))
return sb.String()
}
// wrapText splits a long string into lines of at most maxWidth characters.
func wrapText(s string, maxWidth int) []string {
if maxWidth <= 0 {
return []string{s}
}
var lines []string
for len(s) > maxWidth {
lines = append(lines, s[:maxWidth])
s = s[maxWidth:]
}
if len(s) > 0 {
lines = append(lines, s)
}
return lines
}
================================================
FILE: internal/tui/styles.go
================================================
// Package tui provides a terminal-based management interface for CLIProxyAPI.
package tui
import "github.com/charmbracelet/lipgloss"
// Color palette
var (
colorPrimary = lipgloss.Color("#7C3AED") // violet
colorSecondary = lipgloss.Color("#6366F1") // indigo
colorSuccess = lipgloss.Color("#22C55E") // green
colorWarning = lipgloss.Color("#EAB308") // yellow
colorError = lipgloss.Color("#EF4444") // red
colorInfo = lipgloss.Color("#3B82F6") // blue
colorMuted = lipgloss.Color("#6B7280") // gray
colorBg = lipgloss.Color("#1E1E2E") // dark bg
colorSurface = lipgloss.Color("#313244") // slightly lighter
colorText = lipgloss.Color("#CDD6F4") // light text
colorSubtext = lipgloss.Color("#A6ADC8") // dimmer text
colorBorder = lipgloss.Color("#45475A") // border
colorHighlight = lipgloss.Color("#F5C2E7") // pink highlight
)
// Tab bar styles
var (
tabActiveStyle = lipgloss.NewStyle().
Bold(true).
Foreground(lipgloss.Color("#FFFFFF")).
Background(colorPrimary).
Padding(0, 2)
tabInactiveStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Background(colorSurface).
Padding(0, 2)
tabBarStyle = lipgloss.NewStyle().
Background(colorSurface).
PaddingLeft(1).
PaddingBottom(0)
)
// Content styles
var (
titleStyle = lipgloss.NewStyle().
Bold(true).
Foreground(colorHighlight).
MarginBottom(1)
subtitleStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Italic(true)
labelStyle = lipgloss.NewStyle().
Foreground(colorInfo).
Bold(true).
Width(24)
valueStyle = lipgloss.NewStyle().
Foreground(colorText)
sectionStyle = lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(colorBorder).
Padding(1, 2)
errorStyle = lipgloss.NewStyle().
Foreground(colorError).
Bold(true)
successStyle = lipgloss.NewStyle().
Foreground(colorSuccess)
warningStyle = lipgloss.NewStyle().
Foreground(colorWarning)
statusBarStyle = lipgloss.NewStyle().
Foreground(colorSubtext).
Background(colorSurface).
PaddingLeft(1).
PaddingRight(1)
helpStyle = lipgloss.NewStyle().
Foreground(colorMuted)
)
// Log level styles
var (
logDebugStyle = lipgloss.NewStyle().Foreground(colorMuted)
logInfoStyle = lipgloss.NewStyle().Foreground(colorInfo)
logWarnStyle = lipgloss.NewStyle().Foreground(colorWarning)
logErrorStyle = lipgloss.NewStyle().Foreground(colorError)
)
// Table styles
var (
tableHeaderStyle = lipgloss.NewStyle().
Bold(true).
Foreground(colorHighlight).
BorderBottom(true).
BorderStyle(lipgloss.NormalBorder()).
BorderForeground(colorBorder)
tableCellStyle = lipgloss.NewStyle().
Foreground(colorText).
PaddingRight(2)
tableSelectedStyle = lipgloss.NewStyle().
Foreground(lipgloss.Color("#FFFFFF")).
Background(colorPrimary).
Bold(true)
)
func logLevelStyle(level string) lipgloss.Style {
switch level {
case "debug":
return logDebugStyle
case "info":
return logInfoStyle
case "warn", "warning":
return logWarnStyle
case "error", "fatal", "panic":
return logErrorStyle
default:
return logInfoStyle
}
}
================================================
FILE: internal/tui/usage_tab.go
================================================
package tui
import (
"fmt"
"sort"
"strings"
"github.com/charmbracelet/bubbles/viewport"
tea "github.com/charmbracelet/bubbletea"
"github.com/charmbracelet/lipgloss"
)
// usageTabModel displays usage statistics with charts and breakdowns.
type usageTabModel struct {
client *Client
viewport viewport.Model
usage map[string]any
err error
width int
height int
ready bool
}
type usageDataMsg struct {
usage map[string]any
err error
}
func newUsageTabModel(client *Client) usageTabModel {
return usageTabModel{
client: client,
}
}
func (m usageTabModel) Init() tea.Cmd {
return m.fetchData
}
func (m usageTabModel) fetchData() tea.Msg {
usage, err := m.client.GetUsage()
return usageDataMsg{usage: usage, err: err}
}
func (m usageTabModel) Update(msg tea.Msg) (usageTabModel, tea.Cmd) {
switch msg := msg.(type) {
case localeChangedMsg:
m.viewport.SetContent(m.renderContent())
return m, nil
case usageDataMsg:
if msg.err != nil {
m.err = msg.err
} else {
m.err = nil
m.usage = msg.usage
}
m.viewport.SetContent(m.renderContent())
return m, nil
case tea.KeyMsg:
if msg.String() == "r" {
return m, m.fetchData
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
var cmd tea.Cmd
m.viewport, cmd = m.viewport.Update(msg)
return m, cmd
}
func (m *usageTabModel) SetSize(w, h int) {
m.width = w
m.height = h
if !m.ready {
m.viewport = viewport.New(w, h)
m.viewport.SetContent(m.renderContent())
m.ready = true
} else {
m.viewport.Width = w
m.viewport.Height = h
}
}
func (m usageTabModel) View() string {
if !m.ready {
return T("loading")
}
return m.viewport.View()
}
func (m usageTabModel) renderContent() string {
var sb strings.Builder
sb.WriteString(titleStyle.Render(T("usage_title")))
sb.WriteString("\n")
sb.WriteString(helpStyle.Render(T("usage_help")))
sb.WriteString("\n\n")
if m.err != nil {
sb.WriteString(errorStyle.Render("⚠ Error: " + m.err.Error()))
sb.WriteString("\n")
return sb.String()
}
if m.usage == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
usageMap, _ := m.usage["usage"].(map[string]any)
if usageMap == nil {
sb.WriteString(subtitleStyle.Render(T("usage_no_data")))
sb.WriteString("\n")
return sb.String()
}
totalReqs := int64(getFloat(usageMap, "total_requests"))
successCnt := int64(getFloat(usageMap, "success_count"))
failureCnt := int64(getFloat(usageMap, "failure_count"))
totalTokens := int64(getFloat(usageMap, "total_tokens"))
// ━━━ Overview Cards ━━━
cardWidth := 20
if m.width > 0 {
cardWidth = (m.width - 6) / 4
if cardWidth < 16 {
cardWidth = 16
}
}
cardStyle := lipgloss.NewStyle().
Border(lipgloss.RoundedBorder()).
BorderForeground(lipgloss.Color("240")).
Padding(0, 1).
Width(cardWidth).
Height(3)
// Total Requests
card1 := cardStyle.Copy().BorderForeground(lipgloss.Color("111")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_reqs")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("111")).Render(fmt.Sprintf("%d", totalReqs)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("● %s: %d ● %s: %d", T("usage_success"), successCnt, T("usage_failure"), failureCnt)),
))
// Total Tokens
card2 := cardStyle.Copy().BorderForeground(lipgloss.Color("214")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_total_tokens")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("214")).Render(formatLargeNumber(totalTokens)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_token_l"), formatLargeNumber(totalTokens))),
))
// RPM
rpm := float64(0)
if totalReqs > 0 {
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
rpm = float64(totalReqs) / float64(len(rByH)) / 60.0
}
}
card3 := cardStyle.Copy().BorderForeground(lipgloss.Color("76")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_rpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("76")).Render(fmt.Sprintf("%.2f", rpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %d", T("usage_total_reqs"), totalReqs)),
))
// TPM
tpm := float64(0)
if totalTokens > 0 {
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
tpm = float64(totalTokens) / float64(len(tByH)) / 60.0
}
}
card4 := cardStyle.Copy().BorderForeground(lipgloss.Color("170")).Render(fmt.Sprintf(
"%s\n%s\n%s",
lipgloss.NewStyle().Foreground(colorMuted).Render(T("usage_tpm")),
lipgloss.NewStyle().Bold(true).Foreground(lipgloss.Color("170")).Render(fmt.Sprintf("%.2f", tpm)),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%s: %s", T("usage_total_tokens"), formatLargeNumber(totalTokens))),
))
sb.WriteString(lipgloss.JoinHorizontal(lipgloss.Top, card1, " ", card2, " ", card3, " ", card4))
sb.WriteString("\n\n")
// ━━━ Requests by Hour (ASCII bar chart) ━━━
if rByH, ok := usageMap["requests_by_hour"].(map[string]any); ok && len(rByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByH, m.width-6, lipgloss.Color("111")))
sb.WriteString("\n")
}
// ━━━ Tokens by Hour ━━━
if tByH, ok := usageMap["tokens_by_hour"].(map[string]any); ok && len(tByH) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_tok_by_hour")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(tByH, m.width-6, lipgloss.Color("214")))
sb.WriteString("\n")
}
// ━━━ Requests by Day ━━━
if rByD, ok := usageMap["requests_by_day"].(map[string]any); ok && len(rByD) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_req_by_day")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 60)))
sb.WriteString("\n")
sb.WriteString(renderBarChart(rByD, m.width-6, lipgloss.Color("76")))
sb.WriteString("\n")
}
// ━━━ API Detail Stats ━━━
if apis, ok := usageMap["apis"].(map[string]any); ok && len(apis) > 0 {
sb.WriteString(lipgloss.NewStyle().Bold(true).Foreground(colorHighlight).Render(T("usage_api_detail")))
sb.WriteString("\n")
sb.WriteString(strings.Repeat("─", minInt(m.width, 80)))
sb.WriteString("\n")
header := fmt.Sprintf(" %-30s %10s %12s", "API", T("requests"), T("tokens"))
sb.WriteString(tableHeaderStyle.Render(header))
sb.WriteString("\n")
for apiName, apiSnap := range apis {
if apiMap, ok := apiSnap.(map[string]any); ok {
apiReqs := int64(getFloat(apiMap, "total_requests"))
apiToks := int64(getFloat(apiMap, "total_tokens"))
row := fmt.Sprintf(" %-30s %10d %12s",
truncate(maskKey(apiName), 30), apiReqs, formatLargeNumber(apiToks))
sb.WriteString(lipgloss.NewStyle().Bold(true).Render(row))
sb.WriteString("\n")
// Per-model breakdown
if models, ok := apiMap["models"].(map[string]any); ok {
for model, v := range models {
if stats, ok := v.(map[string]any); ok {
mReqs := int64(getFloat(stats, "total_requests"))
mToks := int64(getFloat(stats, "total_tokens"))
mRow := fmt.Sprintf(" ├─ %-28s %10d %12s",
truncate(model, 28), mReqs, formatLargeNumber(mToks))
sb.WriteString(tableCellStyle.Render(mRow))
sb.WriteString("\n")
// Token type breakdown from details
sb.WriteString(m.renderTokenBreakdown(stats))
}
}
}
}
}
}
sb.WriteString("\n")
return sb.String()
}
// renderTokenBreakdown aggregates input/output/cached/reasoning tokens from model details.
func (m usageTabModel) renderTokenBreakdown(modelStats map[string]any) string {
details, ok := modelStats["details"]
if !ok {
return ""
}
detailList, ok := details.([]any)
if !ok || len(detailList) == 0 {
return ""
}
var inputTotal, outputTotal, cachedTotal, reasoningTotal int64
for _, d := range detailList {
dm, ok := d.(map[string]any)
if !ok {
continue
}
tokens, ok := dm["tokens"].(map[string]any)
if !ok {
continue
}
inputTotal += int64(getFloat(tokens, "input_tokens"))
outputTotal += int64(getFloat(tokens, "output_tokens"))
cachedTotal += int64(getFloat(tokens, "cached_tokens"))
reasoningTotal += int64(getFloat(tokens, "reasoning_tokens"))
}
if inputTotal == 0 && outputTotal == 0 && cachedTotal == 0 && reasoningTotal == 0 {
return ""
}
parts := []string{}
if inputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_input"), formatLargeNumber(inputTotal)))
}
if outputTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_output"), formatLargeNumber(outputTotal)))
}
if cachedTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_cached"), formatLargeNumber(cachedTotal)))
}
if reasoningTotal > 0 {
parts = append(parts, fmt.Sprintf("%s:%s", T("usage_reasoning"), formatLargeNumber(reasoningTotal)))
}
return fmt.Sprintf(" │ %s\n",
lipgloss.NewStyle().Foreground(colorMuted).Render(strings.Join(parts, " ")))
}
// renderBarChart renders a simple ASCII horizontal bar chart.
func renderBarChart(data map[string]any, maxBarWidth int, barColor lipgloss.Color) string {
if maxBarWidth < 10 {
maxBarWidth = 10
}
// Sort keys
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
sort.Strings(keys)
// Find max value
maxVal := float64(0)
for _, k := range keys {
v := getFloat(data, k)
if v > maxVal {
maxVal = v
}
}
if maxVal == 0 {
return ""
}
barStyle := lipgloss.NewStyle().Foreground(barColor)
var sb strings.Builder
labelWidth := 12
barAvail := maxBarWidth - labelWidth - 12
if barAvail < 5 {
barAvail = 5
}
for _, k := range keys {
v := getFloat(data, k)
barLen := int(v / maxVal * float64(barAvail))
if barLen < 1 && v > 0 {
barLen = 1
}
bar := strings.Repeat("█", barLen)
label := k
if len(label) > labelWidth {
label = label[:labelWidth]
}
sb.WriteString(fmt.Sprintf(" %-*s %s %s\n",
labelWidth, label,
barStyle.Render(bar),
lipgloss.NewStyle().Foreground(colorMuted).Render(fmt.Sprintf("%.0f", v)),
))
}
return sb.String()
}
================================================
FILE: internal/usage/logger_plugin.go
================================================
// Package usage provides usage tracking and logging functionality for the CLI Proxy API server.
// It includes plugins for monitoring API usage, token consumption, and other metrics
// to help with observability and billing purposes.
package usage
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gin-gonic/gin"
coreusage "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
)
var statisticsEnabled atomic.Bool
func init() {
statisticsEnabled.Store(true)
coreusage.RegisterPlugin(NewLoggerPlugin())
}
// LoggerPlugin collects in-memory request statistics for usage analysis.
// It implements coreusage.Plugin to receive usage records emitted by the runtime.
type LoggerPlugin struct {
stats *RequestStatistics
}
// NewLoggerPlugin constructs a new logger plugin instance.
//
// Returns:
// - *LoggerPlugin: A new logger plugin instance wired to the shared statistics store.
func NewLoggerPlugin() *LoggerPlugin { return &LoggerPlugin{stats: defaultRequestStatistics} }
// HandleUsage implements coreusage.Plugin.
// It updates the in-memory statistics store whenever a usage record is received.
//
// Parameters:
// - ctx: The context for the usage record
// - record: The usage record to aggregate
func (p *LoggerPlugin) HandleUsage(ctx context.Context, record coreusage.Record) {
if !statisticsEnabled.Load() {
return
}
if p == nil || p.stats == nil {
return
}
p.stats.Record(ctx, record)
}
// SetStatisticsEnabled toggles whether in-memory statistics are recorded.
func SetStatisticsEnabled(enabled bool) { statisticsEnabled.Store(enabled) }
// StatisticsEnabled reports the current recording state.
func StatisticsEnabled() bool { return statisticsEnabled.Load() }
// RequestStatistics maintains aggregated request metrics in memory.
type RequestStatistics struct {
mu sync.RWMutex
totalRequests int64
successCount int64
failureCount int64
totalTokens int64
apis map[string]*apiStats
requestsByDay map[string]int64
requestsByHour map[int]int64
tokensByDay map[string]int64
tokensByHour map[int]int64
}
// apiStats holds aggregated metrics for a single API key.
type apiStats struct {
TotalRequests int64
TotalTokens int64
Models map[string]*modelStats
}
// modelStats holds aggregated metrics for a specific model within an API.
type modelStats struct {
TotalRequests int64
TotalTokens int64
Details []RequestDetail
}
// RequestDetail stores the timestamp and token usage for a single request.
type RequestDetail struct {
Timestamp time.Time `json:"timestamp"`
Source string `json:"source"`
AuthIndex string `json:"auth_index"`
Tokens TokenStats `json:"tokens"`
Failed bool `json:"failed"`
}
// TokenStats captures the token usage breakdown for a request.
type TokenStats struct {
InputTokens int64 `json:"input_tokens"`
OutputTokens int64 `json:"output_tokens"`
ReasoningTokens int64 `json:"reasoning_tokens"`
CachedTokens int64 `json:"cached_tokens"`
TotalTokens int64 `json:"total_tokens"`
}
// StatisticsSnapshot represents an immutable view of the aggregated metrics.
type StatisticsSnapshot struct {
TotalRequests int64 `json:"total_requests"`
SuccessCount int64 `json:"success_count"`
FailureCount int64 `json:"failure_count"`
TotalTokens int64 `json:"total_tokens"`
APIs map[string]APISnapshot `json:"apis"`
RequestsByDay map[string]int64 `json:"requests_by_day"`
RequestsByHour map[string]int64 `json:"requests_by_hour"`
TokensByDay map[string]int64 `json:"tokens_by_day"`
TokensByHour map[string]int64 `json:"tokens_by_hour"`
}
// APISnapshot summarises metrics for a single API key.
type APISnapshot struct {
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
Models map[string]ModelSnapshot `json:"models"`
}
// ModelSnapshot summarises metrics for a specific model.
type ModelSnapshot struct {
TotalRequests int64 `json:"total_requests"`
TotalTokens int64 `json:"total_tokens"`
Details []RequestDetail `json:"details"`
}
var defaultRequestStatistics = NewRequestStatistics()
// GetRequestStatistics returns the shared statistics store.
func GetRequestStatistics() *RequestStatistics { return defaultRequestStatistics }
// NewRequestStatistics constructs an empty statistics store.
func NewRequestStatistics() *RequestStatistics {
return &RequestStatistics{
apis: make(map[string]*apiStats),
requestsByDay: make(map[string]int64),
requestsByHour: make(map[int]int64),
tokensByDay: make(map[string]int64),
tokensByHour: make(map[int]int64),
}
}
// Record ingests a new usage record and updates the aggregates.
func (s *RequestStatistics) Record(ctx context.Context, record coreusage.Record) {
if s == nil {
return
}
if !statisticsEnabled.Load() {
return
}
timestamp := record.RequestedAt
if timestamp.IsZero() {
timestamp = time.Now()
}
detail := normaliseDetail(record.Detail)
totalTokens := detail.TotalTokens
statsKey := record.APIKey
if statsKey == "" {
statsKey = resolveAPIIdentifier(ctx, record)
}
failed := record.Failed
if !failed {
failed = !resolveSuccess(ctx)
}
success := !failed
modelName := record.Model
if modelName == "" {
modelName = "unknown"
}
dayKey := timestamp.Format("2006-01-02")
hourKey := timestamp.Hour()
s.mu.Lock()
defer s.mu.Unlock()
s.totalRequests++
if success {
s.successCount++
} else {
s.failureCount++
}
s.totalTokens += totalTokens
stats, ok := s.apis[statsKey]
if !ok {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[statsKey] = stats
}
s.updateAPIStats(stats, modelName, RequestDetail{
Timestamp: timestamp,
Source: record.Source,
AuthIndex: record.AuthIndex,
Tokens: detail,
Failed: failed,
})
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func (s *RequestStatistics) updateAPIStats(stats *apiStats, model string, detail RequestDetail) {
stats.TotalRequests++
stats.TotalTokens += detail.Tokens.TotalTokens
modelStatsValue, ok := stats.Models[model]
if !ok {
modelStatsValue = &modelStats{}
stats.Models[model] = modelStatsValue
}
modelStatsValue.TotalRequests++
modelStatsValue.TotalTokens += detail.Tokens.TotalTokens
modelStatsValue.Details = append(modelStatsValue.Details, detail)
}
// Snapshot returns a copy of the aggregated metrics for external consumption.
func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
result := StatisticsSnapshot{}
if s == nil {
return result
}
s.mu.RLock()
defer s.mu.RUnlock()
result.TotalRequests = s.totalRequests
result.SuccessCount = s.successCount
result.FailureCount = s.failureCount
result.TotalTokens = s.totalTokens
result.APIs = make(map[string]APISnapshot, len(s.apis))
for apiName, stats := range s.apis {
apiSnapshot := APISnapshot{
TotalRequests: stats.TotalRequests,
TotalTokens: stats.TotalTokens,
Models: make(map[string]ModelSnapshot, len(stats.Models)),
}
for modelName, modelStatsValue := range stats.Models {
requestDetails := make([]RequestDetail, len(modelStatsValue.Details))
copy(requestDetails, modelStatsValue.Details)
apiSnapshot.Models[modelName] = ModelSnapshot{
TotalRequests: modelStatsValue.TotalRequests,
TotalTokens: modelStatsValue.TotalTokens,
Details: requestDetails,
}
}
result.APIs[apiName] = apiSnapshot
}
result.RequestsByDay = make(map[string]int64, len(s.requestsByDay))
for k, v := range s.requestsByDay {
result.RequestsByDay[k] = v
}
result.RequestsByHour = make(map[string]int64, len(s.requestsByHour))
for hour, v := range s.requestsByHour {
key := formatHour(hour)
result.RequestsByHour[key] = v
}
result.TokensByDay = make(map[string]int64, len(s.tokensByDay))
for k, v := range s.tokensByDay {
result.TokensByDay[k] = v
}
result.TokensByHour = make(map[string]int64, len(s.tokensByHour))
for hour, v := range s.tokensByHour {
key := formatHour(hour)
result.TokensByHour[key] = v
}
return result
}
type MergeResult struct {
Added int64 `json:"added"`
Skipped int64 `json:"skipped"`
}
// MergeSnapshot merges an exported statistics snapshot into the current store.
// Existing data is preserved and duplicate request details are skipped.
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
result := MergeResult{}
if s == nil {
return result
}
s.mu.Lock()
defer s.mu.Unlock()
seen := make(map[string]struct{})
for apiName, stats := range s.apis {
if stats == nil {
continue
}
for modelName, modelStatsValue := range stats.Models {
if modelStatsValue == nil {
continue
}
for _, detail := range modelStatsValue.Details {
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
}
}
}
for apiName, apiSnapshot := range snapshot.APIs {
apiName = strings.TrimSpace(apiName)
if apiName == "" {
continue
}
stats, ok := s.apis[apiName]
if !ok || stats == nil {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[apiName] = stats
} else if stats.Models == nil {
stats.Models = make(map[string]*modelStats)
}
for modelName, modelSnapshot := range apiSnapshot.Models {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
modelName = "unknown"
}
for _, detail := range modelSnapshot.Details {
detail.Tokens = normaliseTokenStats(detail.Tokens)
if detail.Timestamp.IsZero() {
detail.Timestamp = time.Now()
}
key := dedupKey(apiName, modelName, detail)
if _, exists := seen[key]; exists {
result.Skipped++
continue
}
seen[key] = struct{}{}
s.recordImported(apiName, modelName, stats, detail)
result.Added++
}
}
}
return result
}
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
totalTokens := detail.Tokens.TotalTokens
if totalTokens < 0 {
totalTokens = 0
}
s.totalRequests++
if detail.Failed {
s.failureCount++
} else {
s.successCount++
}
s.totalTokens += totalTokens
s.updateAPIStats(stats, modelName, detail)
dayKey := detail.Timestamp.Format("2006-01-02")
hourKey := detail.Timestamp.Hour()
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func dedupKey(apiName, modelName string, detail RequestDetail) string {
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
tokens := normaliseTokenStats(detail.Tokens)
return fmt.Sprintf(
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
apiName,
modelName,
timestamp,
detail.Source,
detail.AuthIndex,
detail.Failed,
tokens.InputTokens,
tokens.OutputTokens,
tokens.ReasoningTokens,
tokens.CachedTokens,
tokens.TotalTokens,
)
}
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
path := ginCtx.FullPath()
if path == "" && ginCtx.Request != nil {
path = ginCtx.Request.URL.Path
}
method := ""
if ginCtx.Request != nil {
method = ginCtx.Request.Method
}
if path != "" {
if method != "" {
return method + " " + path
}
return path
}
}
}
if record.Provider != "" {
return record.Provider
}
return "unknown"
}
func resolveSuccess(ctx context.Context) bool {
if ctx == nil {
return true
}
ginCtx, ok := ctx.Value("gin").(*gin.Context)
if !ok || ginCtx == nil {
return true
}
status := ginCtx.Writer.Status()
if status == 0 {
return true
}
return status < httpStatusBadRequest
}
const httpStatusBadRequest = 400
func normaliseDetail(detail coreusage.Detail) TokenStats {
tokens := TokenStats{
InputTokens: detail.InputTokens,
OutputTokens: detail.OutputTokens,
ReasoningTokens: detail.ReasoningTokens,
CachedTokens: detail.CachedTokens,
TotalTokens: detail.TotalTokens,
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens + detail.CachedTokens
}
return tokens
}
func normaliseTokenStats(tokens TokenStats) TokenStats {
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
}
return tokens
}
func formatHour(hour int) string {
if hour < 0 {
hour = 0
}
hour = hour % 24
return fmt.Sprintf("%02d", hour)
}
================================================
FILE: internal/util/claude_model.go
================================================
package util
import "strings"
// IsClaudeThinkingModel checks if the model is a Claude thinking model
// that requires the interleaved-thinking beta header.
func IsClaudeThinkingModel(model string) bool {
lower := strings.ToLower(model)
return strings.Contains(lower, "claude") && strings.Contains(lower, "thinking")
}
================================================
FILE: internal/util/claude_model_test.go
================================================
package util
import "testing"
func TestIsClaudeThinkingModel(t *testing.T) {
tests := []struct {
name string
model string
expected bool
}{
// Claude thinking models - should return true
{"claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"claude-opus-4-6-thinking", "claude-opus-4-6-thinking", true},
{"Claude-Sonnet-Thinking uppercase", "Claude-Sonnet-4-5-Thinking", true},
{"claude thinking mixed case", "Claude-THINKING-Model", true},
// Non-thinking Claude models - should return false
{"claude-sonnet-4-5 (no thinking)", "claude-sonnet-4-5", false},
{"claude-opus-4-5 (no thinking)", "claude-opus-4-5", false},
{"claude-3-5-sonnet", "claude-3-5-sonnet-20240620", false},
// Non-Claude models - should return false
{"gemini-3-pro-preview", "gemini-3-pro-preview", false},
{"gemini-thinking model", "gemini-3-pro-thinking", false}, // not Claude
{"gpt-4o", "gpt-4o", false},
{"empty string", "", false},
// Edge cases
{"thinking without claude", "thinking-model", false},
{"claude without thinking", "claude-model", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsClaudeThinkingModel(tt.model)
if result != tt.expected {
t.Errorf("IsClaudeThinkingModel(%q) = %v, expected %v", tt.model, result, tt.expected)
}
})
}
}
================================================
FILE: internal/util/claude_tool_id.go
================================================
package util
import (
"fmt"
"regexp"
"sync/atomic"
"time"
)
var (
claudeToolUseIDSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_-]`)
claudeToolUseIDCounter uint64
)
// SanitizeClaudeToolID ensures the given id conforms to Claude's
// tool_use.id regex ^[a-zA-Z0-9_-]+$. Non-conforming characters are
// replaced with '_'; an empty result gets a generated fallback.
func SanitizeClaudeToolID(id string) string {
s := claudeToolUseIDSanitizer.ReplaceAllString(id, "_")
if s == "" {
s = fmt.Sprintf("toolu_%d_%d", time.Now().UnixNano(), atomic.AddUint64(&claudeToolUseIDCounter, 1))
}
return s
}
================================================
FILE: internal/util/gemini_schema.go
================================================
// Package util provides utility functions for the CLI Proxy API server.
package util
import (
"fmt"
"sort"
"strconv"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
const placeholderReasonDescription = "Brief explanation of why you are calling this tool"
// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
// It handles unsupported keywords, type flattening, and schema simplification while preserving
// semantic information as description hints.
func CleanJSONSchemaForAntigravity(jsonStr string) string {
return cleanJSONSchema(jsonStr, true)
}
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini tool calling.
// It removes unsupported keywords and simplifies schemas, without adding empty-schema placeholders.
func CleanJSONSchemaForGemini(jsonStr string) string {
return cleanJSONSchema(jsonStr, false)
}
// cleanJSONSchema performs the core cleaning operations on the JSON schema.
func cleanJSONSchema(jsonStr string, addPlaceholder bool) string {
// Phase 1: Convert and add hints
jsonStr = convertRefsToHints(jsonStr)
jsonStr = convertConstToEnum(jsonStr)
jsonStr = convertEnumValuesToStrings(jsonStr)
jsonStr = addEnumHints(jsonStr)
jsonStr = addAdditionalPropertiesHints(jsonStr)
jsonStr = moveConstraintsToDescription(jsonStr)
// Phase 2: Flatten complex structures
jsonStr = mergeAllOf(jsonStr)
jsonStr = flattenAnyOfOneOf(jsonStr)
jsonStr = flattenTypeArrays(jsonStr)
// Phase 3: Cleanup
jsonStr = removeUnsupportedKeywords(jsonStr)
if !addPlaceholder {
// Gemini schema cleanup: remove nullable/title and placeholder-only fields.
jsonStr = removeKeywords(jsonStr, []string{"nullable", "title"})
jsonStr = removePlaceholderFields(jsonStr)
}
jsonStr = cleanupRequiredFields(jsonStr)
// Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
if addPlaceholder {
jsonStr = addEmptySchemaPlaceholder(jsonStr)
}
return jsonStr
}
// removeKeywords removes all occurrences of specified keywords from the JSON schema.
func removeKeywords(jsonStr string, keywords []string) string {
deletePaths := make([]string, 0)
pathsByField := findPathsByFields(jsonStr, keywords)
for _, key := range keywords {
for _, p := range pathsByField[key] {
if isPropertyDefinition(trimSuffix(p, "."+key)) {
continue
}
deletePaths = append(deletePaths, p)
}
}
sortByDepth(deletePaths)
for _, p := range deletePaths {
jsonStr, _ = sjson.Delete(jsonStr, p)
}
return jsonStr
}
// removePlaceholderFields removes placeholder-only properties ("_" and "reason") and their required entries.
func removePlaceholderFields(jsonStr string) string {
// Remove "_" placeholder properties.
paths := findPaths(jsonStr, "_")
sortByDepth(paths)
for _, p := range paths {
if !strings.HasSuffix(p, ".properties._") {
continue
}
jsonStr, _ = sjson.Delete(jsonStr, p)
parentPath := trimSuffix(p, ".properties._")
reqPath := joinPath(parentPath, "required")
req := gjson.Get(jsonStr, reqPath)
if req.IsArray() {
var filtered []string
for _, r := range req.Array() {
if r.String() != "_" {
filtered = append(filtered, r.String())
}
}
if len(filtered) == 0 {
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
} else {
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
}
}
}
// Remove placeholder-only "reason" objects.
reasonPaths := findPaths(jsonStr, "reason")
sortByDepth(reasonPaths)
for _, p := range reasonPaths {
if !strings.HasSuffix(p, ".properties.reason") {
continue
}
parentPath := trimSuffix(p, ".properties.reason")
props := gjson.Get(jsonStr, joinPath(parentPath, "properties"))
if !props.IsObject() || len(props.Map()) != 1 {
continue
}
desc := gjson.Get(jsonStr, p+".description").String()
if desc != placeholderReasonDescription {
continue
}
jsonStr, _ = sjson.Delete(jsonStr, p)
reqPath := joinPath(parentPath, "required")
req := gjson.Get(jsonStr, reqPath)
if req.IsArray() {
var filtered []string
for _, r := range req.Array() {
if r.String() != "reason" {
filtered = append(filtered, r.String())
}
}
if len(filtered) == 0 {
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
} else {
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
}
}
}
return jsonStr
}
// convertRefsToHints converts $ref to description hints (Lazy Hint strategy).
func convertRefsToHints(jsonStr string) string {
paths := findPaths(jsonStr, "$ref")
sortByDepth(paths)
for _, p := range paths {
refVal := gjson.Get(jsonStr, p).String()
defName := refVal
if idx := strings.LastIndex(refVal, "/"); idx >= 0 {
defName = refVal[idx+1:]
}
parentPath := trimSuffix(p, ".$ref")
hint := fmt.Sprintf("See: %s", defName)
if existing := gjson.Get(jsonStr, descriptionPath(parentPath)).String(); existing != "" {
hint = fmt.Sprintf("%s (%s)", existing, hint)
}
replacement := `{"type":"object","description":""}`
replacement, _ = sjson.Set(replacement, "description", hint)
jsonStr = setRawAt(jsonStr, parentPath, replacement)
}
return jsonStr
}
func convertConstToEnum(jsonStr string) string {
for _, p := range findPaths(jsonStr, "const") {
val := gjson.Get(jsonStr, p)
if !val.Exists() {
continue
}
enumPath := trimSuffix(p, ".const") + ".enum"
if !gjson.Get(jsonStr, enumPath).Exists() {
jsonStr, _ = sjson.Set(jsonStr, enumPath, []interface{}{val.Value()})
}
}
return jsonStr
}
// convertEnumValuesToStrings ensures all enum values are strings and the schema type is set to string.
// Gemini API requires enum values to be of type string, not numbers or booleans.
func convertEnumValuesToStrings(jsonStr string) string {
for _, p := range findPaths(jsonStr, "enum") {
arr := gjson.Get(jsonStr, p)
if !arr.IsArray() {
continue
}
var stringVals []string
for _, item := range arr.Array() {
stringVals = append(stringVals, item.String())
}
// Always update enum values to strings and set type to "string"
// This ensures compatibility with Antigravity Gemini which only allows enum for STRING type
jsonStr, _ = sjson.Set(jsonStr, p, stringVals)
parentPath := trimSuffix(p, ".enum")
jsonStr, _ = sjson.Set(jsonStr, joinPath(parentPath, "type"), "string")
}
return jsonStr
}
func addEnumHints(jsonStr string) string {
for _, p := range findPaths(jsonStr, "enum") {
arr := gjson.Get(jsonStr, p)
if !arr.IsArray() {
continue
}
items := arr.Array()
if len(items) <= 1 || len(items) > 10 {
continue
}
var vals []string
for _, item := range items {
vals = append(vals, item.String())
}
jsonStr = appendHint(jsonStr, trimSuffix(p, ".enum"), "Allowed: "+strings.Join(vals, ", "))
}
return jsonStr
}
func addAdditionalPropertiesHints(jsonStr string) string {
for _, p := range findPaths(jsonStr, "additionalProperties") {
if gjson.Get(jsonStr, p).Type == gjson.False {
jsonStr = appendHint(jsonStr, trimSuffix(p, ".additionalProperties"), "No extra properties allowed")
}
}
return jsonStr
}
var unsupportedConstraints = []string{
"minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum",
"pattern", "minItems", "maxItems", "format",
"default", "examples", // Claude rejects these in VALIDATED mode
}
func moveConstraintsToDescription(jsonStr string) string {
pathsByField := findPathsByFields(jsonStr, unsupportedConstraints)
for _, key := range unsupportedConstraints {
for _, p := range pathsByField[key] {
val := gjson.Get(jsonStr, p)
if !val.Exists() || val.IsObject() || val.IsArray() {
continue
}
parentPath := trimSuffix(p, "."+key)
if isPropertyDefinition(parentPath) {
continue
}
jsonStr = appendHint(jsonStr, parentPath, fmt.Sprintf("%s: %s", key, val.String()))
}
}
return jsonStr
}
func mergeAllOf(jsonStr string) string {
paths := findPaths(jsonStr, "allOf")
sortByDepth(paths)
for _, p := range paths {
allOf := gjson.Get(jsonStr, p)
if !allOf.IsArray() {
continue
}
parentPath := trimSuffix(p, ".allOf")
for _, item := range allOf.Array() {
if props := item.Get("properties"); props.IsObject() {
props.ForEach(func(key, value gjson.Result) bool {
destPath := joinPath(parentPath, "properties."+escapeGJSONPathKey(key.String()))
jsonStr, _ = sjson.SetRaw(jsonStr, destPath, value.Raw)
return true
})
}
if req := item.Get("required"); req.IsArray() {
reqPath := joinPath(parentPath, "required")
current := getStrings(jsonStr, reqPath)
for _, r := range req.Array() {
if s := r.String(); !contains(current, s) {
current = append(current, s)
}
}
jsonStr, _ = sjson.Set(jsonStr, reqPath, current)
}
}
jsonStr, _ = sjson.Delete(jsonStr, p)
}
return jsonStr
}
func flattenAnyOfOneOf(jsonStr string) string {
for _, key := range []string{"anyOf", "oneOf"} {
paths := findPaths(jsonStr, key)
sortByDepth(paths)
for _, p := range paths {
arr := gjson.Get(jsonStr, p)
if !arr.IsArray() || len(arr.Array()) == 0 {
continue
}
parentPath := trimSuffix(p, "."+key)
parentDesc := gjson.Get(jsonStr, descriptionPath(parentPath)).String()
items := arr.Array()
bestIdx, allTypes := selectBest(items)
selected := items[bestIdx].Raw
if parentDesc != "" {
selected = mergeDescriptionRaw(selected, parentDesc)
}
if len(allTypes) > 1 {
hint := "Accepts: " + strings.Join(allTypes, " | ")
selected = appendHintRaw(selected, hint)
}
jsonStr = setRawAt(jsonStr, parentPath, selected)
}
}
return jsonStr
}
func selectBest(items []gjson.Result) (bestIdx int, types []string) {
bestScore := -1
for i, item := range items {
t := item.Get("type").String()
score := 0
switch {
case t == "object" || item.Get("properties").Exists():
score, t = 3, orDefault(t, "object")
case t == "array" || item.Get("items").Exists():
score, t = 2, orDefault(t, "array")
case t != "" && t != "null":
score = 1
default:
t = orDefault(t, "null")
}
if t != "" {
types = append(types, t)
}
if score > bestScore {
bestScore, bestIdx = score, i
}
}
return
}
func flattenTypeArrays(jsonStr string) string {
paths := findPaths(jsonStr, "type")
sortByDepth(paths)
nullableFields := make(map[string][]string)
for _, p := range paths {
res := gjson.Get(jsonStr, p)
if !res.IsArray() || len(res.Array()) == 0 {
continue
}
hasNull := false
var nonNullTypes []string
for _, item := range res.Array() {
s := item.String()
if s == "null" {
hasNull = true
} else if s != "" {
nonNullTypes = append(nonNullTypes, s)
}
}
firstType := "string"
if len(nonNullTypes) > 0 {
firstType = nonNullTypes[0]
}
jsonStr, _ = sjson.Set(jsonStr, p, firstType)
parentPath := trimSuffix(p, ".type")
if len(nonNullTypes) > 1 {
hint := "Accepts: " + strings.Join(nonNullTypes, " | ")
jsonStr = appendHint(jsonStr, parentPath, hint)
}
if hasNull {
parts := splitGJSONPath(p)
if len(parts) >= 3 && parts[len(parts)-3] == "properties" {
fieldNameEscaped := parts[len(parts)-2]
fieldName := unescapeGJSONPathKey(fieldNameEscaped)
objectPath := strings.Join(parts[:len(parts)-3], ".")
nullableFields[objectPath] = append(nullableFields[objectPath], fieldName)
propPath := joinPath(objectPath, "properties."+fieldNameEscaped)
jsonStr = appendHint(jsonStr, propPath, "(nullable)")
}
}
}
for objectPath, fields := range nullableFields {
reqPath := joinPath(objectPath, "required")
req := gjson.Get(jsonStr, reqPath)
if !req.IsArray() {
continue
}
var filtered []string
for _, r := range req.Array() {
if !contains(fields, r.String()) {
filtered = append(filtered, r.String())
}
}
if len(filtered) == 0 {
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
} else {
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
}
}
return jsonStr
}
func removeUnsupportedKeywords(jsonStr string) string {
keywords := append(unsupportedConstraints,
"$schema", "$defs", "definitions", "const", "$ref", "$id", "additionalProperties",
"propertyNames", "patternProperties", // Gemini doesn't support these schema keywords
"enumTitles", "prefill", "deprecated", // Schema metadata fields unsupported by Gemini
)
deletePaths := make([]string, 0)
pathsByField := findPathsByFields(jsonStr, keywords)
for _, key := range keywords {
for _, p := range pathsByField[key] {
if isPropertyDefinition(trimSuffix(p, "."+key)) {
continue
}
deletePaths = append(deletePaths, p)
}
}
sortByDepth(deletePaths)
for _, p := range deletePaths {
jsonStr, _ = sjson.Delete(jsonStr, p)
}
// Remove x-* extension fields (e.g., x-google-enum-descriptions) that are not supported by Gemini API
jsonStr = removeExtensionFields(jsonStr)
return jsonStr
}
// removeExtensionFields removes all x-* extension fields from the JSON schema.
// These are OpenAPI/JSON Schema extension fields that Google APIs don't recognize.
func removeExtensionFields(jsonStr string) string {
var paths []string
walkForExtensions(gjson.Parse(jsonStr), "", &paths)
// walkForExtensions returns paths in a way that deeper paths are added before their ancestors
// when they are not deleted wholesale, but since we skip children of deleted x-* nodes,
// any collected path is safe to delete. We still use DeleteBytes for efficiency.
b := []byte(jsonStr)
for _, p := range paths {
b, _ = sjson.DeleteBytes(b, p)
}
return string(b)
}
func walkForExtensions(value gjson.Result, path string, paths *[]string) {
if value.IsArray() {
arr := value.Array()
for i := len(arr) - 1; i >= 0; i-- {
walkForExtensions(arr[i], joinPath(path, strconv.Itoa(i)), paths)
}
return
}
if value.IsObject() {
value.ForEach(func(key, val gjson.Result) bool {
keyStr := key.String()
safeKey := escapeGJSONPathKey(keyStr)
childPath := joinPath(path, safeKey)
// If it's an extension field, we delete it and don't need to look at its children.
if strings.HasPrefix(keyStr, "x-") && !isPropertyDefinition(path) {
*paths = append(*paths, childPath)
return true
}
walkForExtensions(val, childPath, paths)
return true
})
}
}
func cleanupRequiredFields(jsonStr string) string {
for _, p := range findPaths(jsonStr, "required") {
parentPath := trimSuffix(p, ".required")
propsPath := joinPath(parentPath, "properties")
req := gjson.Get(jsonStr, p)
props := gjson.Get(jsonStr, propsPath)
if !req.IsArray() || !props.IsObject() {
continue
}
var valid []string
for _, r := range req.Array() {
key := r.String()
if props.Get(escapeGJSONPathKey(key)).Exists() {
valid = append(valid, key)
}
}
if len(valid) != len(req.Array()) {
if len(valid) == 0 {
jsonStr, _ = sjson.Delete(jsonStr, p)
} else {
jsonStr, _ = sjson.Set(jsonStr, p, valid)
}
}
}
return jsonStr
}
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
// Claude VALIDATED mode requires at least one required property in tool schemas.
func addEmptySchemaPlaceholder(jsonStr string) string {
// Find all "type" fields
paths := findPaths(jsonStr, "type")
// Process from deepest to shallowest (to handle nested objects properly)
sortByDepth(paths)
for _, p := range paths {
typeVal := gjson.Get(jsonStr, p)
if typeVal.String() != "object" {
continue
}
// Get the parent path (the object containing "type")
parentPath := trimSuffix(p, ".type")
// Check if properties exists and is empty or missing
propsPath := joinPath(parentPath, "properties")
propsVal := gjson.Get(jsonStr, propsPath)
reqPath := joinPath(parentPath, "required")
reqVal := gjson.Get(jsonStr, reqPath)
hasRequiredProperties := reqVal.IsArray() && len(reqVal.Array()) > 0
needsPlaceholder := false
if !propsVal.Exists() {
// No properties field at all
needsPlaceholder = true
} else if propsVal.IsObject() && len(propsVal.Map()) == 0 {
// Empty properties object
needsPlaceholder = true
}
if needsPlaceholder {
// Add placeholder "reason" property
reasonPath := joinPath(propsPath, "reason")
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", placeholderReasonDescription)
// Add to required array
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
continue
}
// If schema has properties but none are required, add a minimal placeholder.
if propsVal.IsObject() && !hasRequiredProperties {
// DO NOT add placeholder if it's a top-level schema (parentPath is empty)
// or if we've already added a placeholder reason above.
if parentPath == "" {
continue
}
placeholderPath := joinPath(propsPath, "_")
if !gjson.Get(jsonStr, placeholderPath).Exists() {
jsonStr, _ = sjson.Set(jsonStr, placeholderPath+".type", "boolean")
}
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"_"})
}
}
return jsonStr
}
// --- Helpers ---
func findPaths(jsonStr, field string) []string {
var paths []string
Walk(gjson.Parse(jsonStr), "", field, &paths)
return paths
}
func findPathsByFields(jsonStr string, fields []string) map[string][]string {
set := make(map[string]struct{}, len(fields))
for _, field := range fields {
set[field] = struct{}{}
}
paths := make(map[string][]string, len(set))
walkForFields(gjson.Parse(jsonStr), "", set, paths)
return paths
}
func walkForFields(value gjson.Result, path string, fields map[string]struct{}, paths map[string][]string) {
switch value.Type {
case gjson.JSON:
value.ForEach(func(key, val gjson.Result) bool {
keyStr := key.String()
safeKey := escapeGJSONPathKey(keyStr)
var childPath string
if path == "" {
childPath = safeKey
} else {
childPath = path + "." + safeKey
}
if _, ok := fields[keyStr]; ok {
paths[keyStr] = append(paths[keyStr], childPath)
}
walkForFields(val, childPath, fields, paths)
return true
})
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
// Terminal types - no further traversal needed
}
}
func sortByDepth(paths []string) {
sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) })
}
func trimSuffix(path, suffix string) string {
if path == strings.TrimPrefix(suffix, ".") {
return ""
}
return strings.TrimSuffix(path, suffix)
}
func joinPath(base, suffix string) string {
if base == "" {
return suffix
}
return base + "." + suffix
}
func setRawAt(jsonStr, path, value string) string {
if path == "" {
return value
}
result, _ := sjson.SetRaw(jsonStr, path, value)
return result
}
func isPropertyDefinition(path string) bool {
return path == "properties" || strings.HasSuffix(path, ".properties")
}
func descriptionPath(parentPath string) string {
if parentPath == "" || parentPath == "@this" {
return "description"
}
return parentPath + ".description"
}
func appendHint(jsonStr, parentPath, hint string) string {
descPath := parentPath + ".description"
if parentPath == "" || parentPath == "@this" {
descPath = "description"
}
existing := gjson.Get(jsonStr, descPath).String()
if existing != "" {
hint = fmt.Sprintf("%s (%s)", existing, hint)
}
jsonStr, _ = sjson.Set(jsonStr, descPath, hint)
return jsonStr
}
func appendHintRaw(jsonRaw, hint string) string {
existing := gjson.Get(jsonRaw, "description").String()
if existing != "" {
hint = fmt.Sprintf("%s (%s)", existing, hint)
}
jsonRaw, _ = sjson.Set(jsonRaw, "description", hint)
return jsonRaw
}
func getStrings(jsonStr, path string) []string {
var result []string
if arr := gjson.Get(jsonStr, path); arr.IsArray() {
for _, r := range arr.Array() {
result = append(result, r.String())
}
}
return result
}
func contains(slice []string, item string) bool {
for _, s := range slice {
if s == item {
return true
}
}
return false
}
func orDefault(val, def string) string {
if val == "" {
return def
}
return val
}
func escapeGJSONPathKey(key string) string {
if strings.IndexAny(key, ".*?") == -1 {
return key
}
return gjsonPathKeyReplacer.Replace(key)
}
func unescapeGJSONPathKey(key string) string {
if !strings.Contains(key, "\\") {
return key
}
var b strings.Builder
b.Grow(len(key))
for i := 0; i < len(key); i++ {
if key[i] == '\\' && i+1 < len(key) {
i++
b.WriteByte(key[i])
continue
}
b.WriteByte(key[i])
}
return b.String()
}
func splitGJSONPath(path string) []string {
if path == "" {
return nil
}
parts := make([]string, 0, strings.Count(path, ".")+1)
var b strings.Builder
b.Grow(len(path))
for i := 0; i < len(path); i++ {
c := path[i]
if c == '\\' && i+1 < len(path) {
b.WriteByte('\\')
i++
b.WriteByte(path[i])
continue
}
if c == '.' {
parts = append(parts, b.String())
b.Reset()
continue
}
b.WriteByte(c)
}
parts = append(parts, b.String())
return parts
}
func mergeDescriptionRaw(schemaRaw, parentDesc string) string {
childDesc := gjson.Get(schemaRaw, "description").String()
switch {
case childDesc == "":
schemaRaw, _ = sjson.Set(schemaRaw, "description", parentDesc)
return schemaRaw
case childDesc == parentDesc:
return schemaRaw
default:
combined := fmt.Sprintf("%s (%s)", parentDesc, childDesc)
schemaRaw, _ = sjson.Set(schemaRaw, "description", combined)
return schemaRaw
}
}
================================================
FILE: internal/util/gemini_schema_test.go
================================================
package util
import (
"encoding/json"
"reflect"
"strings"
"testing"
"github.com/tidwall/gjson"
)
func TestCleanJSONSchemaForAntigravity_ConstToEnum(t *testing.T) {
input := `{
"type": "object",
"properties": {
"kind": {
"type": "string",
"const": "InsightVizNode"
}
}
}`
expected := `{
"type": "object",
"properties": {
"kind": {
"type": "string",
"enum": ["InsightVizNode"]
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable(t *testing.T) {
input := `{
"type": "object",
"properties": {
"name": {
"type": ["string", "null"]
},
"other": {
"type": "string"
}
},
"required": ["name", "other"]
}`
expected := `{
"type": "object",
"properties": {
"name": {
"type": "string",
"description": "(nullable)"
},
"other": {
"type": "string"
}
},
"required": ["other"]
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_ConstraintsToDescription(t *testing.T) {
input := `{
"type": "object",
"properties": {
"tags": {
"type": "array",
"description": "List of tags",
"minItems": 1
},
"name": {
"type": "string",
"description": "User name",
"minLength": 3
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// minItems should be REMOVED and moved to description
if strings.Contains(result, `"minItems"`) {
t.Errorf("minItems keyword should be removed")
}
if !strings.Contains(result, "minItems: 1") {
t.Errorf("minItems hint missing in description")
}
// minLength should be moved to description
if !strings.Contains(result, "minLength: 3") {
t.Errorf("minLength hint missing in description")
}
if strings.Contains(result, `"minLength":`) || strings.Contains(result, `"minLength" :`) {
t.Errorf("minLength keyword should be removed")
}
}
func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_SmartSelection(t *testing.T) {
input := `{
"type": "object",
"properties": {
"query": {
"anyOf": [
{ "type": "null" },
{
"type": "object",
"properties": {
"kind": { "type": "string" }
}
}
]
}
}
}`
expected := `{
"type": "object",
"properties": {
"query": {
"type": "object",
"description": "Accepts: null | object",
"properties": {
"_": { "type": "boolean" },
"kind": { "type": "string" }
},
"required": ["_"]
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_OneOfFlattening(t *testing.T) {
input := `{
"type": "object",
"properties": {
"config": {
"oneOf": [
{ "type": "string" },
{ "type": "integer" }
]
}
}
}`
expected := `{
"type": "object",
"properties": {
"config": {
"type": "string",
"description": "Accepts: string | integer"
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_AllOfMerging(t *testing.T) {
input := `{
"type": "object",
"allOf": [
{
"properties": {
"a": { "type": "string" }
},
"required": ["a"]
},
{
"properties": {
"b": { "type": "integer" }
},
"required": ["b"]
}
]
}`
expected := `{
"type": "object",
"properties": {
"a": { "type": "string" },
"b": { "type": "integer" }
},
"required": ["a", "b"]
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_RefHandling(t *testing.T) {
input := `{
"definitions": {
"User": {
"type": "object",
"properties": {
"name": { "type": "string" }
}
}
},
"type": "object",
"properties": {
"customer": { "$ref": "#/definitions/User" }
}
}`
// After $ref is converted to placeholder object, empty schema placeholder is also added
expected := `{
"type": "object",
"properties": {
"customer": {
"type": "object",
"description": "See: User",
"properties": {
"reason": {
"type": "string",
"description": "Brief explanation of why you are calling this tool"
}
},
"required": ["reason"]
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_RefHandling_DescriptionEscaping(t *testing.T) {
input := `{
"definitions": {
"User": {
"type": "object",
"properties": {
"name": { "type": "string" }
}
}
},
"type": "object",
"properties": {
"customer": {
"description": "He said \"hi\"\\nsecond line",
"$ref": "#/definitions/User"
}
}
}`
// After $ref is converted, empty schema placeholder is also added
expected := `{
"type": "object",
"properties": {
"customer": {
"type": "object",
"description": "He said \"hi\"\\nsecond line (See: User)",
"properties": {
"reason": {
"type": "string",
"description": "Brief explanation of why you are calling this tool"
}
},
"required": ["reason"]
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_CyclicRefDefaults(t *testing.T) {
input := `{
"definitions": {
"Node": {
"type": "object",
"properties": {
"child": { "$ref": "#/definitions/Node" }
}
}
},
"$ref": "#/definitions/Node"
}`
result := CleanJSONSchemaForAntigravity(input)
var resMap map[string]interface{}
json.Unmarshal([]byte(result), &resMap)
if resMap["type"] != "object" {
t.Errorf("Expected type: object, got: %v", resMap["type"])
}
desc, ok := resMap["description"].(string)
if !ok || !strings.Contains(desc, "Node") {
t.Errorf("Expected description hint containing 'Node', got: %v", resMap["description"])
}
}
func TestCleanJSONSchemaForAntigravity_RequiredCleanup(t *testing.T) {
input := `{
"type": "object",
"properties": {
"a": {"type": "string"},
"b": {"type": "string"}
},
"required": ["a", "b", "c"]
}`
expected := `{
"type": "object",
"properties": {
"a": {"type": "string"},
"b": {"type": "string"}
},
"required": ["a", "b"]
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_AllOfMerging_DotKeys(t *testing.T) {
input := `{
"type": "object",
"allOf": [
{
"properties": {
"my.param": { "type": "string" }
},
"required": ["my.param"]
},
{
"properties": {
"b": { "type": "integer" }
},
"required": ["b"]
}
]
}`
expected := `{
"type": "object",
"properties": {
"my.param": { "type": "string" },
"b": { "type": "integer" }
},
"required": ["my.param", "b"]
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_PropertyNameCollision(t *testing.T) {
// A tool has an argument named "pattern" - should NOT be treated as a constraint
input := `{
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "The regex pattern"
}
},
"required": ["pattern"]
}`
expected := `{
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "The regex pattern"
}
},
"required": ["pattern"]
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
var resMap map[string]interface{}
json.Unmarshal([]byte(result), &resMap)
props, _ := resMap["properties"].(map[string]interface{})
if _, ok := props["description"]; ok {
t.Errorf("Invalid 'description' property injected into properties map")
}
}
func TestCleanJSONSchemaForAntigravity_DotKeys(t *testing.T) {
input := `{
"type": "object",
"properties": {
"my.param": {
"type": "string",
"$ref": "#/definitions/MyType"
}
},
"definitions": {
"MyType": { "type": "string" }
}
}`
result := CleanJSONSchemaForAntigravity(input)
var resMap map[string]interface{}
if err := json.Unmarshal([]byte(result), &resMap); err != nil {
t.Fatalf("Failed to unmarshal result: %v", err)
}
props, ok := resMap["properties"].(map[string]interface{})
if !ok {
t.Fatalf("properties missing")
}
if val, ok := props["my.param"]; !ok {
t.Fatalf("Key 'my.param' is missing. Result: %s", result)
} else {
valMap, _ := val.(map[string]interface{})
if _, hasRef := valMap["$ref"]; hasRef {
t.Errorf("Key 'my.param' still contains $ref")
}
if _, ok := props["my"]; ok {
t.Errorf("Artifact key 'my' created by sjson splitting")
}
}
}
func TestCleanJSONSchemaForAntigravity_AnyOfAlternativeHints(t *testing.T) {
input := `{
"type": "object",
"properties": {
"value": {
"anyOf": [
{ "type": "string" },
{ "type": "integer" },
{ "type": "null" }
]
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "Accepts:") {
t.Errorf("Expected alternative types hint, got: %s", result)
}
if !strings.Contains(result, "string") || !strings.Contains(result, "integer") {
t.Errorf("Expected all alternative types in hint, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_NullableHint(t *testing.T) {
input := `{
"type": "object",
"properties": {
"name": {
"type": ["string", "null"],
"description": "User name"
}
},
"required": ["name"]
}`
result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "(nullable)") {
t.Errorf("Expected nullable hint, got: %s", result)
}
if !strings.Contains(result, "User name") {
t.Errorf("Expected original description to be preserved, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable_DotKey(t *testing.T) {
input := `{
"type": "object",
"properties": {
"my.param": {
"type": ["string", "null"]
},
"other": {
"type": "string"
}
},
"required": ["my.param", "other"]
}`
expected := `{
"type": "object",
"properties": {
"my.param": {
"type": "string",
"description": "(nullable)"
},
"other": {
"type": "string"
}
},
"required": ["other"]
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_EnumHint(t *testing.T) {
input := `{
"type": "object",
"properties": {
"status": {
"type": "string",
"enum": ["active", "inactive", "pending"],
"description": "Current status"
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "Allowed:") {
t.Errorf("Expected enum values hint, got: %s", result)
}
if !strings.Contains(result, "active") || !strings.Contains(result, "inactive") {
t.Errorf("Expected enum values in hint, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_AdditionalPropertiesHint(t *testing.T) {
input := `{
"type": "object",
"properties": {
"name": { "type": "string" }
},
"additionalProperties": false
}`
result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "No extra properties allowed") {
t.Errorf("Expected additionalProperties hint, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_PreservesDescription(t *testing.T) {
input := `{
"type": "object",
"properties": {
"config": {
"description": "Parent desc",
"anyOf": [
{ "type": "string", "description": "Child desc" },
{ "type": "integer" }
]
}
}
}`
expected := `{
"type": "object",
"properties": {
"config": {
"type": "string",
"description": "Parent desc (Child desc) (Accepts: string | integer)"
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result)
}
func TestCleanJSONSchemaForAntigravity_SingleEnumNoHint(t *testing.T) {
input := `{
"type": "object",
"properties": {
"kind": {
"type": "string",
"enum": ["fixed"]
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
if strings.Contains(result, "Allowed:") {
t.Errorf("Single value enum should not add Allowed hint, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
input := `{
"type": "object",
"properties": {
"value": {
"type": ["string", "integer", "boolean"]
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "Accepts:") {
t.Errorf("Expected multiple types hint, got: %s", result)
}
if !strings.Contains(result, "string") || !strings.Contains(result, "integer") || !strings.Contains(result, "boolean") {
t.Errorf("Expected all types in hint, got: %s", result)
}
}
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
var expMap, actMap map[string]interface{}
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
errAct := json.Unmarshal([]byte(actualJSON), &actMap)
if errExp != nil || errAct != nil {
t.Fatalf("JSON Unmarshal error. Exp: %v, Act: %v", errExp, errAct)
}
if !reflect.DeepEqual(expMap, actMap) {
expBytes, _ := json.MarshalIndent(expMap, "", " ")
actBytes, _ := json.MarshalIndent(actMap, "", " ")
t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes))
}
}
// ============================================================================
// Empty Schema Placeholder Tests
// ============================================================================
func TestCleanJSONSchemaForAntigravity_EmptySchemaPlaceholder(t *testing.T) {
// Empty object schema with no properties should get a placeholder
input := `{
"type": "object"
}`
result := CleanJSONSchemaForAntigravity(input)
// Should have placeholder property added
if !strings.Contains(result, `"reason"`) {
t.Errorf("Empty schema should have 'reason' placeholder property, got: %s", result)
}
if !strings.Contains(result, `"required"`) {
t.Errorf("Empty schema should have 'required' with 'reason', got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_EmptyPropertiesPlaceholder(t *testing.T) {
// Object with empty properties object
input := `{
"type": "object",
"properties": {}
}`
result := CleanJSONSchemaForAntigravity(input)
// Should have placeholder property added
if !strings.Contains(result, `"reason"`) {
t.Errorf("Empty properties should have 'reason' placeholder, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_NonEmptySchemaUnchanged(t *testing.T) {
// Schema with properties should NOT get placeholder
input := `{
"type": "object",
"properties": {
"name": {"type": "string"}
},
"required": ["name"]
}`
result := CleanJSONSchemaForAntigravity(input)
// Should NOT have placeholder property
if strings.Contains(result, `"reason"`) {
t.Errorf("Non-empty schema should NOT have 'reason' placeholder, got: %s", result)
}
// Original properties should be preserved
if !strings.Contains(result, `"name"`) {
t.Errorf("Original property 'name' should be preserved, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_NestedEmptySchema(t *testing.T) {
// Nested empty object in items should also get placeholder
input := `{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object"
}
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// Nested empty object should also get placeholder
// Check that the nested object has a reason property
parsed := gjson.Parse(result)
nestedProps := parsed.Get("properties.items.items.properties")
if !nestedProps.Exists() || !nestedProps.Get("reason").Exists() {
t.Errorf("Nested empty object should have 'reason' placeholder, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_EmptySchemaWithDescription(t *testing.T) {
// Empty schema with description should preserve description and add placeholder
input := `{
"type": "object",
"description": "An empty object"
}`
result := CleanJSONSchemaForAntigravity(input)
// Should have both description and placeholder
if !strings.Contains(result, `"An empty object"`) {
t.Errorf("Description should be preserved, got: %s", result)
}
if !strings.Contains(result, `"reason"`) {
t.Errorf("Empty schema should have 'reason' placeholder, got: %s", result)
}
}
// ============================================================================
// Format field handling (ad-hoc patch removal)
// ============================================================================
func TestCleanJSONSchemaForAntigravity_FormatFieldRemoval(t *testing.T) {
// format:"uri" should be removed and added as hint
input := `{
"type": "object",
"properties": {
"url": {
"type": "string",
"format": "uri",
"description": "A URL"
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// format should be removed
if strings.Contains(result, `"format"`) {
t.Errorf("format field should be removed, got: %s", result)
}
// hint should be added to description
if !strings.Contains(result, "format: uri") {
t.Errorf("format hint should be added to description, got: %s", result)
}
// original description should be preserved
if !strings.Contains(result, "A URL") {
t.Errorf("Original description should be preserved, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_FormatFieldNoDescription(t *testing.T) {
// format without description should create description with hint
input := `{
"type": "object",
"properties": {
"email": {
"type": "string",
"format": "email"
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// format should be removed
if strings.Contains(result, `"format"`) {
t.Errorf("format field should be removed, got: %s", result)
}
// hint should be added
if !strings.Contains(result, "format: email") {
t.Errorf("format hint should be added, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_MultipleFormats(t *testing.T) {
// Multiple format fields should all be handled
input := `{
"type": "object",
"properties": {
"url": {"type": "string", "format": "uri"},
"email": {"type": "string", "format": "email"},
"date": {"type": "string", "format": "date-time"}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// All format fields should be removed
if strings.Contains(result, `"format"`) {
t.Errorf("All format fields should be removed, got: %s", result)
}
// All hints should be added
if !strings.Contains(result, "format: uri") {
t.Errorf("uri format hint should be added, got: %s", result)
}
if !strings.Contains(result, "format: email") {
t.Errorf("email format hint should be added, got: %s", result)
}
if !strings.Contains(result, "format: date-time") {
t.Errorf("date-time format hint should be added, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_NumericEnumToString(t *testing.T) {
// Gemini API requires enum values to be strings, not numbers
input := `{
"type": "object",
"properties": {
"priority": {"type": "integer", "enum": [0, 1, 2]},
"level": {"type": "number", "enum": [1.5, 2.5, 3.5]},
"status": {"type": "string", "enum": ["active", "inactive"]}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// Numeric enum values should be converted to strings
if strings.Contains(result, `"enum":[0,1,2]`) {
t.Errorf("Integer enum values should be converted to strings, got: %s", result)
}
if strings.Contains(result, `"enum":[1.5,2.5,3.5]`) {
t.Errorf("Float enum values should be converted to strings, got: %s", result)
}
// Should contain string versions
if !strings.Contains(result, `"0"`) || !strings.Contains(result, `"1"`) || !strings.Contains(result, `"2"`) {
t.Errorf("Integer enum values should be converted to string format, got: %s", result)
}
// String enum values should remain unchanged
if !strings.Contains(result, `"active"`) || !strings.Contains(result, `"inactive"`) {
t.Errorf("String enum values should remain unchanged, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_BooleanEnumToString(t *testing.T) {
// Boolean enum values should also be converted to strings
input := `{
"type": "object",
"properties": {
"enabled": {"type": "boolean", "enum": [true, false]}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// Boolean enum values should be converted to strings
if strings.Contains(result, `"enum":[true,false]`) {
t.Errorf("Boolean enum values should be converted to strings, got: %s", result)
}
// Should contain string versions "true" and "false"
if !strings.Contains(result, `"true"`) || !strings.Contains(result, `"false"`) {
t.Errorf("Boolean enum values should be converted to string format, got: %s", result)
}
}
func TestCleanJSONSchemaForGemini_RemovesGeminiUnsupportedMetadataFields(t *testing.T) {
input := `{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "root-schema",
"type": "object",
"properties": {
"payload": {
"type": "object",
"prefill": "hello",
"properties": {
"mode": {
"type": "string",
"enum": ["a", "b"],
"enumTitles": ["A", "B"]
}
},
"patternProperties": {
"^x-": {"type": "string"}
}
},
"$id": {
"type": "string",
"description": "property name should not be removed"
}
}
}`
expected := `{
"type": "object",
"properties": {
"payload": {
"type": "object",
"properties": {
"mode": {
"type": "string",
"enum": ["a", "b"],
"description": "Allowed: a, b"
}
}
},
"$id": {
"type": "string",
"description": "property name should not be removed"
}
}
}`
result := CleanJSONSchemaForGemini(input)
compareJSON(t, expected, result)
}
func TestRemoveExtensionFields(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "removes x- fields at root",
input: `{
"type": "object",
"x-custom-meta": "value",
"properties": {
"foo": { "type": "string" }
}
}`,
expected: `{
"type": "object",
"properties": {
"foo": { "type": "string" }
}
}`,
},
{
name: "removes x- fields in nested properties",
input: `{
"type": "object",
"properties": {
"foo": {
"type": "string",
"x-internal-id": 123
}
}
}`,
expected: `{
"type": "object",
"properties": {
"foo": {
"type": "string"
}
}
}`,
},
{
name: "does NOT remove properties named x-",
input: `{
"type": "object",
"properties": {
"x-data": { "type": "string" },
"normal": { "type": "number", "x-meta": "remove" }
},
"required": ["x-data"]
}`,
expected: `{
"type": "object",
"properties": {
"x-data": { "type": "string" },
"normal": { "type": "number" }
},
"required": ["x-data"]
}`,
},
{
name: "does NOT remove $schema and other meta fields (as requested)",
input: `{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "test",
"type": "object",
"properties": {
"foo": { "type": "string" }
}
}`,
expected: `{
"$schema": "http://json-schema.org/draft-07/schema#",
"$id": "test",
"type": "object",
"properties": {
"foo": { "type": "string" }
}
}`,
},
{
name: "handles properties named $schema",
input: `{
"type": "object",
"properties": {
"$schema": { "type": "string" }
}
}`,
expected: `{
"type": "object",
"properties": {
"$schema": { "type": "string" }
}
}`,
},
{
name: "handles escaping in paths",
input: `{
"type": "object",
"properties": {
"foo.bar": {
"type": "string",
"x-meta": "remove"
}
},
"x-root.meta": "remove"
}`,
expected: `{
"type": "object",
"properties": {
"foo.bar": {
"type": "string"
}
}
}`,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual := removeExtensionFields(tt.input)
compareJSON(t, tt.expected, actual)
})
}
}
================================================
FILE: internal/util/header_helpers.go
================================================
package util
import (
"net/http"
"strings"
)
// ApplyCustomHeadersFromAttrs applies user-defined headers stored in the provided attributes map.
// Custom headers override built-in defaults when conflicts occur.
func ApplyCustomHeadersFromAttrs(r *http.Request, attrs map[string]string) {
if r == nil {
return
}
applyCustomHeaders(r, extractCustomHeaders(attrs))
}
func extractCustomHeaders(attrs map[string]string) map[string]string {
if len(attrs) == 0 {
return nil
}
headers := make(map[string]string)
for k, v := range attrs {
if !strings.HasPrefix(k, "header:") {
continue
}
name := strings.TrimSpace(strings.TrimPrefix(k, "header:"))
if name == "" {
continue
}
val := strings.TrimSpace(v)
if val == "" {
continue
}
headers[name] = val
}
if len(headers) == 0 {
return nil
}
return headers
}
func applyCustomHeaders(r *http.Request, headers map[string]string) {
if r == nil || len(headers) == 0 {
return
}
for k, v := range headers {
if k == "" || v == "" {
continue
}
r.Header.Set(k, v)
}
}
================================================
FILE: internal/util/image.go
================================================
package util
import (
"bytes"
"encoding/base64"
"image"
"image/draw"
"image/png"
)
func CreateWhiteImageBase64(aspectRatio string) (string, error) {
width := 1024
height := 1024
switch aspectRatio {
case "1:1":
width = 1024
height = 1024
case "2:3":
width = 832
height = 1248
case "3:2":
width = 1248
height = 832
case "3:4":
width = 864
height = 1184
case "4:3":
width = 1184
height = 864
case "4:5":
width = 896
height = 1152
case "5:4":
width = 1152
height = 896
case "9:16":
width = 768
height = 1344
case "16:9":
width = 1344
height = 768
case "21:9":
width = 1536
height = 672
}
img := image.NewRGBA(image.Rect(0, 0, width, height))
draw.Draw(img, img.Bounds(), image.White, image.Point{}, draw.Src)
var buf bytes.Buffer
if err := png.Encode(&buf, img); err != nil {
return "", err
}
base64String := base64.StdEncoding.EncodeToString(buf.Bytes())
return base64String, nil
}
================================================
FILE: internal/util/provider.go
================================================
// Package util provides utility functions used across the CLIProxyAPI application.
// These functions handle common tasks such as determining AI service providers
// from model names and managing HTTP proxies.
package util
import (
"net/url"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
log "github.com/sirupsen/logrus"
)
// GetProviderName determines all AI service providers capable of serving a registered model.
// It first queries the global model registry to retrieve the providers backing the supplied model name.
// When the model has not been registered yet, it falls back to legacy string heuristics to infer
// potential providers.
//
// Supported providers include (but are not limited to):
// - "gemini" for Google's Gemini family
// - "codex" for OpenAI GPT-compatible providers
// - "claude" for Anthropic models
// - "qwen" for Alibaba's Qwen models
// - "openai-compatibility" for external OpenAI-compatible providers
//
// Parameters:
// - modelName: The name of the model to identify providers for.
// - cfg: The application configuration containing OpenAI compatibility settings.
//
// Returns:
// - []string: All provider identifiers capable of serving the model, ordered by preference.
func GetProviderName(modelName string) []string {
if modelName == "" {
return nil
}
providers := make([]string, 0, 4)
seen := make(map[string]struct{})
appendProvider := func(name string) {
if name == "" {
return
}
if _, exists := seen[name]; exists {
return
}
seen[name] = struct{}{}
providers = append(providers, name)
}
for _, provider := range registry.GetGlobalRegistry().GetModelProviders(modelName) {
appendProvider(provider)
}
if len(providers) > 0 {
return providers
}
return providers
}
// ResolveAutoModel resolves the "auto" model name to an actual available model.
// It uses an empty handler type to get any available model from the registry.
//
// Parameters:
// - modelName: The model name to check (should be "auto")
//
// Returns:
// - string: The resolved model name, or the original if not "auto" or resolution fails
func ResolveAutoModel(modelName string) string {
if modelName != "auto" {
return modelName
}
// Use empty string as handler type to get any available model
firstModel, err := registry.GetGlobalRegistry().GetFirstAvailableModel("")
if err != nil {
log.Warnf("Failed to resolve 'auto' model: %v, falling back to original model name", err)
return modelName
}
log.Infof("Resolved 'auto' model to: %s", firstModel)
return firstModel
}
// IsOpenAICompatibilityAlias checks if the given model name is an alias
// configured for OpenAI compatibility routing.
//
// Parameters:
// - modelName: The model name to check
// - cfg: The application configuration containing OpenAI compatibility settings
//
// Returns:
// - bool: True if the model name is an OpenAI compatibility alias, false otherwise
func IsOpenAICompatibilityAlias(modelName string, cfg *config.Config) bool {
if cfg == nil {
return false
}
for _, compat := range cfg.OpenAICompatibility {
for _, model := range compat.Models {
if model.Alias == modelName {
return true
}
}
}
return false
}
// GetOpenAICompatibilityConfig returns the OpenAI compatibility configuration
// and model details for the given alias.
//
// Parameters:
// - alias: The model alias to find configuration for
// - cfg: The application configuration containing OpenAI compatibility settings
//
// Returns:
// - *config.OpenAICompatibility: The matching compatibility configuration, or nil if not found
// - *config.OpenAICompatibilityModel: The matching model configuration, or nil if not found
func GetOpenAICompatibilityConfig(alias string, cfg *config.Config) (*config.OpenAICompatibility, *config.OpenAICompatibilityModel) {
if cfg == nil {
return nil, nil
}
for _, compat := range cfg.OpenAICompatibility {
for _, model := range compat.Models {
if model.Alias == alias {
return &compat, &model
}
}
}
return nil, nil
}
// InArray checks if a string exists in a slice of strings.
// It iterates through the slice and returns true if the target string is found,
// otherwise it returns false.
//
// Parameters:
// - hystack: The slice of strings to search in
// - needle: The string to search for
//
// Returns:
// - bool: True if the string is found, false otherwise
func InArray(hystack []string, needle string) bool {
for _, item := range hystack {
if needle == item {
return true
}
}
return false
}
// HideAPIKey obscures an API key for logging purposes, showing only the first and last few characters.
//
// Parameters:
// - apiKey: The API key to hide.
//
// Returns:
// - string: The obscured API key.
func HideAPIKey(apiKey string) string {
if len(apiKey) > 8 {
return apiKey[:4] + "..." + apiKey[len(apiKey)-4:]
} else if len(apiKey) > 4 {
return apiKey[:2] + "..." + apiKey[len(apiKey)-2:]
} else if len(apiKey) > 2 {
return apiKey[:1] + "..." + apiKey[len(apiKey)-1:]
}
return apiKey
}
// maskAuthorizationHeader masks the Authorization header value while preserving the auth type prefix.
// Common formats: "Bearer ", "Basic ", "ApiKey ", etc.
// It preserves the prefix (e.g., "Bearer ") and only masks the token/credential part.
//
// Parameters:
// - value: The Authorization header value
//
// Returns:
// - string: The masked Authorization value with prefix preserved
func MaskAuthorizationHeader(value string) string {
parts := strings.SplitN(strings.TrimSpace(value), " ", 2)
if len(parts) < 2 {
return HideAPIKey(value)
}
return parts[0] + " " + HideAPIKey(parts[1])
}
// MaskSensitiveHeaderValue masks sensitive header values while preserving expected formats.
//
// Behavior by header key (case-insensitive):
// - "Authorization": Preserve the auth type prefix (e.g., "Bearer ") and mask only the credential part.
// - Headers containing "api-key": Mask the entire value using HideAPIKey.
// - Others: Return the original value unchanged.
//
// Parameters:
// - key: The HTTP header name to inspect (case-insensitive matching).
// - value: The header value to mask when sensitive.
//
// Returns:
// - string: The masked value according to the header type; unchanged if not sensitive.
func MaskSensitiveHeaderValue(key, value string) string {
lowerKey := strings.ToLower(strings.TrimSpace(key))
switch {
case strings.Contains(lowerKey, "authorization"):
return MaskAuthorizationHeader(value)
case strings.Contains(lowerKey, "api-key"),
strings.Contains(lowerKey, "apikey"),
strings.Contains(lowerKey, "token"),
strings.Contains(lowerKey, "secret"):
return HideAPIKey(value)
default:
return value
}
}
// MaskSensitiveQuery masks sensitive query parameters, e.g. auth_token, within the raw query string.
func MaskSensitiveQuery(raw string) string {
if raw == "" {
return ""
}
parts := strings.Split(raw, "&")
changed := false
for i, part := range parts {
if part == "" {
continue
}
keyPart := part
valuePart := ""
if idx := strings.Index(part, "="); idx >= 0 {
keyPart = part[:idx]
valuePart = part[idx+1:]
}
decodedKey, err := url.QueryUnescape(keyPart)
if err != nil {
decodedKey = keyPart
}
if !shouldMaskQueryParam(decodedKey) {
continue
}
decodedValue, err := url.QueryUnescape(valuePart)
if err != nil {
decodedValue = valuePart
}
masked := HideAPIKey(strings.TrimSpace(decodedValue))
parts[i] = keyPart + "=" + url.QueryEscape(masked)
changed = true
}
if !changed {
return raw
}
return strings.Join(parts, "&")
}
func shouldMaskQueryParam(key string) bool {
key = strings.ToLower(strings.TrimSpace(key))
if key == "" {
return false
}
key = strings.TrimSuffix(key, "[]")
if key == "key" || strings.Contains(key, "api-key") || strings.Contains(key, "apikey") || strings.Contains(key, "api_key") {
return true
}
if strings.Contains(key, "token") || strings.Contains(key, "secret") {
return true
}
return false
}
================================================
FILE: internal/util/proxy.go
================================================
// Package util provides utility functions for the CLI Proxy API server.
// It includes helper functions for proxy configuration, HTTP client setup,
// log level management, and other common operations used across the application.
package util
import (
"net/http"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
)
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
// to route requests through the configured proxy server.
func SetProxy(cfg *config.SDKConfig, httpClient *http.Client) *http.Client {
if cfg == nil || httpClient == nil {
return httpClient
}
transport, _, errBuild := proxyutil.BuildHTTPTransport(cfg.ProxyURL)
if errBuild != nil {
log.Errorf("%v", errBuild)
}
if transport != nil {
httpClient.Transport = transport
}
return httpClient
}
================================================
FILE: internal/util/sanitize_test.go
================================================
package util
import (
"testing"
)
func TestSanitizeFunctionName(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{"Normal", "valid_name", "valid_name"},
{"With Dots", "name.with.dots", "name.with.dots"},
{"With Colons", "name:with:colons", "name:with:colons"},
{"With Dashes", "name-with-dashes", "name-with-dashes"},
{"Mixed Allowed", "name.with_dots:colons-dashes", "name.with_dots:colons-dashes"},
{"Invalid Characters", "name!with@invalid#chars", "name_with_invalid_chars"},
{"Spaces", "name with spaces", "name_with_spaces"},
{"Non-ASCII", "name_with_你好_chars", "name_with____chars"},
{"Starts with digit", "123name", "_123name"},
{"Starts with dot", ".name", "_.name"},
{"Starts with colon", ":name", "_:name"},
{"Starts with dash", "-name", "_-name"},
{"Starts with invalid char", "!name", "_name"},
{"Exactly 64 chars", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
{"Too long (65 chars)", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charactX", "this_is_a_very_long_name_that_exactly_reaches_sixty_four_charact"},
{"Very long", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_limit_for_function_names", "this_is_a_very_long_name_that_exceeds_the_sixty_four_character_l"},
{"Starts with digit (64 chars total)", "1234567890123456789012345678901234567890123456789012345678901234", "_123456789012345678901234567890123456789012345678901234567890123"},
{"Starts with invalid char (64 chars total)", "!234567890123456789012345678901234567890123456789012345678901234", "_234567890123456789012345678901234567890123456789012345678901234"},
{"Empty", "", ""},
{"Single character invalid", "@", "_"},
{"Single character valid", "a", "a"},
{"Single character digit", "1", "_1"},
{"Single character underscore", "_", "_"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := SanitizeFunctionName(tt.input)
if got != tt.expected {
t.Errorf("SanitizeFunctionName(%q) = %v, want %v", tt.input, got, tt.expected)
}
// Verify Gemini compliance
if len(got) > 64 {
t.Errorf("SanitizeFunctionName(%q) result too long: %d", tt.input, len(got))
}
if len(got) > 0 {
first := got[0]
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
t.Errorf("SanitizeFunctionName(%q) result starts with invalid char: %c", tt.input, first)
}
}
})
}
}
================================================
FILE: internal/util/ssh_helper.go
================================================
// Package util provides helper functions for SSH tunnel instructions and network-related tasks.
// This includes detecting the appropriate IP address and printing commands
// to help users connect to the local server from a remote machine.
package util
import (
"context"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
var ipServices = []string{
"https://api.ipify.org",
"https://ifconfig.me/ip",
"https://icanhazip.com",
"https://ipinfo.io/ip",
}
// getPublicIP attempts to retrieve the public IP address from a list of external services.
// It iterates through the ipServices and returns the first successful response.
//
// Returns:
// - string: The public IP address as a string
// - error: An error if all services fail, nil otherwise
func getPublicIP() (string, error) {
for _, service := range ipServices {
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", service, nil)
if err != nil {
log.Debugf("Failed to create request to %s: %v", service, err)
continue
}
resp, err := http.DefaultClient.Do(req)
if err != nil {
log.Debugf("Failed to get public IP from %s: %v", service, err)
continue
}
defer func() {
if closeErr := resp.Body.Close(); closeErr != nil {
log.Warnf("Failed to close response body from %s: %v", service, closeErr)
}
}()
if resp.StatusCode != http.StatusOK {
log.Debugf("bad status code from %s: %d", service, resp.StatusCode)
continue
}
ip, err := io.ReadAll(resp.Body)
if err != nil {
log.Debugf("Failed to read response body from %s: %v", service, err)
continue
}
return strings.TrimSpace(string(ip)), nil
}
return "", fmt.Errorf("all IP services failed")
}
// getOutboundIP retrieves the preferred outbound IP address of this machine.
// It uses a UDP connection to a public DNS server to determine the local IP
// address that would be used for outbound traffic.
//
// Returns:
// - string: The outbound IP address as a string
// - error: An error if the IP address cannot be determined, nil otherwise
func getOutboundIP() (string, error) {
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil {
return "", err
}
defer func() {
if closeErr := conn.Close(); closeErr != nil {
log.Warnf("Failed to close UDP connection: %v", closeErr)
}
}()
localAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
return "", fmt.Errorf("could not assert UDP address type")
}
return localAddr.IP.String(), nil
}
// GetIPAddress attempts to find the best-available IP address.
// It first tries to get the public IP address, and if that fails,
// it falls back to getting the local outbound IP address.
//
// Returns:
// - string: The determined IP address (preferring public IPv4)
func GetIPAddress() string {
publicIP, err := getPublicIP()
if err == nil {
log.Debugf("Public IP detected: %s", publicIP)
return publicIP
}
log.Warnf("Failed to get public IP, falling back to outbound IP: %v", err)
outboundIP, err := getOutboundIP()
if err == nil {
log.Debugf("Outbound IP detected: %s", outboundIP)
return outboundIP
}
log.Errorf("Failed to get any IP address: %v", err)
return "127.0.0.1" // Fallback
}
// PrintSSHTunnelInstructions detects the IP address and prints SSH tunnel instructions
// for the user to connect to the local OAuth callback server from a remote machine.
//
// Parameters:
// - port: The local port number for the SSH tunnel
func PrintSSHTunnelInstructions(port int) {
ipAddress := GetIPAddress()
border := "================================================================================"
fmt.Println("To authenticate from a remote machine, an SSH tunnel may be required.")
fmt.Println(border)
fmt.Println(" Run one of the following commands on your local machine (NOT the server):")
fmt.Println()
fmt.Printf(" # Standard SSH command (assumes SSH port 22):\n")
fmt.Printf(" ssh -L %d:127.0.0.1:%d root@%s -p 22\n", port, port, ipAddress)
fmt.Println()
fmt.Printf(" # If using an SSH key (assumes SSH port 22):\n")
fmt.Printf(" ssh -i -L %d:127.0.0.1:%d root@%s -p 22\n", port, port, ipAddress)
fmt.Println()
fmt.Println(" NOTE: If your server's SSH port is not 22, please modify the '-p 22' part accordingly.")
fmt.Println(border)
}
================================================
FILE: internal/util/translator.go
================================================
// Package util provides utility functions for the CLI Proxy API server.
// It includes helper functions for JSON manipulation, proxy configuration,
// and other common operations used across the application.
package util
import (
"bytes"
"fmt"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// Walk recursively traverses a JSON structure to find all occurrences of a specific field.
// It builds paths to each occurrence and adds them to the provided paths slice.
//
// Parameters:
// - value: The gjson.Result object to traverse
// - path: The current path in the JSON structure (empty string for root)
// - field: The field name to search for
// - paths: Pointer to a slice where found paths will be stored
//
// The function works recursively, building dot-notation paths to each occurrence
// of the specified field throughout the JSON structure.
func Walk(value gjson.Result, path, field string, paths *[]string) {
switch value.Type {
case gjson.JSON:
// For JSON objects and arrays, iterate through each child
value.ForEach(func(key, val gjson.Result) bool {
var childPath string
// Escape special characters for gjson/sjson path syntax
// . -> \.
// * -> \*
// ? -> \?
keyStr := key.String()
safeKey := escapeGJSONPathKey(keyStr)
if path == "" {
childPath = safeKey
} else {
childPath = path + "." + safeKey
}
if keyStr == field {
*paths = append(*paths, childPath)
}
Walk(val, childPath, field, paths)
return true
})
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
// Terminal types - no further traversal needed
}
}
// RenameKey renames a key in a JSON string by moving its value to a new key path
// and then deleting the old key path.
//
// Parameters:
// - jsonStr: The JSON string to modify
// - oldKeyPath: The dot-notation path to the key that should be renamed
// - newKeyPath: The dot-notation path where the value should be moved to
//
// Returns:
// - string: The modified JSON string with the key renamed
// - error: An error if the operation fails
//
// The function performs the rename in two steps:
// 1. Sets the value at the new key path
// 2. Deletes the old key path
func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) {
value := gjson.Get(jsonStr, oldKeyPath)
if !value.Exists() {
return "", fmt.Errorf("old key '%s' does not exist", oldKeyPath)
}
interimJson, err := sjson.SetRaw(jsonStr, newKeyPath, value.Raw)
if err != nil {
return "", fmt.Errorf("failed to set new key '%s': %w", newKeyPath, err)
}
finalJson, err := sjson.Delete(interimJson, oldKeyPath)
if err != nil {
return "", fmt.Errorf("failed to delete old key '%s': %w", oldKeyPath, err)
}
return finalJson, nil
}
// FixJSON converts non-standard JSON that uses single quotes for strings into
// RFC 8259-compliant JSON by converting those single-quoted strings to
// double-quoted strings with proper escaping.
//
// Examples:
//
// {'a': 1, 'b': '2'} => {"a": 1, "b": "2"}
// {"t": 'He said "hi"'} => {"t": "He said \"hi\""}
//
// Rules:
// - Existing double-quoted JSON strings are preserved as-is.
// - Single-quoted strings are converted to double-quoted strings.
// - Inside converted strings, any double quote is escaped (\").
// - Common backslash escapes (\n, \r, \t, \b, \f, \\) are preserved.
// - \' inside single-quoted strings becomes a literal ' in the output (no
// escaping needed inside double quotes).
// - Unicode escapes (\uXXXX) inside single-quoted strings are forwarded.
// - The function does not attempt to fix other non-JSON features beyond quotes.
func FixJSON(input string) string {
var out bytes.Buffer
inDouble := false
inSingle := false
escaped := false // applies within the current string state
// Helper to write a rune, escaping double quotes when inside a converted
// single-quoted string (which becomes a double-quoted string in output).
writeConverted := func(r rune) {
if r == '"' {
out.WriteByte('\\')
out.WriteByte('"')
return
}
out.WriteRune(r)
}
runes := []rune(input)
for i := 0; i < len(runes); i++ {
r := runes[i]
if inDouble {
out.WriteRune(r)
if escaped {
// end of escape sequence in a standard JSON string
escaped = false
continue
}
if r == '\\' {
escaped = true
continue
}
if r == '"' {
inDouble = false
}
continue
}
if inSingle {
if escaped {
// Handle common escape sequences after a backslash within a
// single-quoted string
escaped = false
switch r {
case 'n', 'r', 't', 'b', 'f', '/', '"':
// Keep the backslash and the character (except for '"' which
// rarely appears, but if it does, keep as \" to remain valid)
out.WriteByte('\\')
out.WriteRune(r)
case '\\':
out.WriteByte('\\')
out.WriteByte('\\')
case '\'':
// \' inside single-quoted becomes a literal '
out.WriteRune('\'')
case 'u':
// Forward \uXXXX if possible
out.WriteByte('\\')
out.WriteByte('u')
// Copy up to next 4 hex digits if present
for k := 0; k < 4 && i+1 < len(runes); k++ {
peek := runes[i+1]
// simple hex check
if (peek >= '0' && peek <= '9') || (peek >= 'a' && peek <= 'f') || (peek >= 'A' && peek <= 'F') {
out.WriteRune(peek)
i++
} else {
break
}
}
default:
// Unknown escape: preserve the backslash and the char
out.WriteByte('\\')
out.WriteRune(r)
}
continue
}
if r == '\\' { // start escape sequence
escaped = true
continue
}
if r == '\'' { // end of single-quoted string
out.WriteByte('"')
inSingle = false
continue
}
// regular char inside converted string; escape double quotes
writeConverted(r)
continue
}
// Outside any string
if r == '"' {
inDouble = true
out.WriteRune(r)
continue
}
if r == '\'' { // start of non-standard single-quoted string
inSingle = true
out.WriteByte('"')
continue
}
out.WriteRune(r)
}
// If input ended while still inside a single-quoted string, close it to
// produce the best-effort valid JSON.
if inSingle {
out.WriteByte('"')
}
return out.String()
}
func CanonicalToolName(name string) string {
canonical := strings.TrimSpace(name)
canonical = strings.TrimLeft(canonical, "_")
return strings.ToLower(canonical)
}
// ToolNameMapFromClaudeRequest returns a canonical-name -> original-name map extracted from a Claude request.
// It is used to restore exact tool name casing for clients that require strict tool name matching (e.g. Claude Code).
func ToolNameMapFromClaudeRequest(rawJSON []byte) map[string]string {
if len(rawJSON) == 0 || !gjson.ValidBytes(rawJSON) {
return nil
}
tools := gjson.GetBytes(rawJSON, "tools")
if !tools.Exists() || !tools.IsArray() {
return nil
}
toolResults := tools.Array()
out := make(map[string]string, len(toolResults))
tools.ForEach(func(_, tool gjson.Result) bool {
name := strings.TrimSpace(tool.Get("name").String())
if name == "" {
return true
}
key := CanonicalToolName(name)
if key == "" {
return true
}
if _, exists := out[key]; !exists {
out[key] = name
}
return true
})
if len(out) == 0 {
return nil
}
return out
}
func MapToolName(toolNameMap map[string]string, name string) string {
if name == "" || toolNameMap == nil {
return name
}
if mapped, ok := toolNameMap[CanonicalToolName(name)]; ok && mapped != "" {
return mapped
}
return name
}
================================================
FILE: internal/util/util.go
================================================
// Package util provides utility functions for the CLI Proxy API server.
// It includes helper functions for logging configuration, file system operations,
// and other common utilities used throughout the application.
package util
import (
"context"
"fmt"
"os"
"path/filepath"
"regexp"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
var functionNameSanitizer = regexp.MustCompile(`[^a-zA-Z0-9_.:-]`)
// SanitizeFunctionName ensures a function name matches the requirements for Gemini/Vertex AI.
// It replaces invalid characters with underscores, ensures it starts with a letter or underscore,
// and truncates it to 64 characters if necessary.
// Regex Rule: [^a-zA-Z0-9_.:-] replaced with _.
func SanitizeFunctionName(name string) string {
if name == "" {
return ""
}
// Replace invalid characters with underscore
sanitized := functionNameSanitizer.ReplaceAllString(name, "_")
// Ensure it starts with a letter or underscore
// Re-reading requirements: Must start with a letter or an underscore.
if len(sanitized) > 0 {
first := sanitized[0]
if !((first >= 'a' && first <= 'z') || (first >= 'A' && first <= 'Z') || first == '_') {
// If it starts with an allowed character but not allowed at the beginning (digit, dot, colon, dash),
// we must prepend an underscore.
// To stay within the 64-character limit while prepending, we must truncate first.
if len(sanitized) >= 64 {
sanitized = sanitized[:63]
}
sanitized = "_" + sanitized
}
} else {
sanitized = "_"
}
// Truncate to 64 characters
if len(sanitized) > 64 {
sanitized = sanitized[:64]
}
return sanitized
}
// SetLogLevel configures the logrus log level based on the configuration.
// It sets the log level to DebugLevel if debug mode is enabled, otherwise to InfoLevel.
func SetLogLevel(cfg *config.Config) {
currentLevel := log.GetLevel()
var newLevel log.Level
if cfg.Debug {
newLevel = log.DebugLevel
} else {
newLevel = log.InfoLevel
}
if currentLevel != newLevel {
log.SetLevel(newLevel)
log.Infof("log level changed from %s to %s (debug=%t)", currentLevel, newLevel, cfg.Debug)
}
}
// ResolveAuthDir normalizes the auth directory path for consistent reuse throughout the app.
// It expands a leading tilde (~) to the user's home directory and returns a cleaned path.
func ResolveAuthDir(authDir string) (string, error) {
if authDir == "" {
return "", nil
}
if strings.HasPrefix(authDir, "~") {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("resolve auth dir: %w", err)
}
remainder := strings.TrimPrefix(authDir, "~")
remainder = strings.TrimLeft(remainder, "/\\")
if remainder == "" {
return filepath.Clean(home), nil
}
normalized := strings.ReplaceAll(remainder, "\\", "/")
return filepath.Clean(filepath.Join(home, filepath.FromSlash(normalized))), nil
}
return filepath.Clean(authDir), nil
}
// CountAuthFiles returns the number of auth records available through the provided Store.
// For filesystem-backed stores, this reflects the number of JSON auth files under the configured directory.
func CountAuthFiles[T any](ctx context.Context, store interface {
List(context.Context) ([]T, error)
}) int {
if store == nil {
return 0
}
if ctx == nil {
ctx = context.Background()
}
entries, err := store.List(ctx)
if err != nil {
log.Debugf("countAuthFiles: failed to list auth records: %v", err)
return 0
}
return len(entries)
}
// WritablePath returns the cleaned WRITABLE_PATH environment variable when it is set.
// It accepts both uppercase and lowercase variants for compatibility with existing conventions.
func WritablePath() string {
for _, key := range []string{"WRITABLE_PATH", "writable_path"} {
if value, ok := os.LookupEnv(key); ok {
trimmed := strings.TrimSpace(value)
if trimmed != "" {
return filepath.Clean(trimmed)
}
}
}
return ""
}
================================================
FILE: internal/watcher/clients.go
================================================
// clients.go implements watcher client lifecycle logic and persistence helpers.
// It reloads clients, handles incremental auth file changes, and persists updates when supported.
package watcher
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io/fs"
"os"
"path/filepath"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) {
log.Debugf("starting full client load process")
w.clientsMutex.RLock()
cfg := w.config
w.clientsMutex.RUnlock()
if cfg == nil {
log.Error("config is nil, cannot reload clients")
return
}
if len(affectedOAuthProviders) > 0 {
w.clientsMutex.Lock()
if w.currentAuths != nil {
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
for id, auth := range w.currentAuths {
if auth == nil {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if _, match := matchProvider(provider, affectedOAuthProviders); match {
continue
}
filtered[id] = auth
}
w.currentAuths = filtered
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
} else {
w.currentAuths = nil
}
w.clientsMutex.Unlock()
}
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
var authFileCount int
if rescanAuth {
authFileCount = w.loadFileClients(cfg)
log.Debugf("loaded %d file-based clients", authFileCount)
} else {
w.clientsMutex.RLock()
authFileCount = len(w.lastAuthHashes)
w.clientsMutex.RUnlock()
log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount)
}
if rescanAuth {
w.clientsMutex.Lock()
w.lastAuthHashes = make(map[string]string)
w.lastAuthContents = make(map[string]*coreauth.Auth)
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
} else if resolvedAuthDir != "" {
_ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return nil
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
sum := sha256.Sum256(data)
normalizedPath := w.normalizeAuthPath(path)
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
// Parse and cache auth content for future diff comparisons
var auth coreauth.Auth
if errParse := json.Unmarshal(data, &auth); errParse == nil {
w.lastAuthContents[normalizedPath] = &auth
}
ctx := &synthesizer.SynthesisContext{
Config: cfg,
AuthDir: resolvedAuthDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
if generated := synthesizer.SynthesizeAuthFile(ctx, path, data); len(generated) > 0 {
if pathAuths := authSliceToMap(generated); len(pathAuths) > 0 {
w.fileAuthsByPath[normalizedPath] = pathAuths
}
}
}
}
return nil
})
}
w.clientsMutex.Unlock()
}
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
if w.reloadCallback != nil {
log.Debugf("triggering server update callback before auth refresh")
w.reloadCallback(cfg)
}
w.refreshAuthState(forceAuthRefresh)
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
totalNewClients,
authFileCount,
geminiAPIKeyCount,
vertexCompatAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,
openAICompatCount,
)
}
func (w *Watcher) addOrUpdateClient(path string) {
data, errRead := os.ReadFile(path)
if errRead != nil {
log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead)
return
}
if len(data) == 0 {
log.Debugf("ignoring empty auth file: %s", filepath.Base(path))
return
}
sum := sha256.Sum256(data)
curHash := hex.EncodeToString(sum[:])
normalized := w.normalizeAuthPath(path)
// Parse new auth content for diff comparison
var newAuth coreauth.Auth
if errParse := json.Unmarshal(data, &newAuth); errParse != nil {
log.Errorf("failed to parse auth file %s: %v", filepath.Base(path), errParse)
return
}
w.clientsMutex.Lock()
if w.config == nil {
log.Error("config is nil, cannot add or update client")
w.clientsMutex.Unlock()
return
}
if w.fileAuthsByPath == nil {
w.fileAuthsByPath = make(map[string]map[string]*coreauth.Auth)
}
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
w.clientsMutex.Unlock()
return
}
// Get old auth for diff comparison
var oldAuth *coreauth.Auth
if w.lastAuthContents != nil {
oldAuth = w.lastAuthContents[normalized]
}
// Compute and log field changes
if changes := diff.BuildAuthChangeDetails(oldAuth, &newAuth); len(changes) > 0 {
log.Debugf("auth field changes for %s:", filepath.Base(path))
for _, c := range changes {
log.Debugf(" %s", c)
}
}
// Update caches
w.lastAuthHashes[normalized] = curHash
if w.lastAuthContents == nil {
w.lastAuthContents = make(map[string]*coreauth.Auth)
}
w.lastAuthContents[normalized] = &newAuth
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
for id, a := range w.fileAuthsByPath[normalized] {
oldByID[id] = a
}
// Build synthesized auth entries for this single file only.
sctx := &synthesizer.SynthesisContext{
Config: w.config,
AuthDir: w.authDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
generated := synthesizer.SynthesizeAuthFile(sctx, path, data)
newByID := authSliceToMap(generated)
if len(newByID) > 0 {
w.fileAuthsByPath[normalized] = newByID
} else {
delete(w.fileAuthsByPath, normalized)
}
updates := w.computePerPathUpdatesLocked(oldByID, newByID)
w.clientsMutex.Unlock()
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) removeClient(path string) {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.Lock()
oldByID := make(map[string]*coreauth.Auth, len(w.fileAuthsByPath[normalized]))
for id, a := range w.fileAuthsByPath[normalized] {
oldByID[id] = a
}
delete(w.lastAuthHashes, normalized)
delete(w.lastAuthContents, normalized)
delete(w.fileAuthsByPath, normalized)
updates := w.computePerPathUpdatesLocked(oldByID, map[string]*coreauth.Auth{})
w.clientsMutex.Unlock()
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) computePerPathUpdatesLocked(oldByID, newByID map[string]*coreauth.Auth) []AuthUpdate {
if w.currentAuths == nil {
w.currentAuths = make(map[string]*coreauth.Auth)
}
updates := make([]AuthUpdate, 0, len(oldByID)+len(newByID))
for id, newAuth := range newByID {
existing, ok := w.currentAuths[id]
if !ok {
w.currentAuths[id] = newAuth.Clone()
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: newAuth.Clone()})
continue
}
if !authEqual(existing, newAuth) {
w.currentAuths[id] = newAuth.Clone()
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: newAuth.Clone()})
}
}
for id := range oldByID {
if _, stillExists := newByID[id]; stillExists {
continue
}
delete(w.currentAuths, id)
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
}
return updates
}
func authSliceToMap(auths []*coreauth.Auth) map[string]*coreauth.Auth {
byID := make(map[string]*coreauth.Auth, len(auths))
for _, a := range auths {
if a == nil || strings.TrimSpace(a.ID) == "" {
continue
}
byID[a.ID] = a
}
return byID
}
func (w *Watcher) loadFileClients(cfg *config.Config) int {
authFileCount := 0
successfulAuthCount := 0
authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir)
if errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
return 0
}
if authDir == "" {
return 0
}
errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
log.Debugf("error accessing path %s: %v", path, err)
return err
}
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
authFileCount++
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 {
successfulAuthCount++
}
}
return nil
})
if errWalk != nil {
log.Errorf("error walking auth directory: %v", errWalk)
}
log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount)
return authFileCount
}
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
geminiAPIKeyCount := 0
vertexCompatAPIKeyCount := 0
claudeAPIKeyCount := 0
codexAPIKeyCount := 0
openAICompatCount := 0
if len(cfg.GeminiKey) > 0 {
geminiAPIKeyCount += len(cfg.GeminiKey)
}
if len(cfg.VertexCompatAPIKey) > 0 {
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
}
if len(cfg.ClaudeKey) > 0 {
claudeAPIKeyCount += len(cfg.ClaudeKey)
}
if len(cfg.CodexKey) > 0 {
codexAPIKeyCount += len(cfg.CodexKey)
}
if len(cfg.OpenAICompatibility) > 0 {
for _, compatConfig := range cfg.OpenAICompatibility {
openAICompatCount += len(compatConfig.APIKeyEntries)
}
}
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
}
func (w *Watcher) persistConfigAsync() {
if w == nil || w.storePersister == nil {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := w.storePersister.PersistConfig(ctx); err != nil {
log.Errorf("failed to persist config change: %v", err)
}
}()
}
func (w *Watcher) persistAuthAsync(message string, paths ...string) {
if w == nil || w.storePersister == nil {
return
}
filtered := make([]string, 0, len(paths))
for _, p := range paths {
if trimmed := strings.TrimSpace(p); trimmed != "" {
filtered = append(filtered, trimmed)
}
}
if len(filtered) == 0 {
return
}
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil {
log.Errorf("failed to persist auth changes: %v", err)
}
}()
}
func (w *Watcher) stopServerUpdateTimer() {
w.serverUpdateMu.Lock()
defer w.serverUpdateMu.Unlock()
if w.serverUpdateTimer != nil {
w.serverUpdateTimer.Stop()
w.serverUpdateTimer = nil
}
w.serverUpdatePend = false
}
func (w *Watcher) triggerServerUpdate(cfg *config.Config) {
if w == nil || w.reloadCallback == nil || cfg == nil {
return
}
if w.stopped.Load() {
return
}
now := time.Now()
w.serverUpdateMu.Lock()
if w.serverUpdateLast.IsZero() || now.Sub(w.serverUpdateLast) >= serverUpdateDebounce {
w.serverUpdateLast = now
if w.serverUpdateTimer != nil {
w.serverUpdateTimer.Stop()
w.serverUpdateTimer = nil
}
w.serverUpdatePend = false
w.serverUpdateMu.Unlock()
w.reloadCallback(cfg)
return
}
if w.serverUpdatePend {
w.serverUpdateMu.Unlock()
return
}
delay := serverUpdateDebounce - now.Sub(w.serverUpdateLast)
if delay < 10*time.Millisecond {
delay = 10 * time.Millisecond
}
w.serverUpdatePend = true
if w.serverUpdateTimer != nil {
w.serverUpdateTimer.Stop()
w.serverUpdateTimer = nil
}
var timer *time.Timer
timer = time.AfterFunc(delay, func() {
if w.stopped.Load() {
return
}
w.clientsMutex.RLock()
latestCfg := w.config
w.clientsMutex.RUnlock()
w.serverUpdateMu.Lock()
if w.serverUpdateTimer != timer || !w.serverUpdatePend {
w.serverUpdateMu.Unlock()
return
}
w.serverUpdateTimer = nil
w.serverUpdatePend = false
if latestCfg == nil || w.reloadCallback == nil || w.stopped.Load() {
w.serverUpdateMu.Unlock()
return
}
w.serverUpdateLast = time.Now()
w.serverUpdateMu.Unlock()
w.reloadCallback(latestCfg)
})
w.serverUpdateTimer = timer
w.serverUpdateMu.Unlock()
}
================================================
FILE: internal/watcher/config_reload.go
================================================
// config_reload.go implements debounced configuration hot reload.
// It detects material changes and reloads clients when the config changes.
package watcher
import (
"crypto/sha256"
"encoding/hex"
"os"
"reflect"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
"gopkg.in/yaml.v3"
log "github.com/sirupsen/logrus"
)
func (w *Watcher) stopConfigReloadTimer() {
w.configReloadMu.Lock()
if w.configReloadTimer != nil {
w.configReloadTimer.Stop()
w.configReloadTimer = nil
}
w.configReloadMu.Unlock()
}
func (w *Watcher) scheduleConfigReload() {
w.configReloadMu.Lock()
defer w.configReloadMu.Unlock()
if w.configReloadTimer != nil {
w.configReloadTimer.Stop()
}
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
w.configReloadMu.Lock()
w.configReloadTimer = nil
w.configReloadMu.Unlock()
w.reloadConfigIfChanged()
})
}
func (w *Watcher) reloadConfigIfChanged() {
data, err := os.ReadFile(w.configPath)
if err != nil {
log.Errorf("failed to read config file for hash check: %v", err)
return
}
if len(data) == 0 {
log.Debugf("ignoring empty config file write event")
return
}
sum := sha256.Sum256(data)
newHash := hex.EncodeToString(sum[:])
w.clientsMutex.RLock()
currentHash := w.lastConfigHash
w.clientsMutex.RUnlock()
if currentHash != "" && currentHash == newHash {
log.Debugf("config file content unchanged (hash match), skipping reload")
return
}
log.Infof("config file changed, reloading: %s", w.configPath)
if w.reloadConfig() {
finalHash := newHash
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
sumUpdated := sha256.Sum256(updatedData)
finalHash = hex.EncodeToString(sumUpdated[:])
} else if errRead != nil {
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
}
w.clientsMutex.Lock()
w.lastConfigHash = finalHash
w.clientsMutex.Unlock()
w.persistConfigAsync()
}
}
func (w *Watcher) reloadConfig() bool {
log.Debug("=========================== CONFIG RELOAD ============================")
log.Debugf("starting config reload from: %s", w.configPath)
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
if errLoadConfig != nil {
log.Errorf("failed to reload config: %v", errLoadConfig)
return false
}
if w.mirroredAuthDir != "" {
newConfig.AuthDir = w.mirroredAuthDir
} else {
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(newConfig.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory from config: %v", errResolveAuthDir)
} else {
newConfig.AuthDir = resolvedAuthDir
}
}
w.clientsMutex.Lock()
var oldConfig *config.Config
_ = yaml.Unmarshal(w.oldConfigYaml, &oldConfig)
w.oldConfigYaml, _ = yaml.Marshal(newConfig)
w.config = newConfig
w.clientsMutex.Unlock()
var affectedOAuthProviders []string
if oldConfig != nil {
_, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
}
util.SetLogLevel(newConfig)
if oldConfig != nil && oldConfig.Debug != newConfig.Debug {
log.Debugf("log level updated - debug mode changed from %t to %t", oldConfig.Debug, newConfig.Debug)
}
if oldConfig != nil {
details := diff.BuildConfigChangeDetails(oldConfig, newConfig)
if len(details) > 0 {
log.Debugf("config changes detected:")
for _, d := range details {
log.Debugf(" %s", d)
}
} else {
log.Debugf("no material config field changes detected")
}
}
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
retryConfigChanged := oldConfig != nil && (oldConfig.RequestRetry != newConfig.RequestRetry || oldConfig.MaxRetryInterval != newConfig.MaxRetryInterval || oldConfig.MaxRetryCredentials != newConfig.MaxRetryCredentials)
forceAuthRefresh := oldConfig != nil && (oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix || !reflect.DeepEqual(oldConfig.OAuthModelAlias, newConfig.OAuthModelAlias) || retryConfigChanged)
log.Infof("config successfully reloaded, triggering client reload")
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
return true
}
================================================
FILE: internal/watcher/diff/auth_diff.go
================================================
// auth_diff.go computes human-readable diffs for auth file field changes.
package diff
import (
"fmt"
"strings"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// BuildAuthChangeDetails computes a redacted, human-readable list of auth field changes.
// Only prefix, proxy_url, and disabled fields are tracked; sensitive data is never printed.
func BuildAuthChangeDetails(oldAuth, newAuth *coreauth.Auth) []string {
changes := make([]string, 0, 3)
// Handle nil cases by using empty Auth as default
if oldAuth == nil {
oldAuth = &coreauth.Auth{}
}
if newAuth == nil {
return changes
}
// Compare prefix
oldPrefix := strings.TrimSpace(oldAuth.Prefix)
newPrefix := strings.TrimSpace(newAuth.Prefix)
if oldPrefix != newPrefix {
changes = append(changes, fmt.Sprintf("prefix: %s -> %s", oldPrefix, newPrefix))
}
// Compare proxy_url (redacted)
oldProxy := strings.TrimSpace(oldAuth.ProxyURL)
newProxy := strings.TrimSpace(newAuth.ProxyURL)
if oldProxy != newProxy {
changes = append(changes, fmt.Sprintf("proxy_url: %s -> %s", formatProxyURL(oldProxy), formatProxyURL(newProxy)))
}
// Compare disabled
if oldAuth.Disabled != newAuth.Disabled {
changes = append(changes, fmt.Sprintf("disabled: %t -> %t", oldAuth.Disabled, newAuth.Disabled))
}
return changes
}
================================================
FILE: internal/watcher/diff/config_diff.go
================================================
package diff
import (
"fmt"
"net/url"
"reflect"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// BuildConfigChangeDetails computes a redacted, human-readable list of config changes.
// Secrets are never printed; only structural or non-sensitive fields are surfaced.
func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
changes := make([]string, 0, 16)
if oldCfg == nil || newCfg == nil {
return changes
}
// Simple scalars
if oldCfg.Port != newCfg.Port {
changes = append(changes, fmt.Sprintf("port: %d -> %d", oldCfg.Port, newCfg.Port))
}
if oldCfg.AuthDir != newCfg.AuthDir {
changes = append(changes, fmt.Sprintf("auth-dir: %s -> %s", oldCfg.AuthDir, newCfg.AuthDir))
}
if oldCfg.Debug != newCfg.Debug {
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
}
if oldCfg.Pprof.Enable != newCfg.Pprof.Enable {
changes = append(changes, fmt.Sprintf("pprof.enable: %t -> %t", oldCfg.Pprof.Enable, newCfg.Pprof.Enable))
}
if strings.TrimSpace(oldCfg.Pprof.Addr) != strings.TrimSpace(newCfg.Pprof.Addr) {
changes = append(changes, fmt.Sprintf("pprof.addr: %s -> %s", strings.TrimSpace(oldCfg.Pprof.Addr), strings.TrimSpace(newCfg.Pprof.Addr)))
}
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
}
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
}
if oldCfg.DisableCooling != newCfg.DisableCooling {
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
}
if oldCfg.RequestLog != newCfg.RequestLog {
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
}
if oldCfg.LogsMaxTotalSizeMB != newCfg.LogsMaxTotalSizeMB {
changes = append(changes, fmt.Sprintf("logs-max-total-size-mb: %d -> %d", oldCfg.LogsMaxTotalSizeMB, newCfg.LogsMaxTotalSizeMB))
}
if oldCfg.ErrorLogsMaxFiles != newCfg.ErrorLogsMaxFiles {
changes = append(changes, fmt.Sprintf("error-logs-max-files: %d -> %d", oldCfg.ErrorLogsMaxFiles, newCfg.ErrorLogsMaxFiles))
}
if oldCfg.RequestRetry != newCfg.RequestRetry {
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
}
if oldCfg.MaxRetryCredentials != newCfg.MaxRetryCredentials {
changes = append(changes, fmt.Sprintf("max-retry-credentials: %d -> %d", oldCfg.MaxRetryCredentials, newCfg.MaxRetryCredentials))
}
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
}
if oldCfg.ProxyURL != newCfg.ProxyURL {
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", formatProxyURL(oldCfg.ProxyURL), formatProxyURL(newCfg.ProxyURL)))
}
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
}
if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix {
changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix))
}
if oldCfg.NonStreamKeepAliveInterval != newCfg.NonStreamKeepAliveInterval {
changes = append(changes, fmt.Sprintf("nonstream-keepalive-interval: %d -> %d", oldCfg.NonStreamKeepAliveInterval, newCfg.NonStreamKeepAliveInterval))
}
// Quota-exceeded behavior
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-project: %t -> %t", oldCfg.QuotaExceeded.SwitchProject, newCfg.QuotaExceeded.SwitchProject))
}
if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel {
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel))
}
if oldCfg.Routing.Strategy != newCfg.Routing.Strategy {
changes = append(changes, fmt.Sprintf("routing.strategy: %s -> %s", oldCfg.Routing.Strategy, newCfg.Routing.Strategy))
}
// API keys (redacted) and counts
if len(oldCfg.APIKeys) != len(newCfg.APIKeys) {
changes = append(changes, fmt.Sprintf("api-keys count: %d -> %d", len(oldCfg.APIKeys), len(newCfg.APIKeys)))
} else if !reflect.DeepEqual(trimStrings(oldCfg.APIKeys), trimStrings(newCfg.APIKeys)) {
changes = append(changes, "api-keys: values updated (count unchanged, redacted)")
}
if len(oldCfg.GeminiKey) != len(newCfg.GeminiKey) {
changes = append(changes, fmt.Sprintf("gemini-api-key count: %d -> %d", len(oldCfg.GeminiKey), len(newCfg.GeminiKey)))
} else {
for i := range oldCfg.GeminiKey {
o := oldCfg.GeminiKey[i]
n := newCfg.GeminiKey[i]
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
changes = append(changes, fmt.Sprintf("gemini[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
}
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
}
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
changes = append(changes, fmt.Sprintf("gemini[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
}
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
changes = append(changes, fmt.Sprintf("gemini[%d].api-key: updated", i))
}
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
}
oldModels := SummarizeGeminiModels(o.Models)
newModels := SummarizeGeminiModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("gemini[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
}
}
}
// Claude keys (do not print key material)
if len(oldCfg.ClaudeKey) != len(newCfg.ClaudeKey) {
changes = append(changes, fmt.Sprintf("claude-api-key count: %d -> %d", len(oldCfg.ClaudeKey), len(newCfg.ClaudeKey)))
} else {
for i := range oldCfg.ClaudeKey {
o := oldCfg.ClaudeKey[i]
n := newCfg.ClaudeKey[i]
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
changes = append(changes, fmt.Sprintf("claude[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
}
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
}
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
changes = append(changes, fmt.Sprintf("claude[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
}
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
changes = append(changes, fmt.Sprintf("claude[%d].api-key: updated", i))
}
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
}
oldModels := SummarizeClaudeModels(o.Models)
newModels := SummarizeClaudeModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("claude[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
}
if o.Cloak != nil && n.Cloak != nil {
if strings.TrimSpace(o.Cloak.Mode) != strings.TrimSpace(n.Cloak.Mode) {
changes = append(changes, fmt.Sprintf("claude[%d].cloak.mode: %s -> %s", i, o.Cloak.Mode, n.Cloak.Mode))
}
if o.Cloak.StrictMode != n.Cloak.StrictMode {
changes = append(changes, fmt.Sprintf("claude[%d].cloak.strict-mode: %t -> %t", i, o.Cloak.StrictMode, n.Cloak.StrictMode))
}
if len(o.Cloak.SensitiveWords) != len(n.Cloak.SensitiveWords) {
changes = append(changes, fmt.Sprintf("claude[%d].cloak.sensitive-words: %d -> %d", i, len(o.Cloak.SensitiveWords), len(n.Cloak.SensitiveWords)))
}
}
}
}
// Codex keys (do not print key material)
if len(oldCfg.CodexKey) != len(newCfg.CodexKey) {
changes = append(changes, fmt.Sprintf("codex-api-key count: %d -> %d", len(oldCfg.CodexKey), len(newCfg.CodexKey)))
} else {
for i := range oldCfg.CodexKey {
o := oldCfg.CodexKey[i]
n := newCfg.CodexKey[i]
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
changes = append(changes, fmt.Sprintf("codex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
}
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
}
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
}
if o.Websockets != n.Websockets {
changes = append(changes, fmt.Sprintf("codex[%d].websockets: %t -> %t", i, o.Websockets, n.Websockets))
}
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
}
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
}
oldModels := SummarizeCodexModels(o.Models)
newModels := SummarizeCodexModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("codex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
}
}
}
// AmpCode settings (redacted where needed)
oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL)
newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL)
if oldAmpURL != newAmpURL {
changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL))
}
oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey)
newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey)
switch {
case oldAmpKey == "" && newAmpKey != "":
changes = append(changes, "ampcode.upstream-api-key: added")
case oldAmpKey != "" && newAmpKey == "":
changes = append(changes, "ampcode.upstream-api-key: removed")
case oldAmpKey != newAmpKey:
changes = append(changes, "ampcode.upstream-api-key: updated")
}
if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost {
changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost))
}
oldMappings := SummarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings)
newMappings := SummarizeAmpModelMappings(newCfg.AmpCode.ModelMappings)
if oldMappings.hash != newMappings.hash {
changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count))
}
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
}
oldUpstreamAPIKeysCount := len(oldCfg.AmpCode.UpstreamAPIKeys)
newUpstreamAPIKeysCount := len(newCfg.AmpCode.UpstreamAPIKeys)
if !equalUpstreamAPIKeys(oldCfg.AmpCode.UpstreamAPIKeys, newCfg.AmpCode.UpstreamAPIKeys) {
changes = append(changes, fmt.Sprintf("ampcode.upstream-api-keys: updated (%d -> %d entries)", oldUpstreamAPIKeysCount, newUpstreamAPIKeysCount))
}
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
changes = append(changes, entries...)
}
if entries, _ := DiffOAuthModelAliasChanges(oldCfg.OAuthModelAlias, newCfg.OAuthModelAlias); len(entries) > 0 {
changes = append(changes, entries...)
}
// Remote management (never print the key)
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
}
if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel {
changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel))
}
oldPanelRepo := strings.TrimSpace(oldCfg.RemoteManagement.PanelGitHubRepository)
newPanelRepo := strings.TrimSpace(newCfg.RemoteManagement.PanelGitHubRepository)
if oldPanelRepo != newPanelRepo {
changes = append(changes, fmt.Sprintf("remote-management.panel-github-repository: %s -> %s", oldPanelRepo, newPanelRepo))
}
if oldCfg.RemoteManagement.SecretKey != newCfg.RemoteManagement.SecretKey {
switch {
case oldCfg.RemoteManagement.SecretKey == "" && newCfg.RemoteManagement.SecretKey != "":
changes = append(changes, "remote-management.secret-key: created")
case oldCfg.RemoteManagement.SecretKey != "" && newCfg.RemoteManagement.SecretKey == "":
changes = append(changes, "remote-management.secret-key: deleted")
default:
changes = append(changes, "remote-management.secret-key: updated")
}
}
// OpenAI compatibility providers (summarized)
if compat := DiffOpenAICompatibility(oldCfg.OpenAICompatibility, newCfg.OpenAICompatibility); len(compat) > 0 {
changes = append(changes, "openai-compatibility:")
for _, c := range compat {
changes = append(changes, " "+c)
}
}
// Vertex-compatible API keys
if len(oldCfg.VertexCompatAPIKey) != len(newCfg.VertexCompatAPIKey) {
changes = append(changes, fmt.Sprintf("vertex-api-key count: %d -> %d", len(oldCfg.VertexCompatAPIKey), len(newCfg.VertexCompatAPIKey)))
} else {
for i := range oldCfg.VertexCompatAPIKey {
o := oldCfg.VertexCompatAPIKey[i]
n := newCfg.VertexCompatAPIKey[i]
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
changes = append(changes, fmt.Sprintf("vertex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
}
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
changes = append(changes, fmt.Sprintf("vertex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
}
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
changes = append(changes, fmt.Sprintf("vertex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
}
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
changes = append(changes, fmt.Sprintf("vertex[%d].api-key: updated", i))
}
oldModels := SummarizeVertexModels(o.Models)
newModels := SummarizeVertexModels(n.Models)
if oldModels.hash != newModels.hash {
changes = append(changes, fmt.Sprintf("vertex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
}
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
if oldExcluded.hash != newExcluded.hash {
changes = append(changes, fmt.Sprintf("vertex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
}
if !equalStringMap(o.Headers, n.Headers) {
changes = append(changes, fmt.Sprintf("vertex[%d].headers: updated", i))
}
}
}
return changes
}
func trimStrings(in []string) []string {
out := make([]string, len(in))
for i := range in {
out[i] = strings.TrimSpace(in[i])
}
return out
}
func equalStringMap(a, b map[string]string) bool {
if len(a) != len(b) {
return false
}
for k, v := range a {
if b[k] != v {
return false
}
}
return true
}
func formatProxyURL(raw string) string {
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return ""
}
parsed, err := url.Parse(trimmed)
if err != nil {
return ""
}
host := strings.TrimSpace(parsed.Host)
scheme := strings.TrimSpace(parsed.Scheme)
if host == "" {
// Allow host:port style without scheme.
parsed2, err2 := url.Parse("http://" + trimmed)
if err2 == nil {
host = strings.TrimSpace(parsed2.Host)
}
scheme = ""
}
if host == "" {
return ""
}
if scheme == "" {
return host
}
return scheme + "://" + host
}
func equalStringSet(a, b []string) bool {
if len(a) == 0 && len(b) == 0 {
return true
}
aSet := make(map[string]struct{}, len(a))
for _, k := range a {
aSet[strings.TrimSpace(k)] = struct{}{}
}
bSet := make(map[string]struct{}, len(b))
for _, k := range b {
bSet[strings.TrimSpace(k)] = struct{}{}
}
if len(aSet) != len(bSet) {
return false
}
for k := range aSet {
if _, ok := bSet[k]; !ok {
return false
}
}
return true
}
// equalUpstreamAPIKeys compares two slices of AmpUpstreamAPIKeyEntry for equality.
// Comparison is done by count and content (upstream key and client keys).
func equalUpstreamAPIKeys(a, b []config.AmpUpstreamAPIKeyEntry) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if strings.TrimSpace(a[i].UpstreamAPIKey) != strings.TrimSpace(b[i].UpstreamAPIKey) {
return false
}
if !equalStringSet(a[i].APIKeys, b[i].APIKeys) {
return false
}
}
return true
}
================================================
FILE: internal/watcher/diff/config_diff_test.go
================================================
package diff
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestBuildConfigChangeDetails(t *testing.T) {
oldCfg := &config.Config{
Port: 8080,
AuthDir: "/tmp/auth-old",
GeminiKey: []config.GeminiKey{
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model"}},
},
AmpCode: config.AmpCode{
UpstreamURL: "http://old-upstream",
ModelMappings: []config.AmpModelMapping{{From: "from-old", To: "to-old"}},
RestrictManagementToLocalhost: false,
},
RemoteManagement: config.RemoteManagement{
AllowRemote: false,
SecretKey: "old",
DisableControlPanel: false,
PanelGitHubRepository: "repo-old",
},
OAuthExcludedModels: map[string][]string{
"providerA": {"m1"},
},
OpenAICompatibility: []config.OpenAICompatibility{
{
Name: "compat-a",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "k1"},
},
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
},
},
}
newCfg := &config.Config{
Port: 9090,
AuthDir: "/tmp/auth-new",
GeminiKey: []config.GeminiKey{
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model", "extra"}},
},
AmpCode: config.AmpCode{
UpstreamURL: "http://new-upstream",
RestrictManagementToLocalhost: true,
ModelMappings: []config.AmpModelMapping{
{From: "from-old", To: "to-old"},
{From: "from-new", To: "to-new"},
},
},
RemoteManagement: config.RemoteManagement{
AllowRemote: true,
SecretKey: "new",
DisableControlPanel: true,
PanelGitHubRepository: "repo-new",
},
OAuthExcludedModels: map[string][]string{
"providerA": {"m1", "m2"},
"providerB": {"x"},
},
OpenAICompatibility: []config.OpenAICompatibility{
{
Name: "compat-a",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "k1"},
},
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
},
{
Name: "compat-b",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "k2"},
},
},
},
}
details := BuildConfigChangeDetails(oldCfg, newCfg)
expectContains(t, details, "port: 8080 -> 9090")
expectContains(t, details, "auth-dir: /tmp/auth-old -> /tmp/auth-new")
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
expectContains(t, details, "ampcode.upstream-url: http://old-upstream -> http://new-upstream")
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 2 entries)")
expectContains(t, details, "remote-management.allow-remote: false -> true")
expectContains(t, details, "remote-management.secret-key: updated")
expectContains(t, details, "oauth-excluded-models[providera]: updated (1 -> 2 entries)")
expectContains(t, details, "oauth-excluded-models[providerb]: added (1 entries)")
expectContains(t, details, "openai-compatibility:")
expectContains(t, details, " provider added: compat-b (api-keys=1, models=0)")
expectContains(t, details, " provider updated: compat-a (models 1 -> 2)")
}
func TestBuildConfigChangeDetails_NoChanges(t *testing.T) {
cfg := &config.Config{
Port: 8080,
}
if details := BuildConfigChangeDetails(cfg, cfg); len(details) != 0 {
t.Fatalf("expected no change entries, got %v", details)
}
}
func TestBuildConfigChangeDetails_GeminiVertexHeadersAndForceMappings(t *testing.T) {
oldCfg := &config.Config{
GeminiKey: []config.GeminiKey{
{APIKey: "g1", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
},
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "v1", BaseURL: "http://v-old", Models: []config.VertexCompatModel{{Name: "m1"}}},
},
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
ForceModelMappings: false,
},
}
newCfg := &config.Config{
GeminiKey: []config.GeminiKey{
{APIKey: "g1", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"a", "b"}},
},
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "v1", BaseURL: "http://v-new", Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
},
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
ForceModelMappings: true,
},
}
details := BuildConfigChangeDetails(oldCfg, newCfg)
expectContains(t, details, "gemini[0].headers: updated")
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 1 entries)")
expectContains(t, details, "ampcode.force-model-mappings: false -> true")
}
func TestBuildConfigChangeDetails_ModelPrefixes(t *testing.T) {
oldCfg := &config.Config{
GeminiKey: []config.GeminiKey{
{APIKey: "g1", Prefix: "old-g", BaseURL: "http://g", ProxyURL: "http://gp"},
},
ClaudeKey: []config.ClaudeKey{
{APIKey: "c1", Prefix: "old-c", BaseURL: "http://c", ProxyURL: "http://cp"},
},
CodexKey: []config.CodexKey{
{APIKey: "x1", Prefix: "old-x", BaseURL: "http://x", ProxyURL: "http://xp"},
},
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "v1", Prefix: "old-v", BaseURL: "http://v", ProxyURL: "http://vp"},
},
}
newCfg := &config.Config{
GeminiKey: []config.GeminiKey{
{APIKey: "g1", Prefix: "new-g", BaseURL: "http://g", ProxyURL: "http://gp"},
},
ClaudeKey: []config.ClaudeKey{
{APIKey: "c1", Prefix: "new-c", BaseURL: "http://c", ProxyURL: "http://cp"},
},
CodexKey: []config.CodexKey{
{APIKey: "x1", Prefix: "new-x", BaseURL: "http://x", ProxyURL: "http://xp"},
},
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "v1", Prefix: "new-v", BaseURL: "http://v", ProxyURL: "http://vp"},
},
}
changes := BuildConfigChangeDetails(oldCfg, newCfg)
expectContains(t, changes, "gemini[0].prefix: old-g -> new-g")
expectContains(t, changes, "claude[0].prefix: old-c -> new-c")
expectContains(t, changes, "codex[0].prefix: old-x -> new-x")
expectContains(t, changes, "vertex[0].prefix: old-v -> new-v")
}
func TestBuildConfigChangeDetails_NilSafe(t *testing.T) {
if details := BuildConfigChangeDetails(nil, &config.Config{}); len(details) != 0 {
t.Fatalf("expected empty change list when old nil, got %v", details)
}
if details := BuildConfigChangeDetails(&config.Config{}, nil); len(details) != 0 {
t.Fatalf("expected empty change list when new nil, got %v", details)
}
}
func TestBuildConfigChangeDetails_SecretsAndCounts(t *testing.T) {
oldCfg := &config.Config{
SDKConfig: sdkconfig.SDKConfig{
APIKeys: []string{"a"},
},
AmpCode: config.AmpCode{
UpstreamAPIKey: "",
},
RemoteManagement: config.RemoteManagement{
SecretKey: "",
},
}
newCfg := &config.Config{
SDKConfig: sdkconfig.SDKConfig{
APIKeys: []string{"a", "b", "c"},
},
AmpCode: config.AmpCode{
UpstreamAPIKey: "new-key",
},
RemoteManagement: config.RemoteManagement{
SecretKey: "new-secret",
},
}
details := BuildConfigChangeDetails(oldCfg, newCfg)
expectContains(t, details, "api-keys count: 1 -> 3")
expectContains(t, details, "ampcode.upstream-api-key: added")
expectContains(t, details, "remote-management.secret-key: created")
}
func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
oldCfg := &config.Config{
Port: 1000,
AuthDir: "/old",
Debug: false,
LoggingToFile: false,
UsageStatisticsEnabled: false,
DisableCooling: false,
RequestRetry: 1,
MaxRetryCredentials: 1,
MaxRetryInterval: 1,
WebsocketAuth: false,
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
CodexKey: []config.CodexKey{{APIKey: "x1"}},
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: false,
ProxyURL: "http://old-proxy",
APIKeys: []string{"key-1"},
ForceModelPrefix: false,
NonStreamKeepAliveInterval: 0,
},
}
newCfg := &config.Config{
Port: 2000,
AuthDir: "/new",
Debug: true,
LoggingToFile: true,
UsageStatisticsEnabled: true,
DisableCooling: true,
RequestRetry: 2,
MaxRetryCredentials: 3,
MaxRetryInterval: 3,
WebsocketAuth: true,
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
ClaudeKey: []config.ClaudeKey{
{APIKey: "c1", BaseURL: "http://new", ProxyURL: "http://p", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
{APIKey: "c2"},
},
CodexKey: []config.CodexKey{
{APIKey: "x1", BaseURL: "http://x", ProxyURL: "http://px", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"b"}},
{APIKey: "x2"},
},
AmpCode: config.AmpCode{
UpstreamAPIKey: "",
RestrictManagementToLocalhost: true,
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
},
RemoteManagement: config.RemoteManagement{
DisableControlPanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: true,
ProxyURL: "http://new-proxy",
APIKeys: []string{" key-1 ", "key-2"},
ForceModelPrefix: true,
NonStreamKeepAliveInterval: 5,
},
}
details := BuildConfigChangeDetails(oldCfg, newCfg)
expectContains(t, details, "debug: false -> true")
expectContains(t, details, "logging-to-file: false -> true")
expectContains(t, details, "usage-statistics-enabled: false -> true")
expectContains(t, details, "disable-cooling: false -> true")
expectContains(t, details, "request-log: false -> true")
expectContains(t, details, "request-retry: 1 -> 2")
expectContains(t, details, "max-retry-credentials: 1 -> 3")
expectContains(t, details, "max-retry-interval: 1 -> 3")
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
expectContains(t, details, "ws-auth: false -> true")
expectContains(t, details, "force-model-prefix: false -> true")
expectContains(t, details, "nonstream-keepalive-interval: 0 -> 5")
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
expectContains(t, details, "api-keys count: 1 -> 2")
expectContains(t, details, "claude-api-key count: 1 -> 2")
expectContains(t, details, "codex-api-key count: 1 -> 2")
expectContains(t, details, "ampcode.restrict-management-to-localhost: false -> true")
expectContains(t, details, "ampcode.upstream-api-key: removed")
expectContains(t, details, "remote-management.disable-control-panel: false -> true")
expectContains(t, details, "remote-management.panel-github-repository: old/repo -> new/repo")
expectContains(t, details, "remote-management.secret-key: deleted")
}
func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
oldCfg := &config.Config{
Port: 1,
AuthDir: "/a",
Debug: false,
LoggingToFile: false,
UsageStatisticsEnabled: false,
DisableCooling: false,
RequestRetry: 1,
MaxRetryCredentials: 1,
MaxRetryInterval: 1,
WebsocketAuth: false,
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
GeminiKey: []config.GeminiKey{
{APIKey: "g-old", BaseURL: "http://g-old", ProxyURL: "http://gp-old", Headers: map[string]string{"A": "1"}},
},
ClaudeKey: []config.ClaudeKey{
{APIKey: "c-old", BaseURL: "http://c-old", ProxyURL: "http://cp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
},
CodexKey: []config.CodexKey{
{APIKey: "x-old", BaseURL: "http://x-old", ProxyURL: "http://xp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
},
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "v-old", BaseURL: "http://v-old", ProxyURL: "http://vp-old", Headers: map[string]string{"H": "1"}, Models: []config.VertexCompatModel{{Name: "m1"}}},
},
AmpCode: config.AmpCode{
UpstreamURL: "http://amp-old",
UpstreamAPIKey: "old-key",
RestrictManagementToLocalhost: false,
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
ForceModelMappings: false,
},
RemoteManagement: config.RemoteManagement{
AllowRemote: false,
DisableControlPanel: false,
PanelGitHubRepository: "old/repo",
SecretKey: "old",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: false,
ProxyURL: "http://old-proxy",
APIKeys: []string{" keyA "},
},
OAuthExcludedModels: map[string][]string{"p1": {"a"}},
OpenAICompatibility: []config.OpenAICompatibility{
{
Name: "prov-old",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "k1"},
},
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
},
},
}
newCfg := &config.Config{
Port: 2,
AuthDir: "/b",
Debug: true,
LoggingToFile: true,
UsageStatisticsEnabled: true,
DisableCooling: true,
RequestRetry: 2,
MaxRetryCredentials: 3,
MaxRetryInterval: 3,
WebsocketAuth: true,
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
GeminiKey: []config.GeminiKey{
{APIKey: "g-new", BaseURL: "http://g-new", ProxyURL: "http://gp-new", Headers: map[string]string{"A": "2"}, ExcludedModels: []string{"x", "y"}},
},
ClaudeKey: []config.ClaudeKey{
{APIKey: "c-new", BaseURL: "http://c-new", ProxyURL: "http://cp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
},
CodexKey: []config.CodexKey{
{APIKey: "x-new", BaseURL: "http://x-new", ProxyURL: "http://xp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
},
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "v-new", BaseURL: "http://v-new", ProxyURL: "http://vp-new", Headers: map[string]string{"H": "2"}, Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
},
AmpCode: config.AmpCode{
UpstreamURL: "http://amp-new",
UpstreamAPIKey: "",
RestrictManagementToLocalhost: true,
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
ForceModelMappings: true,
},
RemoteManagement: config.RemoteManagement{
AllowRemote: true,
DisableControlPanel: true,
PanelGitHubRepository: "new/repo",
SecretKey: "",
},
SDKConfig: sdkconfig.SDKConfig{
RequestLog: true,
ProxyURL: "http://new-proxy",
APIKeys: []string{"keyB"},
},
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
OpenAICompatibility: []config.OpenAICompatibility{
{
Name: "prov-old",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "k1"},
{APIKey: "k2"},
},
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
},
{
Name: "prov-new",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k3"}},
},
},
}
changes := BuildConfigChangeDetails(oldCfg, newCfg)
expectContains(t, changes, "port: 1 -> 2")
expectContains(t, changes, "auth-dir: /a -> /b")
expectContains(t, changes, "debug: false -> true")
expectContains(t, changes, "logging-to-file: false -> true")
expectContains(t, changes, "usage-statistics-enabled: false -> true")
expectContains(t, changes, "disable-cooling: false -> true")
expectContains(t, changes, "request-retry: 1 -> 2")
expectContains(t, changes, "max-retry-credentials: 1 -> 3")
expectContains(t, changes, "max-retry-interval: 1 -> 3")
expectContains(t, changes, "proxy-url: http://old-proxy -> http://new-proxy")
expectContains(t, changes, "ws-auth: false -> true")
expectContains(t, changes, "quota-exceeded.switch-project: false -> true")
expectContains(t, changes, "quota-exceeded.switch-preview-model: false -> true")
expectContains(t, changes, "api-keys: values updated (count unchanged, redacted)")
expectContains(t, changes, "gemini[0].base-url: http://g-old -> http://g-new")
expectContains(t, changes, "gemini[0].proxy-url: http://gp-old -> http://gp-new")
expectContains(t, changes, "gemini[0].api-key: updated")
expectContains(t, changes, "gemini[0].headers: updated")
expectContains(t, changes, "gemini[0].excluded-models: updated (0 -> 2 entries)")
expectContains(t, changes, "claude[0].base-url: http://c-old -> http://c-new")
expectContains(t, changes, "claude[0].proxy-url: http://cp-old -> http://cp-new")
expectContains(t, changes, "claude[0].api-key: updated")
expectContains(t, changes, "claude[0].headers: updated")
expectContains(t, changes, "claude[0].excluded-models: updated (1 -> 2 entries)")
expectContains(t, changes, "codex[0].base-url: http://x-old -> http://x-new")
expectContains(t, changes, "codex[0].proxy-url: http://xp-old -> http://xp-new")
expectContains(t, changes, "codex[0].api-key: updated")
expectContains(t, changes, "codex[0].headers: updated")
expectContains(t, changes, "codex[0].excluded-models: updated (1 -> 2 entries)")
expectContains(t, changes, "vertex[0].base-url: http://v-old -> http://v-new")
expectContains(t, changes, "vertex[0].proxy-url: http://vp-old -> http://vp-new")
expectContains(t, changes, "vertex[0].api-key: updated")
expectContains(t, changes, "vertex[0].models: updated (1 -> 2 entries)")
expectContains(t, changes, "vertex[0].headers: updated")
expectContains(t, changes, "ampcode.upstream-url: http://amp-old -> http://amp-new")
expectContains(t, changes, "ampcode.upstream-api-key: removed")
expectContains(t, changes, "ampcode.restrict-management-to-localhost: false -> true")
expectContains(t, changes, "ampcode.model-mappings: updated (1 -> 1 entries)")
expectContains(t, changes, "ampcode.force-model-mappings: false -> true")
expectContains(t, changes, "oauth-excluded-models[p1]: updated (1 -> 2 entries)")
expectContains(t, changes, "oauth-excluded-models[p2]: added (1 entries)")
expectContains(t, changes, "remote-management.allow-remote: false -> true")
expectContains(t, changes, "remote-management.disable-control-panel: false -> true")
expectContains(t, changes, "remote-management.panel-github-repository: old/repo -> new/repo")
expectContains(t, changes, "remote-management.secret-key: deleted")
expectContains(t, changes, "openai-compatibility:")
}
func TestFormatProxyURL(t *testing.T) {
tests := []struct {
name string
in string
want string
}{
{name: "empty", in: "", want: ""},
{name: "invalid", in: "http://[::1", want: ""},
{name: "fullURLRedactsUserinfoAndPath", in: "http://user:pass@example.com:8080/path?x=1#frag", want: "http://example.com:8080"},
{name: "socks5RedactsUserinfoAndPath", in: "socks5://user:pass@192.168.1.1:1080/path?x=1", want: "socks5://192.168.1.1:1080"},
{name: "socks5HostPort", in: "socks5://proxy.example.com:1080/", want: "socks5://proxy.example.com:1080"},
{name: "hostPortNoScheme", in: "example.com:1234/path?x=1", want: "example.com:1234"},
{name: "relativePathRedacted", in: "/just/path", want: ""},
{name: "schemeAndHost", in: "https://example.com", want: "https://example.com"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := formatProxyURL(tt.in); got != tt.want {
t.Fatalf("expected %q, got %q", tt.want, got)
}
})
}
}
func TestBuildConfigChangeDetails_SecretAndUpstreamUpdates(t *testing.T) {
oldCfg := &config.Config{
AmpCode: config.AmpCode{
UpstreamAPIKey: "old",
},
RemoteManagement: config.RemoteManagement{
SecretKey: "old",
},
}
newCfg := &config.Config{
AmpCode: config.AmpCode{
UpstreamAPIKey: "new",
},
RemoteManagement: config.RemoteManagement{
SecretKey: "new",
},
}
changes := BuildConfigChangeDetails(oldCfg, newCfg)
expectContains(t, changes, "ampcode.upstream-api-key: updated")
expectContains(t, changes, "remote-management.secret-key: updated")
}
func TestBuildConfigChangeDetails_CountBranches(t *testing.T) {
oldCfg := &config.Config{}
newCfg := &config.Config{
GeminiKey: []config.GeminiKey{{APIKey: "g"}},
ClaudeKey: []config.ClaudeKey{{APIKey: "c"}},
CodexKey: []config.CodexKey{{APIKey: "x"}},
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "v", BaseURL: "http://v"},
},
}
changes := BuildConfigChangeDetails(oldCfg, newCfg)
expectContains(t, changes, "gemini-api-key count: 0 -> 1")
expectContains(t, changes, "claude-api-key count: 0 -> 1")
expectContains(t, changes, "codex-api-key count: 0 -> 1")
expectContains(t, changes, "vertex-api-key count: 0 -> 1")
}
func TestTrimStrings(t *testing.T) {
out := trimStrings([]string{" a ", "b", " c"})
if len(out) != 3 || out[0] != "a" || out[1] != "b" || out[2] != "c" {
t.Fatalf("unexpected trimmed strings: %v", out)
}
}
================================================
FILE: internal/watcher/diff/model_hash.go
================================================
package diff
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// ComputeOpenAICompatModelsHash returns a stable hash for OpenAI-compat models.
// Used to detect model list changes during hot reload.
func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeVertexCompatModelsHash returns a stable hash for Vertex-compatible models.
func ComputeVertexCompatModelsHash(models []config.VertexCompatModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeClaudeModelsHash returns a stable hash for Claude model aliases.
func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeCodexModelsHash returns a stable hash for Codex model aliases.
func ComputeCodexModelsHash(models []config.CodexModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeGeminiModelsHash returns a stable hash for Gemini model aliases.
func ComputeGeminiModelsHash(models []config.GeminiModel) string {
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return hashJoined(keys)
}
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
func ComputeExcludedModelsHash(excluded []string) string {
if len(excluded) == 0 {
return ""
}
normalized := make([]string, 0, len(excluded))
for _, entry := range excluded {
if trimmed := strings.TrimSpace(entry); trimmed != "" {
normalized = append(normalized, strings.ToLower(trimmed))
}
}
if len(normalized) == 0 {
return ""
}
sort.Strings(normalized)
data, _ := json.Marshal(normalized)
sum := sha256.Sum256(data)
return hex.EncodeToString(sum[:])
}
func normalizeModelPairs(collect func(out func(key string))) []string {
seen := make(map[string]struct{})
keys := make([]string, 0)
collect(func(key string) {
if _, exists := seen[key]; exists {
return
}
seen[key] = struct{}{}
keys = append(keys, key)
})
if len(keys) == 0 {
return nil
}
sort.Strings(keys)
return keys
}
func hashJoined(keys []string) string {
if len(keys) == 0 {
return ""
}
sum := sha256.Sum256([]byte(strings.Join(keys, "\n")))
return hex.EncodeToString(sum[:])
}
================================================
FILE: internal/watcher/diff/model_hash_test.go
================================================
package diff
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) {
models := []config.OpenAICompatibilityModel{
{Name: "gpt-4", Alias: "gpt4"},
{Name: "gpt-3.5-turbo"},
}
hash1 := ComputeOpenAICompatModelsHash(models)
hash2 := ComputeOpenAICompatModelsHash(models)
if hash1 == "" {
t.Fatal("hash should not be empty")
}
if hash1 != hash2 {
t.Fatalf("hash should be deterministic, got %s vs %s", hash1, hash2)
}
changed := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-4"}, {Name: "gpt-4.1"}})
if hash1 == changed {
t.Fatal("hash should change when model list changes")
}
}
func TestComputeOpenAICompatModelsHash_NormalizesAndDedups(t *testing.T) {
a := []config.OpenAICompatibilityModel{
{Name: "gpt-4", Alias: "gpt4"},
{Name: " "},
{Name: "GPT-4", Alias: "GPT4"},
{Alias: "a1"},
}
b := []config.OpenAICompatibilityModel{
{Alias: "A1"},
{Name: "gpt-4", Alias: "gpt4"},
}
h1 := ComputeOpenAICompatModelsHash(a)
h2 := ComputeOpenAICompatModelsHash(b)
if h1 == "" || h2 == "" {
t.Fatal("expected non-empty hashes for non-empty model sets")
}
if h1 != h2 {
t.Fatalf("expected normalized hashes to match, got %s / %s", h1, h2)
}
}
func TestComputeVertexCompatModelsHash_DifferentInputs(t *testing.T) {
models := []config.VertexCompatModel{{Name: "gemini-pro", Alias: "pro"}}
hash1 := ComputeVertexCompatModelsHash(models)
hash2 := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: "gemini-1.5-pro", Alias: "pro"}})
if hash1 == "" || hash2 == "" {
t.Fatal("hashes should not be empty for non-empty models")
}
if hash1 == hash2 {
t.Fatal("hash should differ when model content differs")
}
}
func TestComputeVertexCompatModelsHash_IgnoresBlankAndOrder(t *testing.T) {
a := []config.VertexCompatModel{
{Name: "m1", Alias: "a1"},
{Name: " "},
{Name: "M1", Alias: "A1"},
}
b := []config.VertexCompatModel{
{Name: "m1", Alias: "a1"},
}
if h1, h2 := ComputeVertexCompatModelsHash(a), ComputeVertexCompatModelsHash(b); h1 == "" || h1 != h2 {
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
}
}
func TestComputeClaudeModelsHash_Empty(t *testing.T) {
if got := ComputeClaudeModelsHash(nil); got != "" {
t.Fatalf("expected empty hash for nil models, got %q", got)
}
if got := ComputeClaudeModelsHash([]config.ClaudeModel{}); got != "" {
t.Fatalf("expected empty hash for empty slice, got %q", got)
}
}
func TestComputeCodexModelsHash_Empty(t *testing.T) {
if got := ComputeCodexModelsHash(nil); got != "" {
t.Fatalf("expected empty hash for nil models, got %q", got)
}
if got := ComputeCodexModelsHash([]config.CodexModel{}); got != "" {
t.Fatalf("expected empty hash for empty slice, got %q", got)
}
}
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.ClaudeModel{
{Name: "m1", Alias: "a1"},
{Name: " "},
{Name: "M1", Alias: "A1"},
}
b := []config.ClaudeModel{
{Name: "m1", Alias: "a1"},
}
if h1, h2 := ComputeClaudeModelsHash(a), ComputeClaudeModelsHash(b); h1 == "" || h1 != h2 {
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
}
}
func TestComputeCodexModelsHash_IgnoresBlankAndDedup(t *testing.T) {
a := []config.CodexModel{
{Name: "m1", Alias: "a1"},
{Name: " "},
{Name: "M1", Alias: "A1"},
}
b := []config.CodexModel{
{Name: "m1", Alias: "a1"},
}
if h1, h2 := ComputeCodexModelsHash(a), ComputeCodexModelsHash(b); h1 == "" || h1 != h2 {
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
}
}
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
if hash1 == "" || hash2 == "" {
t.Fatal("hash should not be empty for non-empty input")
}
if hash1 != hash2 {
t.Fatalf("hash should be order/space insensitive for same multiset, got %s vs %s", hash1, hash2)
}
hash3 := ComputeExcludedModelsHash([]string{"c"})
if hash1 == hash3 {
t.Fatal("hash should differ for different normalized sets")
}
}
func TestComputeOpenAICompatModelsHash_Empty(t *testing.T) {
if got := ComputeOpenAICompatModelsHash(nil); got != "" {
t.Fatalf("expected empty hash for nil input, got %q", got)
}
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{}); got != "" {
t.Fatalf("expected empty hash for empty slice, got %q", got)
}
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: " "}, {Alias: ""}}); got != "" {
t.Fatalf("expected empty hash for blank models, got %q", got)
}
}
func TestComputeVertexCompatModelsHash_Empty(t *testing.T) {
if got := ComputeVertexCompatModelsHash(nil); got != "" {
t.Fatalf("expected empty hash for nil input, got %q", got)
}
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{}); got != "" {
t.Fatalf("expected empty hash for empty slice, got %q", got)
}
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: " "}}); got != "" {
t.Fatalf("expected empty hash for blank models, got %q", got)
}
}
func TestComputeExcludedModelsHash_Empty(t *testing.T) {
if got := ComputeExcludedModelsHash(nil); got != "" {
t.Fatalf("expected empty hash for nil input, got %q", got)
}
if got := ComputeExcludedModelsHash([]string{}); got != "" {
t.Fatalf("expected empty hash for empty slice, got %q", got)
}
if got := ComputeExcludedModelsHash([]string{" ", ""}); got != "" {
t.Fatalf("expected empty hash for whitespace-only entries, got %q", got)
}
}
func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
models := []config.ClaudeModel{{Name: "a", Alias: "A"}, {Name: "b"}}
h1 := ComputeClaudeModelsHash(models)
h2 := ComputeClaudeModelsHash(models)
if h1 == "" || h1 != h2 {
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
}
if h3 := ComputeClaudeModelsHash([]config.ClaudeModel{{Name: "a"}}); h3 == h1 {
t.Fatalf("expected different hash when models change, got %s", h3)
}
}
func TestComputeCodexModelsHash_Deterministic(t *testing.T) {
models := []config.CodexModel{{Name: "a", Alias: "A"}, {Name: "b"}}
h1 := ComputeCodexModelsHash(models)
h2 := ComputeCodexModelsHash(models)
if h1 == "" || h1 != h2 {
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
}
if h3 := ComputeCodexModelsHash([]config.CodexModel{{Name: "a"}}); h3 == h1 {
t.Fatalf("expected different hash when models change, got %s", h3)
}
}
================================================
FILE: internal/watcher/diff/models_summary.go
================================================
package diff
import (
"crypto/sha256"
"encoding/hex"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
type GeminiModelsSummary struct {
hash string
count int
}
type ClaudeModelsSummary struct {
hash string
count int
}
type CodexModelsSummary struct {
hash string
count int
}
type VertexModelsSummary struct {
hash string
count int
}
// SummarizeGeminiModels hashes Gemini model aliases for change detection.
func SummarizeGeminiModels(models []config.GeminiModel) GeminiModelsSummary {
if len(models) == 0 {
return GeminiModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return GeminiModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeClaudeModels hashes Claude model aliases for change detection.
func SummarizeClaudeModels(models []config.ClaudeModel) ClaudeModelsSummary {
if len(models) == 0 {
return ClaudeModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return ClaudeModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeCodexModels hashes Codex model aliases for change detection.
func SummarizeCodexModels(models []config.CodexModel) CodexModelsSummary {
if len(models) == 0 {
return CodexModelsSummary{}
}
keys := normalizeModelPairs(func(out func(key string)) {
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
}
})
return CodexModelsSummary{
hash: hashJoined(keys),
count: len(keys),
}
}
// SummarizeVertexModels hashes Vertex-compatible model aliases for change detection.
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
if len(models) == 0 {
return VertexModelsSummary{}
}
names := make([]string, 0, len(models))
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
if alias != "" {
name = alias
}
names = append(names, name)
}
if len(names) == 0 {
return VertexModelsSummary{}
}
sort.Strings(names)
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
return VertexModelsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(names),
}
}
================================================
FILE: internal/watcher/diff/oauth_excluded.go
================================================
package diff
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
type ExcludedModelsSummary struct {
hash string
count int
}
// SummarizeExcludedModels normalizes and hashes an excluded-model list.
func SummarizeExcludedModels(list []string) ExcludedModelsSummary {
if len(list) == 0 {
return ExcludedModelsSummary{}
}
seen := make(map[string]struct{}, len(list))
normalized := make([]string, 0, len(list))
for _, entry := range list {
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
if _, exists := seen[trimmed]; exists {
continue
}
seen[trimmed] = struct{}{}
normalized = append(normalized, trimmed)
}
}
sort.Strings(normalized)
return ExcludedModelsSummary{
hash: ComputeExcludedModelsHash(normalized),
count: len(normalized),
}
}
// SummarizeOAuthExcludedModels summarizes OAuth excluded models per provider.
func SummarizeOAuthExcludedModels(entries map[string][]string) map[string]ExcludedModelsSummary {
if len(entries) == 0 {
return nil
}
out := make(map[string]ExcludedModelsSummary, len(entries))
for k, v := range entries {
key := strings.ToLower(strings.TrimSpace(k))
if key == "" {
continue
}
out[key] = SummarizeExcludedModels(v)
}
return out
}
// DiffOAuthExcludedModelChanges compares OAuth excluded models maps.
func DiffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) {
oldSummary := SummarizeOAuthExcludedModels(oldMap)
newSummary := SummarizeOAuthExcludedModels(newMap)
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
for k := range oldSummary {
keys[k] = struct{}{}
}
for k := range newSummary {
keys[k] = struct{}{}
}
changes := make([]string, 0, len(keys))
affected := make([]string, 0, len(keys))
for key := range keys {
oldInfo, okOld := oldSummary[key]
newInfo, okNew := newSummary[key]
switch {
case okOld && !okNew:
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key))
affected = append(affected, key)
case !okOld && okNew:
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count))
affected = append(affected, key)
case okOld && okNew && oldInfo.hash != newInfo.hash:
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
affected = append(affected, key)
}
}
sort.Strings(changes)
sort.Strings(affected)
return changes, affected
}
type AmpModelMappingsSummary struct {
hash string
count int
}
// SummarizeAmpModelMappings hashes Amp model mappings for change detection.
func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappingsSummary {
if len(mappings) == 0 {
return AmpModelMappingsSummary{}
}
entries := make([]string, 0, len(mappings))
for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if from == "" && to == "" {
continue
}
entries = append(entries, from+"->"+to)
}
if len(entries) == 0 {
return AmpModelMappingsSummary{}
}
sort.Strings(entries)
sum := sha256.Sum256([]byte(strings.Join(entries, "|")))
return AmpModelMappingsSummary{
hash: hex.EncodeToString(sum[:]),
count: len(entries),
}
}
================================================
FILE: internal/watcher/diff/oauth_excluded_test.go
================================================
package diff
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestSummarizeExcludedModels_NormalizesAndDedupes(t *testing.T) {
summary := SummarizeExcludedModels([]string{"A", " a ", "B", "b"})
if summary.count != 2 {
t.Fatalf("expected 2 unique entries, got %d", summary.count)
}
if summary.hash == "" {
t.Fatal("expected non-empty hash")
}
if empty := SummarizeExcludedModels(nil); empty.count != 0 || empty.hash != "" {
t.Fatalf("expected empty summary for nil input, got %+v", empty)
}
}
func TestDiffOAuthExcludedModelChanges(t *testing.T) {
oldMap := map[string][]string{
"ProviderA": {"model-1", "model-2"},
"providerB": {"x"},
}
newMap := map[string][]string{
"providerA": {"model-1", "model-3"},
"providerC": {"y"},
}
changes, affected := DiffOAuthExcludedModelChanges(oldMap, newMap)
expectContains(t, changes, "oauth-excluded-models[providera]: updated (2 -> 2 entries)")
expectContains(t, changes, "oauth-excluded-models[providerb]: removed")
expectContains(t, changes, "oauth-excluded-models[providerc]: added (1 entries)")
if len(affected) != 3 {
t.Fatalf("expected 3 affected providers, got %d", len(affected))
}
}
func TestSummarizeAmpModelMappings(t *testing.T) {
summary := SummarizeAmpModelMappings([]config.AmpModelMapping{
{From: "a", To: "A"},
{From: "b", To: "B"},
{From: " ", To: " "}, // ignored
})
if summary.count != 2 {
t.Fatalf("expected 2 entries, got %d", summary.count)
}
if summary.hash == "" {
t.Fatal("expected non-empty hash")
}
if empty := SummarizeAmpModelMappings(nil); empty.count != 0 || empty.hash != "" {
t.Fatalf("expected empty summary for nil input, got %+v", empty)
}
if blank := SummarizeAmpModelMappings([]config.AmpModelMapping{{From: " ", To: " "}}); blank.count != 0 || blank.hash != "" {
t.Fatalf("expected blank mappings ignored, got %+v", blank)
}
}
func TestSummarizeOAuthExcludedModels_NormalizesKeys(t *testing.T) {
out := SummarizeOAuthExcludedModels(map[string][]string{
"ProvA": {"X"},
"": {"ignored"},
})
if len(out) != 1 {
t.Fatalf("expected only non-empty key summary, got %d", len(out))
}
if _, ok := out["prova"]; !ok {
t.Fatalf("expected normalized key 'prova', got keys %v", out)
}
if out["prova"].count != 1 || out["prova"].hash == "" {
t.Fatalf("unexpected summary %+v", out["prova"])
}
if outEmpty := SummarizeOAuthExcludedModels(nil); outEmpty != nil {
t.Fatalf("expected nil map for nil input, got %v", outEmpty)
}
}
func TestSummarizeVertexModels(t *testing.T) {
summary := SummarizeVertexModels([]config.VertexCompatModel{
{Name: "m1"},
{Name: " ", Alias: "alias"},
{}, // ignored
})
if summary.count != 2 {
t.Fatalf("expected 2 vertex models, got %d", summary.count)
}
if summary.hash == "" {
t.Fatal("expected non-empty hash")
}
if empty := SummarizeVertexModels(nil); empty.count != 0 || empty.hash != "" {
t.Fatalf("expected empty summary for nil input, got %+v", empty)
}
if blank := SummarizeVertexModels([]config.VertexCompatModel{{Name: " "}}); blank.count != 0 || blank.hash != "" {
t.Fatalf("expected blank model ignored, got %+v", blank)
}
}
func expectContains(t *testing.T, list []string, target string) {
t.Helper()
for _, entry := range list {
if entry == target {
return
}
}
t.Fatalf("expected list to contain %q, got %#v", target, list)
}
================================================
FILE: internal/watcher/diff/oauth_model_alias.go
================================================
package diff
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
type OAuthModelAliasSummary struct {
hash string
count int
}
// SummarizeOAuthModelAlias summarizes OAuth model alias per channel.
func SummarizeOAuthModelAlias(entries map[string][]config.OAuthModelAlias) map[string]OAuthModelAliasSummary {
if len(entries) == 0 {
return nil
}
out := make(map[string]OAuthModelAliasSummary, len(entries))
for k, v := range entries {
key := strings.ToLower(strings.TrimSpace(k))
if key == "" {
continue
}
out[key] = summarizeOAuthModelAliasList(v)
}
if len(out) == 0 {
return nil
}
return out
}
// DiffOAuthModelAliasChanges compares OAuth model alias maps.
func DiffOAuthModelAliasChanges(oldMap, newMap map[string][]config.OAuthModelAlias) ([]string, []string) {
oldSummary := SummarizeOAuthModelAlias(oldMap)
newSummary := SummarizeOAuthModelAlias(newMap)
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
for k := range oldSummary {
keys[k] = struct{}{}
}
for k := range newSummary {
keys[k] = struct{}{}
}
changes := make([]string, 0, len(keys))
affected := make([]string, 0, len(keys))
for key := range keys {
oldInfo, okOld := oldSummary[key]
newInfo, okNew := newSummary[key]
switch {
case okOld && !okNew:
changes = append(changes, fmt.Sprintf("oauth-model-alias[%s]: removed", key))
affected = append(affected, key)
case !okOld && okNew:
changes = append(changes, fmt.Sprintf("oauth-model-alias[%s]: added (%d entries)", key, newInfo.count))
affected = append(affected, key)
case okOld && okNew && oldInfo.hash != newInfo.hash:
changes = append(changes, fmt.Sprintf("oauth-model-alias[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
affected = append(affected, key)
}
}
sort.Strings(changes)
sort.Strings(affected)
return changes, affected
}
func summarizeOAuthModelAliasList(list []config.OAuthModelAlias) OAuthModelAliasSummary {
if len(list) == 0 {
return OAuthModelAliasSummary{}
}
seen := make(map[string]struct{}, len(list))
normalized := make([]string, 0, len(list))
for _, alias := range list {
name := strings.ToLower(strings.TrimSpace(alias.Name))
aliasVal := strings.ToLower(strings.TrimSpace(alias.Alias))
if name == "" || aliasVal == "" {
continue
}
key := name + "->" + aliasVal
if alias.Fork {
key += "|fork"
}
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
normalized = append(normalized, key)
}
if len(normalized) == 0 {
return OAuthModelAliasSummary{}
}
sort.Strings(normalized)
sum := sha256.Sum256([]byte(strings.Join(normalized, "|")))
return OAuthModelAliasSummary{
hash: hex.EncodeToString(sum[:]),
count: len(normalized),
}
}
================================================
FILE: internal/watcher/diff/openai_compat.go
================================================
package diff
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// DiffOpenAICompatibility produces human-readable change descriptions.
func DiffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
changes := make([]string, 0)
oldMap := make(map[string]config.OpenAICompatibility, len(oldList))
oldLabels := make(map[string]string, len(oldList))
for idx, entry := range oldList {
key, label := openAICompatKey(entry, idx)
oldMap[key] = entry
oldLabels[key] = label
}
newMap := make(map[string]config.OpenAICompatibility, len(newList))
newLabels := make(map[string]string, len(newList))
for idx, entry := range newList {
key, label := openAICompatKey(entry, idx)
newMap[key] = entry
newLabels[key] = label
}
keySet := make(map[string]struct{}, len(oldMap)+len(newMap))
for key := range oldMap {
keySet[key] = struct{}{}
}
for key := range newMap {
keySet[key] = struct{}{}
}
orderedKeys := make([]string, 0, len(keySet))
for key := range keySet {
orderedKeys = append(orderedKeys, key)
}
sort.Strings(orderedKeys)
for _, key := range orderedKeys {
oldEntry, oldOk := oldMap[key]
newEntry, newOk := newMap[key]
label := oldLabels[key]
if label == "" {
label = newLabels[key]
}
switch {
case !oldOk:
changes = append(changes, fmt.Sprintf("provider added: %s (api-keys=%d, models=%d)", label, countAPIKeys(newEntry), countOpenAIModels(newEntry.Models)))
case !newOk:
changes = append(changes, fmt.Sprintf("provider removed: %s (api-keys=%d, models=%d)", label, countAPIKeys(oldEntry), countOpenAIModels(oldEntry.Models)))
default:
if detail := describeOpenAICompatibilityUpdate(oldEntry, newEntry); detail != "" {
changes = append(changes, fmt.Sprintf("provider updated: %s %s", label, detail))
}
}
}
return changes
}
func describeOpenAICompatibilityUpdate(oldEntry, newEntry config.OpenAICompatibility) string {
oldKeyCount := countAPIKeys(oldEntry)
newKeyCount := countAPIKeys(newEntry)
oldModelCount := countOpenAIModels(oldEntry.Models)
newModelCount := countOpenAIModels(newEntry.Models)
details := make([]string, 0, 3)
if oldKeyCount != newKeyCount {
details = append(details, fmt.Sprintf("api-keys %d -> %d", oldKeyCount, newKeyCount))
}
if oldModelCount != newModelCount {
details = append(details, fmt.Sprintf("models %d -> %d", oldModelCount, newModelCount))
}
if !equalStringMap(oldEntry.Headers, newEntry.Headers) {
details = append(details, "headers updated")
}
if len(details) == 0 {
return ""
}
return "(" + strings.Join(details, ", ") + ")"
}
func countAPIKeys(entry config.OpenAICompatibility) int {
count := 0
for _, keyEntry := range entry.APIKeyEntries {
if strings.TrimSpace(keyEntry.APIKey) != "" {
count++
}
}
return count
}
func countOpenAIModels(models []config.OpenAICompatibilityModel) int {
count := 0
for _, model := range models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
count++
}
return count
}
func openAICompatKey(entry config.OpenAICompatibility, index int) (string, string) {
name := strings.TrimSpace(entry.Name)
if name != "" {
return "name:" + name, name
}
base := strings.TrimSpace(entry.BaseURL)
if base != "" {
return "base:" + base, base
}
for _, model := range entry.Models {
alias := strings.TrimSpace(model.Alias)
if alias == "" {
alias = strings.TrimSpace(model.Name)
}
if alias != "" {
return "alias:" + alias, alias
}
}
sig := openAICompatSignature(entry)
if sig == "" {
return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1)
}
short := sig
if len(short) > 8 {
short = short[:8]
}
return "sig:" + sig, "compat-" + short
}
func openAICompatSignature(entry config.OpenAICompatibility) string {
var parts []string
if v := strings.TrimSpace(entry.Name); v != "" {
parts = append(parts, "name="+strings.ToLower(v))
}
if v := strings.TrimSpace(entry.BaseURL); v != "" {
parts = append(parts, "base="+v)
}
models := make([]string, 0, len(entry.Models))
for _, model := range entry.Models {
name := strings.TrimSpace(model.Name)
alias := strings.TrimSpace(model.Alias)
if name == "" && alias == "" {
continue
}
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias))
}
if len(models) > 0 {
sort.Strings(models)
parts = append(parts, "models="+strings.Join(models, ","))
}
if len(entry.Headers) > 0 {
keys := make([]string, 0, len(entry.Headers))
for k := range entry.Headers {
if trimmed := strings.TrimSpace(k); trimmed != "" {
keys = append(keys, strings.ToLower(trimmed))
}
}
if len(keys) > 0 {
sort.Strings(keys)
parts = append(parts, "headers="+strings.Join(keys, ","))
}
}
// Intentionally exclude API key material; only count non-empty entries.
if count := countAPIKeys(entry); count > 0 {
parts = append(parts, fmt.Sprintf("api_keys=%d", count))
}
if len(parts) == 0 {
return ""
}
sum := sha256.Sum256([]byte(strings.Join(parts, "|")))
return hex.EncodeToString(sum[:])
}
================================================
FILE: internal/watcher/diff/openai_compat_test.go
================================================
package diff
import (
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestDiffOpenAICompatibility(t *testing.T) {
oldList := []config.OpenAICompatibility{
{
Name: "provider-a",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "key-a"},
},
Models: []config.OpenAICompatibilityModel{
{Name: "m1"},
},
},
}
newList := []config.OpenAICompatibility{
{
Name: "provider-a",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "key-a"},
{APIKey: "key-b"},
},
Models: []config.OpenAICompatibilityModel{
{Name: "m1"},
{Name: "m2"},
},
Headers: map[string]string{"X-Test": "1"},
},
{
Name: "provider-b",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-b"}},
},
}
changes := DiffOpenAICompatibility(oldList, newList)
expectContains(t, changes, "provider added: provider-b (api-keys=1, models=0)")
expectContains(t, changes, "provider updated: provider-a (api-keys 1 -> 2, models 1 -> 2, headers updated)")
}
func TestDiffOpenAICompatibility_RemovedAndUnchanged(t *testing.T) {
oldList := []config.OpenAICompatibility{
{
Name: "provider-a",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
},
}
newList := []config.OpenAICompatibility{
{
Name: "provider-a",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
},
}
if changes := DiffOpenAICompatibility(oldList, newList); len(changes) != 0 {
t.Fatalf("expected no changes, got %v", changes)
}
newList = nil
changes := DiffOpenAICompatibility(oldList, newList)
expectContains(t, changes, "provider removed: provider-a (api-keys=1, models=1)")
}
func TestOpenAICompatKeyFallbacks(t *testing.T) {
entry := config.OpenAICompatibility{
BaseURL: "http://base",
Models: []config.OpenAICompatibilityModel{{Alias: "alias-only"}},
}
key, label := openAICompatKey(entry, 0)
if key != "base:http://base" || label != "http://base" {
t.Fatalf("expected base key, got %s/%s", key, label)
}
entry.BaseURL = ""
key, label = openAICompatKey(entry, 1)
if key != "alias:alias-only" || label != "alias-only" {
t.Fatalf("expected alias fallback, got %s/%s", key, label)
}
entry.Models = nil
key, label = openAICompatKey(entry, 2)
if key != "index:2" || label != "entry-3" {
t.Fatalf("expected index fallback, got %s/%s", key, label)
}
}
func TestOpenAICompatKey_UsesName(t *testing.T) {
entry := config.OpenAICompatibility{Name: "My-Provider"}
key, label := openAICompatKey(entry, 0)
if key != "name:My-Provider" || label != "My-Provider" {
t.Fatalf("expected name key, got %s/%s", key, label)
}
}
func TestOpenAICompatKey_SignatureFallbackWhenOnlyAPIKeys(t *testing.T) {
entry := config.OpenAICompatibility{
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k1"}, {APIKey: "k2"}},
}
key, label := openAICompatKey(entry, 0)
if !strings.HasPrefix(key, "sig:") || !strings.HasPrefix(label, "compat-") {
t.Fatalf("expected signature key, got %s/%s", key, label)
}
}
func TestOpenAICompatSignature_EmptyReturnsEmpty(t *testing.T) {
if got := openAICompatSignature(config.OpenAICompatibility{}); got != "" {
t.Fatalf("expected empty signature, got %q", got)
}
}
func TestOpenAICompatSignature_StableAndNormalized(t *testing.T) {
a := config.OpenAICompatibility{
Name: " Provider ",
BaseURL: "http://base",
Models: []config.OpenAICompatibilityModel{
{Name: "m1"},
{Name: " "},
{Alias: "A1"},
},
Headers: map[string]string{
"X-Test": "1",
" ": "ignored",
},
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "k1"},
{APIKey: " "},
},
}
b := config.OpenAICompatibility{
Name: "provider",
BaseURL: "http://base",
Models: []config.OpenAICompatibilityModel{
{Alias: "a1"},
{Name: "m1"},
},
Headers: map[string]string{
"x-test": "2",
},
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "k2"},
},
}
sigA := openAICompatSignature(a)
sigB := openAICompatSignature(b)
if sigA == "" || sigB == "" {
t.Fatalf("expected non-empty signatures, got %q / %q", sigA, sigB)
}
if sigA != sigB {
t.Fatalf("expected normalized signatures to match, got %s / %s", sigA, sigB)
}
c := b
c.Models = append(c.Models, config.OpenAICompatibilityModel{Name: "m2"})
if sigC := openAICompatSignature(c); sigC == sigB {
t.Fatalf("expected signature to change when models change, got %s", sigC)
}
}
func TestCountOpenAIModelsSkipsBlanks(t *testing.T) {
models := []config.OpenAICompatibilityModel{
{Name: "m1"},
{Name: ""},
{Alias: ""},
{Name: " "},
{Alias: "a1"},
}
if got := countOpenAIModels(models); got != 2 {
t.Fatalf("expected 2 counted models, got %d", got)
}
}
func TestOpenAICompatKeyUsesModelNameWhenAliasEmpty(t *testing.T) {
entry := config.OpenAICompatibility{
Models: []config.OpenAICompatibilityModel{{Name: "model-name"}},
}
key, label := openAICompatKey(entry, 5)
if key != "alias:model-name" || label != "model-name" {
t.Fatalf("expected model-name fallback, got %s/%s", key, label)
}
}
================================================
FILE: internal/watcher/dispatcher.go
================================================
// dispatcher.go implements auth update dispatching and queue management.
// It batches, deduplicates, and delivers auth updates to registered consumers.
package watcher
import (
"context"
"fmt"
"reflect"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
var snapshotCoreAuthsFunc = snapshotCoreAuths
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
w.authQueue = queue
if w.dispatchCond == nil {
w.dispatchCond = sync.NewCond(&w.dispatchMu)
}
if w.dispatchCancel != nil {
w.dispatchCancel()
if w.dispatchCond != nil {
w.dispatchMu.Lock()
w.dispatchCond.Broadcast()
w.dispatchMu.Unlock()
}
w.dispatchCancel = nil
}
if queue != nil {
ctx, cancel := context.WithCancel(context.Background())
w.dispatchCancel = cancel
go w.dispatchLoop(ctx)
}
}
func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
if w == nil {
return false
}
w.clientsMutex.Lock()
if w.runtimeAuths == nil {
w.runtimeAuths = make(map[string]*coreauth.Auth)
}
switch update.Action {
case AuthUpdateActionAdd, AuthUpdateActionModify:
if update.Auth != nil && update.Auth.ID != "" {
clone := update.Auth.Clone()
w.runtimeAuths[clone.ID] = clone
if w.currentAuths == nil {
w.currentAuths = make(map[string]*coreauth.Auth)
}
w.currentAuths[clone.ID] = clone.Clone()
}
case AuthUpdateActionDelete:
id := update.ID
if id == "" && update.Auth != nil {
id = update.Auth.ID
}
if id != "" {
delete(w.runtimeAuths, id)
if w.currentAuths != nil {
delete(w.currentAuths, id)
}
}
}
w.clientsMutex.Unlock()
if w.getAuthQueue() == nil {
return false
}
w.dispatchAuthUpdates([]AuthUpdate{update})
return true
}
func (w *Watcher) refreshAuthState(force bool) {
w.clientsMutex.RLock()
cfg := w.config
authDir := w.authDir
w.clientsMutex.RUnlock()
auths := snapshotCoreAuthsFunc(cfg, authDir)
w.clientsMutex.Lock()
if len(w.runtimeAuths) > 0 {
for _, a := range w.runtimeAuths {
if a != nil {
auths = append(auths, a.Clone())
}
}
}
updates := w.prepareAuthUpdatesLocked(auths, force)
w.clientsMutex.Unlock()
w.dispatchAuthUpdates(updates)
}
func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate {
newState := make(map[string]*coreauth.Auth, len(auths))
for _, auth := range auths {
if auth == nil || auth.ID == "" {
continue
}
newState[auth.ID] = auth.Clone()
}
if w.currentAuths == nil {
w.currentAuths = newState
if w.authQueue == nil {
return nil
}
updates := make([]AuthUpdate, 0, len(newState))
for id, auth := range newState {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
}
return updates
}
if w.authQueue == nil {
w.currentAuths = newState
return nil
}
updates := make([]AuthUpdate, 0, len(newState)+len(w.currentAuths))
for id, auth := range newState {
if existing, ok := w.currentAuths[id]; !ok {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
} else if force || !authEqual(existing, auth) {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()})
}
}
for id := range w.currentAuths {
if _, ok := newState[id]; !ok {
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
}
}
w.currentAuths = newState
return updates
}
func (w *Watcher) dispatchAuthUpdates(updates []AuthUpdate) {
if len(updates) == 0 {
return
}
queue := w.getAuthQueue()
if queue == nil {
return
}
baseTS := time.Now().UnixNano()
w.dispatchMu.Lock()
if w.pendingUpdates == nil {
w.pendingUpdates = make(map[string]AuthUpdate)
}
for idx, update := range updates {
key := w.authUpdateKey(update, baseTS+int64(idx))
if _, exists := w.pendingUpdates[key]; !exists {
w.pendingOrder = append(w.pendingOrder, key)
}
w.pendingUpdates[key] = update
}
if w.dispatchCond != nil {
w.dispatchCond.Signal()
}
w.dispatchMu.Unlock()
}
func (w *Watcher) authUpdateKey(update AuthUpdate, ts int64) string {
if update.ID != "" {
return update.ID
}
return fmt.Sprintf("%s:%d", update.Action, ts)
}
func (w *Watcher) dispatchLoop(ctx context.Context) {
for {
batch, ok := w.nextPendingBatch(ctx)
if !ok {
return
}
queue := w.getAuthQueue()
if queue == nil {
if ctx.Err() != nil {
return
}
time.Sleep(10 * time.Millisecond)
continue
}
for _, update := range batch {
select {
case queue <- update:
case <-ctx.Done():
return
}
}
}
}
func (w *Watcher) nextPendingBatch(ctx context.Context) ([]AuthUpdate, bool) {
w.dispatchMu.Lock()
defer w.dispatchMu.Unlock()
for len(w.pendingOrder) == 0 {
if ctx.Err() != nil {
return nil, false
}
w.dispatchCond.Wait()
if ctx.Err() != nil {
return nil, false
}
}
batch := make([]AuthUpdate, 0, len(w.pendingOrder))
for _, key := range w.pendingOrder {
batch = append(batch, w.pendingUpdates[key])
delete(w.pendingUpdates, key)
}
w.pendingOrder = w.pendingOrder[:0]
return batch, true
}
func (w *Watcher) getAuthQueue() chan<- AuthUpdate {
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
return w.authQueue
}
func (w *Watcher) stopDispatch() {
if w.dispatchCancel != nil {
w.dispatchCancel()
w.dispatchCancel = nil
}
w.dispatchMu.Lock()
w.pendingOrder = nil
w.pendingUpdates = nil
if w.dispatchCond != nil {
w.dispatchCond.Broadcast()
}
w.dispatchMu.Unlock()
w.clientsMutex.Lock()
w.authQueue = nil
w.clientsMutex.Unlock()
}
func authEqual(a, b *coreauth.Auth) bool {
return reflect.DeepEqual(normalizeAuth(a), normalizeAuth(b))
}
func normalizeAuth(a *coreauth.Auth) *coreauth.Auth {
if a == nil {
return nil
}
clone := a.Clone()
clone.CreatedAt = time.Time{}
clone.UpdatedAt = time.Time{}
clone.LastRefreshedAt = time.Time{}
clone.NextRefreshAfter = time.Time{}
clone.Runtime = nil
clone.Quota.NextRecoverAt = time.Time{}
return clone
}
func snapshotCoreAuths(cfg *config.Config, authDir string) []*coreauth.Auth {
ctx := &synthesizer.SynthesisContext{
Config: cfg,
AuthDir: authDir,
Now: time.Now(),
IDGenerator: synthesizer.NewStableIDGenerator(),
}
var out []*coreauth.Auth
configSynth := synthesizer.NewConfigSynthesizer()
if auths, err := configSynth.Synthesize(ctx); err == nil {
out = append(out, auths...)
}
fileSynth := synthesizer.NewFileSynthesizer()
if auths, err := fileSynth.Synthesize(ctx); err == nil {
out = append(out, auths...)
}
return out
}
================================================
FILE: internal/watcher/events.go
================================================
// events.go implements fsnotify event handling for config and auth file changes.
// It normalizes paths, debounces noisy events, and triggers reload/update logic.
package watcher
import (
"context"
"crypto/sha256"
"encoding/hex"
"os"
"path/filepath"
"runtime"
"strings"
"time"
"github.com/fsnotify/fsnotify"
log "github.com/sirupsen/logrus"
)
func matchProvider(provider string, targets []string) (string, bool) {
p := strings.ToLower(strings.TrimSpace(provider))
for _, t := range targets {
if strings.EqualFold(p, strings.TrimSpace(t)) {
return p, true
}
}
return p, false
}
func (w *Watcher) start(ctx context.Context) error {
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
return errAddConfig
}
log.Debugf("watching config file: %s", w.configPath)
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
return errAddAuthDir
}
log.Debugf("watching auth directory: %s", w.authDir)
go w.processEvents(ctx)
w.reloadClients(true, nil, false)
return nil
}
func (w *Watcher) processEvents(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
w.handleEvent(event)
case errWatch, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Errorf("file watcher error: %v", errWatch)
}
}
}
func (w *Watcher) handleEvent(event fsnotify.Event) {
// Filter only relevant events: config file or auth-dir JSON files.
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
normalizedName := w.normalizeAuthPath(event.Name)
normalizedConfigPath := w.normalizeAuthPath(w.configPath)
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
if !isConfigEvent && !isAuthJSON {
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
return
}
now := time.Now()
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
// Handle config file changes
if isConfigEvent {
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
w.scheduleConfigReload()
return
}
// Handle auth directory changes incrementally (.json only)
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
if w.shouldDebounceRemove(normalizedName, now) {
log.Debugf("debouncing remove event for %s", filepath.Base(event.Name))
return
}
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
// Wait briefly; if the path exists again, treat as an update instead of removal.
time.Sleep(replaceCheckDelay)
if _, statErr := os.Stat(event.Name); statErr == nil {
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
return
}
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
w.addOrUpdateClient(event.Name)
return
}
if !w.isKnownAuthFile(event.Name) {
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
return
}
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
w.removeClient(event.Name)
return
}
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
return
}
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
w.addOrUpdateClient(event.Name)
}
}
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
data, errRead := os.ReadFile(path)
if errRead != nil {
return false, errRead
}
if len(data) == 0 {
return false, nil
}
sum := sha256.Sum256(data)
curHash := hex.EncodeToString(sum[:])
normalized := w.normalizeAuthPath(path)
w.clientsMutex.RLock()
prevHash, ok := w.lastAuthHashes[normalized]
w.clientsMutex.RUnlock()
if ok && prevHash == curHash {
return true, nil
}
return false, nil
}
func (w *Watcher) isKnownAuthFile(path string) bool {
normalized := w.normalizeAuthPath(path)
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
_, ok := w.lastAuthHashes[normalized]
return ok
}
func (w *Watcher) normalizeAuthPath(path string) string {
trimmed := strings.TrimSpace(path)
if trimmed == "" {
return ""
}
cleaned := filepath.Clean(trimmed)
if runtime.GOOS == "windows" {
cleaned = strings.TrimPrefix(cleaned, `\\?\`)
cleaned = strings.ToLower(cleaned)
}
return cleaned
}
func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool {
if normalizedPath == "" {
return false
}
w.clientsMutex.Lock()
if w.lastRemoveTimes == nil {
w.lastRemoveTimes = make(map[string]time.Time)
}
if last, ok := w.lastRemoveTimes[normalizedPath]; ok {
if now.Sub(last) < authRemoveDebounceWindow {
w.clientsMutex.Unlock()
return true
}
}
w.lastRemoveTimes[normalizedPath] = now
if len(w.lastRemoveTimes) > 128 {
cutoff := now.Add(-2 * authRemoveDebounceWindow)
for p, t := range w.lastRemoveTimes {
if t.Before(cutoff) {
delete(w.lastRemoveTimes, p)
}
}
}
w.clientsMutex.Unlock()
return false
}
================================================
FILE: internal/watcher/synthesizer/config.go
================================================
package synthesizer
import (
"fmt"
"strconv"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// ConfigSynthesizer generates Auth entries from configuration API keys.
// It handles Gemini, Claude, Codex, OpenAI-compat, and Vertex-compat providers.
type ConfigSynthesizer struct{}
// NewConfigSynthesizer creates a new ConfigSynthesizer instance.
func NewConfigSynthesizer() *ConfigSynthesizer {
return &ConfigSynthesizer{}
}
// Synthesize generates Auth entries from config API keys.
func (s *ConfigSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error) {
out := make([]*coreauth.Auth, 0, 32)
if ctx == nil || ctx.Config == nil {
return out, nil
}
// Gemini API Keys
out = append(out, s.synthesizeGeminiKeys(ctx)...)
// Claude API Keys
out = append(out, s.synthesizeClaudeKeys(ctx)...)
// Codex API Keys
out = append(out, s.synthesizeCodexKeys(ctx)...)
// OpenAI-compat
out = append(out, s.synthesizeOpenAICompat(ctx)...)
// Vertex-compat
out = append(out, s.synthesizeVertexCompat(ctx)...)
return out, nil
}
// synthesizeGeminiKeys creates Auth entries for Gemini API keys.
func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*coreauth.Auth {
cfg := ctx.Config
now := ctx.Now
idGen := ctx.IDGenerator
out := make([]*coreauth.Auth, 0, len(cfg.GeminiKey))
for i := range cfg.GeminiKey {
entry := cfg.GeminiKey[i]
key := strings.TrimSpace(entry.APIKey)
if key == "" {
continue
}
prefix := strings.TrimSpace(entry.Prefix)
base := strings.TrimSpace(entry.BaseURL)
proxyURL := strings.TrimSpace(entry.ProxyURL)
id, token := idGen.Next("gemini:apikey", key, base)
attrs := map[string]string{
"source": fmt.Sprintf("config:gemini[%s]", token),
"api_key": key,
}
if entry.Priority != 0 {
attrs["priority"] = strconv.Itoa(entry.Priority)
}
if base != "" {
attrs["base_url"] = base
}
if hash := diff.ComputeGeminiModelsHash(entry.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(entry.Headers, attrs)
a := &coreauth.Auth{
ID: id,
Provider: "gemini",
Label: "gemini-apikey",
Prefix: prefix,
Status: coreauth.StatusActive,
ProxyURL: proxyURL,
Attributes: attrs,
CreatedAt: now,
UpdatedAt: now,
}
ApplyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey")
out = append(out, a)
}
return out
}
// synthesizeClaudeKeys creates Auth entries for Claude API keys.
func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*coreauth.Auth {
cfg := ctx.Config
now := ctx.Now
idGen := ctx.IDGenerator
out := make([]*coreauth.Auth, 0, len(cfg.ClaudeKey))
for i := range cfg.ClaudeKey {
ck := cfg.ClaudeKey[i]
key := strings.TrimSpace(ck.APIKey)
if key == "" {
continue
}
prefix := strings.TrimSpace(ck.Prefix)
base := strings.TrimSpace(ck.BaseURL)
id, token := idGen.Next("claude:apikey", key, base)
attrs := map[string]string{
"source": fmt.Sprintf("config:claude[%s]", token),
"api_key": key,
}
if ck.Priority != 0 {
attrs["priority"] = strconv.Itoa(ck.Priority)
}
if base != "" {
attrs["base_url"] = base
}
if hash := diff.ComputeClaudeModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(ck.Headers, attrs)
proxyURL := strings.TrimSpace(ck.ProxyURL)
a := &coreauth.Auth{
ID: id,
Provider: "claude",
Label: "claude-apikey",
Prefix: prefix,
Status: coreauth.StatusActive,
ProxyURL: proxyURL,
Attributes: attrs,
CreatedAt: now,
UpdatedAt: now,
}
ApplyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
out = append(out, a)
}
return out
}
// synthesizeCodexKeys creates Auth entries for Codex API keys.
func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreauth.Auth {
cfg := ctx.Config
now := ctx.Now
idGen := ctx.IDGenerator
out := make([]*coreauth.Auth, 0, len(cfg.CodexKey))
for i := range cfg.CodexKey {
ck := cfg.CodexKey[i]
key := strings.TrimSpace(ck.APIKey)
if key == "" {
continue
}
prefix := strings.TrimSpace(ck.Prefix)
id, token := idGen.Next("codex:apikey", key, ck.BaseURL)
attrs := map[string]string{
"source": fmt.Sprintf("config:codex[%s]", token),
"api_key": key,
}
if ck.Priority != 0 {
attrs["priority"] = strconv.Itoa(ck.Priority)
}
if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL
}
if ck.Websockets {
attrs["websockets"] = "true"
}
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(ck.Headers, attrs)
proxyURL := strings.TrimSpace(ck.ProxyURL)
a := &coreauth.Auth{
ID: id,
Provider: "codex",
Label: "codex-apikey",
Prefix: prefix,
Status: coreauth.StatusActive,
ProxyURL: proxyURL,
Attributes: attrs,
CreatedAt: now,
UpdatedAt: now,
}
ApplyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
out = append(out, a)
}
return out
}
// synthesizeOpenAICompat creates Auth entries for OpenAI-compatible providers.
func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*coreauth.Auth {
cfg := ctx.Config
now := ctx.Now
idGen := ctx.IDGenerator
out := make([]*coreauth.Auth, 0)
for i := range cfg.OpenAICompatibility {
compat := &cfg.OpenAICompatibility[i]
prefix := strings.TrimSpace(compat.Prefix)
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
if providerName == "" {
providerName = "openai-compatibility"
}
base := strings.TrimSpace(compat.BaseURL)
// Handle new APIKeyEntries format (preferred)
createdEntries := 0
for j := range compat.APIKeyEntries {
entry := &compat.APIKeyEntries[j]
key := strings.TrimSpace(entry.APIKey)
proxyURL := strings.TrimSpace(entry.ProxyURL)
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
id, token := idGen.Next(idKind, key, base, proxyURL)
attrs := map[string]string{
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
"base_url": base,
"compat_name": compat.Name,
"provider_key": providerName,
}
if compat.Priority != 0 {
attrs["priority"] = strconv.Itoa(compat.Priority)
}
if key != "" {
attrs["api_key"] = key
}
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(compat.Headers, attrs)
a := &coreauth.Auth{
ID: id,
Provider: providerName,
Label: compat.Name,
Prefix: prefix,
Status: coreauth.StatusActive,
ProxyURL: proxyURL,
Attributes: attrs,
CreatedAt: now,
UpdatedAt: now,
}
out = append(out, a)
createdEntries++
}
// Fallback: create entry without API key if no APIKeyEntries
if createdEntries == 0 {
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
id, token := idGen.Next(idKind, base)
attrs := map[string]string{
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
"base_url": base,
"compat_name": compat.Name,
"provider_key": providerName,
}
if compat.Priority != 0 {
attrs["priority"] = strconv.Itoa(compat.Priority)
}
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(compat.Headers, attrs)
a := &coreauth.Auth{
ID: id,
Provider: providerName,
Label: compat.Name,
Prefix: prefix,
Status: coreauth.StatusActive,
Attributes: attrs,
CreatedAt: now,
UpdatedAt: now,
}
out = append(out, a)
}
}
return out
}
// synthesizeVertexCompat creates Auth entries for Vertex-compatible providers.
func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*coreauth.Auth {
cfg := ctx.Config
now := ctx.Now
idGen := ctx.IDGenerator
out := make([]*coreauth.Auth, 0, len(cfg.VertexCompatAPIKey))
for i := range cfg.VertexCompatAPIKey {
compat := &cfg.VertexCompatAPIKey[i]
providerName := "vertex"
base := strings.TrimSpace(compat.BaseURL)
key := strings.TrimSpace(compat.APIKey)
prefix := strings.TrimSpace(compat.Prefix)
proxyURL := strings.TrimSpace(compat.ProxyURL)
idKind := "vertex:apikey"
id, token := idGen.Next(idKind, key, base, proxyURL)
attrs := map[string]string{
"source": fmt.Sprintf("config:vertex-apikey[%s]", token),
"base_url": base,
"provider_key": providerName,
}
if compat.Priority != 0 {
attrs["priority"] = strconv.Itoa(compat.Priority)
}
if key != "" {
attrs["api_key"] = key
}
if hash := diff.ComputeVertexCompatModelsHash(compat.Models); hash != "" {
attrs["models_hash"] = hash
}
addConfigHeadersToAttrs(compat.Headers, attrs)
a := &coreauth.Auth{
ID: id,
Provider: providerName,
Label: "vertex-apikey",
Prefix: prefix,
Status: coreauth.StatusActive,
ProxyURL: proxyURL,
Attributes: attrs,
CreatedAt: now,
UpdatedAt: now,
}
ApplyAuthExcludedModelsMeta(a, cfg, compat.ExcludedModels, "apikey")
out = append(out, a)
}
return out
}
================================================
FILE: internal/watcher/synthesizer/config_test.go
================================================
package synthesizer
import (
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestNewConfigSynthesizer(t *testing.T) {
synth := NewConfigSynthesizer()
if synth == nil {
t.Fatal("expected non-nil synthesizer")
}
}
func TestConfigSynthesizer_Synthesize_NilContext(t *testing.T) {
synth := NewConfigSynthesizer()
auths, err := synth.Synthesize(nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 0 {
t.Fatalf("expected empty auths, got %d", len(auths))
}
}
func TestConfigSynthesizer_Synthesize_NilConfig(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: nil,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 0 {
t.Fatalf("expected empty auths, got %d", len(auths))
}
}
func TestConfigSynthesizer_GeminiKeys(t *testing.T) {
tests := []struct {
name string
geminiKeys []config.GeminiKey
wantLen int
validate func(*testing.T, []*coreauth.Auth)
}{
{
name: "single gemini key",
geminiKeys: []config.GeminiKey{
{APIKey: "test-key-123", Prefix: "team-a"},
},
wantLen: 1,
validate: func(t *testing.T, auths []*coreauth.Auth) {
if auths[0].Provider != "gemini" {
t.Errorf("expected provider gemini, got %s", auths[0].Provider)
}
if auths[0].Prefix != "team-a" {
t.Errorf("expected prefix team-a, got %s", auths[0].Prefix)
}
if auths[0].Label != "gemini-apikey" {
t.Errorf("expected label gemini-apikey, got %s", auths[0].Label)
}
if auths[0].Attributes["api_key"] != "test-key-123" {
t.Errorf("expected api_key test-key-123, got %s", auths[0].Attributes["api_key"])
}
if auths[0].Status != coreauth.StatusActive {
t.Errorf("expected status active, got %s", auths[0].Status)
}
},
},
{
name: "gemini key with base url and proxy",
geminiKeys: []config.GeminiKey{
{
APIKey: "api-key",
BaseURL: "https://custom.api.com",
ProxyURL: "http://proxy.local:8080",
Prefix: "custom",
},
},
wantLen: 1,
validate: func(t *testing.T, auths []*coreauth.Auth) {
if auths[0].Attributes["base_url"] != "https://custom.api.com" {
t.Errorf("expected base_url https://custom.api.com, got %s", auths[0].Attributes["base_url"])
}
if auths[0].ProxyURL != "http://proxy.local:8080" {
t.Errorf("expected proxy_url http://proxy.local:8080, got %s", auths[0].ProxyURL)
}
},
},
{
name: "gemini key with headers",
geminiKeys: []config.GeminiKey{
{
APIKey: "api-key",
Headers: map[string]string{"X-Custom": "value"},
},
},
wantLen: 1,
validate: func(t *testing.T, auths []*coreauth.Auth) {
if auths[0].Attributes["header:X-Custom"] != "value" {
t.Errorf("expected header:X-Custom=value, got %s", auths[0].Attributes["header:X-Custom"])
}
},
},
{
name: "empty api key skipped",
geminiKeys: []config.GeminiKey{
{APIKey: ""},
{APIKey: " "},
{APIKey: "valid-key"},
},
wantLen: 1,
},
{
name: "multiple gemini keys",
geminiKeys: []config.GeminiKey{
{APIKey: "key-1", Prefix: "a"},
{APIKey: "key-2", Prefix: "b"},
{APIKey: "key-3", Prefix: "c"},
},
wantLen: 3,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
GeminiKey: tt.geminiKeys,
},
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != tt.wantLen {
t.Fatalf("expected %d auths, got %d", tt.wantLen, len(auths))
}
if tt.validate != nil && len(auths) > 0 {
tt.validate(t, auths)
}
})
}
}
func TestConfigSynthesizer_ClaudeKeys(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
ClaudeKey: []config.ClaudeKey{
{
APIKey: "sk-ant-api-xxx",
Prefix: "main",
BaseURL: "https://api.anthropic.com",
Models: []config.ClaudeModel{
{Name: "claude-3-opus"},
{Name: "claude-3-sonnet"},
},
},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
if auths[0].Provider != "claude" {
t.Errorf("expected provider claude, got %s", auths[0].Provider)
}
if auths[0].Label != "claude-apikey" {
t.Errorf("expected label claude-apikey, got %s", auths[0].Label)
}
if auths[0].Prefix != "main" {
t.Errorf("expected prefix main, got %s", auths[0].Prefix)
}
if auths[0].Attributes["api_key"] != "sk-ant-api-xxx" {
t.Errorf("expected api_key sk-ant-api-xxx, got %s", auths[0].Attributes["api_key"])
}
if _, ok := auths[0].Attributes["models_hash"]; !ok {
t.Error("expected models_hash in attributes")
}
}
func TestConfigSynthesizer_ClaudeKeys_SkipsEmptyAndHeaders(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
ClaudeKey: []config.ClaudeKey{
{APIKey: ""}, // empty, should be skipped
{APIKey: " "}, // whitespace, should be skipped
{APIKey: "valid-key", Headers: map[string]string{"X-Custom": "value"}},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth (empty keys skipped), got %d", len(auths))
}
if auths[0].Attributes["header:X-Custom"] != "value" {
t.Errorf("expected header:X-Custom=value, got %s", auths[0].Attributes["header:X-Custom"])
}
}
func TestConfigSynthesizer_CodexKeys(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
CodexKey: []config.CodexKey{
{
APIKey: "codex-key-123",
Prefix: "dev",
BaseURL: "https://api.openai.com",
ProxyURL: "http://proxy.local",
Websockets: true,
},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
if auths[0].Provider != "codex" {
t.Errorf("expected provider codex, got %s", auths[0].Provider)
}
if auths[0].Label != "codex-apikey" {
t.Errorf("expected label codex-apikey, got %s", auths[0].Label)
}
if auths[0].ProxyURL != "http://proxy.local" {
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
}
if auths[0].Attributes["websockets"] != "true" {
t.Errorf("expected websockets=true, got %s", auths[0].Attributes["websockets"])
}
}
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
CodexKey: []config.CodexKey{
{APIKey: ""}, // empty, should be skipped
{APIKey: " "}, // whitespace, should be skipped
{APIKey: "valid-key", Headers: map[string]string{"Authorization": "Bearer xyz"}},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth (empty keys skipped), got %d", len(auths))
}
if auths[0].Attributes["header:Authorization"] != "Bearer xyz" {
t.Errorf("expected header:Authorization=Bearer xyz, got %s", auths[0].Attributes["header:Authorization"])
}
}
func TestConfigSynthesizer_OpenAICompat(t *testing.T) {
tests := []struct {
name string
compat []config.OpenAICompatibility
wantLen int
}{
{
name: "with APIKeyEntries",
compat: []config.OpenAICompatibility{
{
Name: "CustomProvider",
BaseURL: "https://custom.api.com",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "key-1"},
{APIKey: "key-2"},
},
},
},
wantLen: 2,
},
{
name: "empty APIKeyEntries included (legacy)",
compat: []config.OpenAICompatibility{
{
Name: "EmptyKeys",
BaseURL: "https://empty.api.com",
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: ""},
{APIKey: " "},
},
},
},
wantLen: 2,
},
{
name: "without APIKeyEntries (fallback)",
compat: []config.OpenAICompatibility{
{
Name: "NoKeyProvider",
BaseURL: "https://no-key.api.com",
},
},
wantLen: 1,
},
{
name: "empty name defaults",
compat: []config.OpenAICompatibility{
{
Name: "",
BaseURL: "https://default.api.com",
},
},
wantLen: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
OpenAICompatibility: tt.compat,
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != tt.wantLen {
t.Fatalf("expected %d auths, got %d", tt.wantLen, len(auths))
}
})
}
}
func TestConfigSynthesizer_VertexCompat(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
VertexCompatAPIKey: []config.VertexCompatKey{
{
APIKey: "vertex-key-123",
BaseURL: "https://vertex.googleapis.com",
Prefix: "vertex-prod",
},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
if auths[0].Provider != "vertex" {
t.Errorf("expected provider vertex, got %s", auths[0].Provider)
}
if auths[0].Label != "vertex-apikey" {
t.Errorf("expected label vertex-apikey, got %s", auths[0].Label)
}
if auths[0].Prefix != "vertex-prod" {
t.Errorf("expected prefix vertex-prod, got %s", auths[0].Prefix)
}
}
func TestConfigSynthesizer_VertexCompat_SkipsEmptyAndHeaders(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "", BaseURL: "https://vertex.api"}, // empty key creates auth without api_key attr
{APIKey: " ", BaseURL: "https://vertex.api"}, // whitespace key creates auth without api_key attr
{APIKey: "valid-key", BaseURL: "https://vertex.api", Headers: map[string]string{"X-Vertex": "test"}},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Vertex compat doesn't skip empty keys - it creates auths without api_key attribute
if len(auths) != 3 {
t.Fatalf("expected 3 auths, got %d", len(auths))
}
// First two should not have api_key attribute
if _, ok := auths[0].Attributes["api_key"]; ok {
t.Error("expected first auth to not have api_key attribute")
}
if _, ok := auths[1].Attributes["api_key"]; ok {
t.Error("expected second auth to not have api_key attribute")
}
// Third should have headers
if auths[2].Attributes["header:X-Vertex"] != "test" {
t.Errorf("expected header:X-Vertex=test, got %s", auths[2].Attributes["header:X-Vertex"])
}
}
func TestConfigSynthesizer_OpenAICompat_WithModelsHash(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
OpenAICompatibility: []config.OpenAICompatibility{
{
Name: "TestProvider",
BaseURL: "https://test.api.com",
Models: []config.OpenAICompatibilityModel{
{Name: "model-a"},
{Name: "model-b"},
},
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
{APIKey: "key-with-models"},
},
},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
if _, ok := auths[0].Attributes["models_hash"]; !ok {
t.Error("expected models_hash in attributes")
}
if auths[0].Attributes["api_key"] != "key-with-models" {
t.Errorf("expected api_key key-with-models, got %s", auths[0].Attributes["api_key"])
}
}
func TestConfigSynthesizer_OpenAICompat_FallbackWithModels(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
OpenAICompatibility: []config.OpenAICompatibility{
{
Name: "NoKeyWithModels",
BaseURL: "https://nokey.api.com",
Models: []config.OpenAICompatibilityModel{
{Name: "model-x"},
},
Headers: map[string]string{"X-API": "header-value"},
// No APIKeyEntries - should use fallback path
},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
if _, ok := auths[0].Attributes["models_hash"]; !ok {
t.Error("expected models_hash in fallback path")
}
if auths[0].Attributes["header:X-API"] != "header-value" {
t.Errorf("expected header:X-API=header-value, got %s", auths[0].Attributes["header:X-API"])
}
}
func TestConfigSynthesizer_VertexCompat_WithModels(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
VertexCompatAPIKey: []config.VertexCompatKey{
{
APIKey: "vertex-key",
BaseURL: "https://vertex.api",
Models: []config.VertexCompatModel{
{Name: "gemini-pro", Alias: "pro"},
{Name: "gemini-ultra", Alias: "ultra"},
},
},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
if _, ok := auths[0].Attributes["models_hash"]; !ok {
t.Error("expected models_hash in vertex auth with models")
}
}
func TestConfigSynthesizer_IDStability(t *testing.T) {
cfg := &config.Config{
GeminiKey: []config.GeminiKey{
{APIKey: "stable-key", Prefix: "test"},
},
}
// Generate IDs twice with fresh generators
synth1 := NewConfigSynthesizer()
ctx1 := &SynthesisContext{
Config: cfg,
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
IDGenerator: NewStableIDGenerator(),
}
auths1, _ := synth1.Synthesize(ctx1)
synth2 := NewConfigSynthesizer()
ctx2 := &SynthesisContext{
Config: cfg,
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
IDGenerator: NewStableIDGenerator(),
}
auths2, _ := synth2.Synthesize(ctx2)
if auths1[0].ID != auths2[0].ID {
t.Errorf("same config should produce same ID: got %q and %q", auths1[0].ID, auths2[0].ID)
}
}
func TestConfigSynthesizer_AllProviders(t *testing.T) {
synth := NewConfigSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
GeminiKey: []config.GeminiKey{
{APIKey: "gemini-key"},
},
ClaudeKey: []config.ClaudeKey{
{APIKey: "claude-key"},
},
CodexKey: []config.CodexKey{
{APIKey: "codex-key"},
},
OpenAICompatibility: []config.OpenAICompatibility{
{Name: "compat", BaseURL: "https://compat.api"},
},
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "vertex-key", BaseURL: "https://vertex.api"},
},
},
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 5 {
t.Fatalf("expected 5 auths, got %d", len(auths))
}
providers := make(map[string]bool)
for _, a := range auths {
providers[a.Provider] = true
}
expected := []string{"gemini", "claude", "codex", "compat", "vertex"}
for _, p := range expected {
if !providers[p] {
t.Errorf("expected provider %s not found", p)
}
}
}
================================================
FILE: internal/watcher/synthesizer/context.go
================================================
package synthesizer
import (
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// SynthesisContext provides the context needed for auth synthesis.
type SynthesisContext struct {
// Config is the current configuration
Config *config.Config
// AuthDir is the directory containing auth files
AuthDir string
// Now is the current time for timestamps
Now time.Time
// IDGenerator generates stable IDs for auth entries
IDGenerator *StableIDGenerator
}
================================================
FILE: internal/watcher/synthesizer/file.go
================================================
package synthesizer
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// FileSynthesizer generates Auth entries from OAuth JSON files.
// It handles file-based authentication and Gemini virtual auth generation.
type FileSynthesizer struct{}
// NewFileSynthesizer creates a new FileSynthesizer instance.
func NewFileSynthesizer() *FileSynthesizer {
return &FileSynthesizer{}
}
// Synthesize generates Auth entries from auth files in the auth directory.
func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error) {
out := make([]*coreauth.Auth, 0, 16)
if ctx == nil || ctx.AuthDir == "" {
return out, nil
}
entries, err := os.ReadDir(ctx.AuthDir)
if err != nil {
// Not an error if directory doesn't exist
return out, nil
}
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(strings.ToLower(name), ".json") {
continue
}
full := filepath.Join(ctx.AuthDir, name)
data, errRead := os.ReadFile(full)
if errRead != nil || len(data) == 0 {
continue
}
auths := synthesizeFileAuths(ctx, full, data)
if len(auths) == 0 {
continue
}
out = append(out, auths...)
}
return out, nil
}
// SynthesizeAuthFile generates Auth entries for one auth JSON file payload.
// It shares exactly the same mapping behavior as FileSynthesizer.Synthesize.
func SynthesizeAuthFile(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
return synthesizeFileAuths(ctx, fullPath, data)
}
func synthesizeFileAuths(ctx *SynthesisContext, fullPath string, data []byte) []*coreauth.Auth {
if ctx == nil || len(data) == 0 {
return nil
}
now := ctx.Now
cfg := ctx.Config
var metadata map[string]any
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
return nil
}
t, _ := metadata["type"].(string)
if t == "" {
return nil
}
provider := strings.ToLower(t)
if provider == "gemini" {
provider = "gemini-cli"
}
label := provider
if email, _ := metadata["email"].(string); email != "" {
label = email
}
// Use relative path under authDir as ID to stay consistent with the file-based token store.
id := fullPath
if strings.TrimSpace(ctx.AuthDir) != "" {
if rel, errRel := filepath.Rel(ctx.AuthDir, fullPath); errRel == nil && rel != "" {
id = rel
}
}
if runtime.GOOS == "windows" {
id = strings.ToLower(id)
}
proxyURL := ""
if p, ok := metadata["proxy_url"].(string); ok {
proxyURL = p
}
prefix := ""
if rawPrefix, ok := metadata["prefix"].(string); ok {
trimmed := strings.TrimSpace(rawPrefix)
trimmed = strings.Trim(trimmed, "/")
if trimmed != "" && !strings.Contains(trimmed, "/") {
prefix = trimmed
}
}
disabled, _ := metadata["disabled"].(bool)
status := coreauth.StatusActive
if disabled {
status = coreauth.StatusDisabled
}
// Read per-account excluded models from the OAuth JSON file.
perAccountExcluded := extractExcludedModelsFromMetadata(metadata)
a := &coreauth.Auth{
ID: id,
Provider: provider,
Label: label,
Prefix: prefix,
Status: status,
Disabled: disabled,
Attributes: map[string]string{
"source": fullPath,
"path": fullPath,
},
ProxyURL: proxyURL,
Metadata: metadata,
CreatedAt: now,
UpdatedAt: now,
}
// Read priority from auth file.
if rawPriority, ok := metadata["priority"]; ok {
switch v := rawPriority.(type) {
case float64:
a.Attributes["priority"] = strconv.Itoa(int(v))
case string:
priority := strings.TrimSpace(v)
if _, errAtoi := strconv.Atoi(priority); errAtoi == nil {
a.Attributes["priority"] = priority
}
}
}
// Read note from auth file.
if rawNote, ok := metadata["note"]; ok {
if note, isStr := rawNote.(string); isStr {
if trimmed := strings.TrimSpace(note); trimmed != "" {
a.Attributes["note"] = trimmed
}
}
}
ApplyAuthExcludedModelsMeta(a, cfg, perAccountExcluded, "oauth")
// For codex auth files, extract plan_type from the JWT id_token.
if provider == "codex" {
if idTokenRaw, ok := metadata["id_token"].(string); ok && strings.TrimSpace(idTokenRaw) != "" {
if claims, errParse := codex.ParseJWTToken(idTokenRaw); errParse == nil && claims != nil {
if pt := strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType); pt != "" {
a.Attributes["plan_type"] = pt
}
}
}
}
if provider == "gemini-cli" {
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
for _, v := range virtuals {
ApplyAuthExcludedModelsMeta(v, cfg, perAccountExcluded, "oauth")
}
out := make([]*coreauth.Auth, 0, 1+len(virtuals))
out = append(out, a)
out = append(out, virtuals...)
return out
}
}
return []*coreauth.Auth{a}
}
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
// It disables the primary auth and creates one virtual auth per project.
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {
if primary == nil || metadata == nil {
return nil
}
projects := splitGeminiProjectIDs(metadata)
if len(projects) <= 1 {
return nil
}
email, _ := metadata["email"].(string)
shared := geminicli.NewSharedCredential(primary.ID, email, metadata, projects)
primary.Disabled = true
primary.Status = coreauth.StatusDisabled
primary.Runtime = shared
if primary.Attributes == nil {
primary.Attributes = make(map[string]string)
}
primary.Attributes["gemini_virtual_primary"] = "true"
primary.Attributes["virtual_children"] = strings.Join(projects, ",")
source := primary.Attributes["source"]
authPath := primary.Attributes["path"]
originalProvider := primary.Provider
if originalProvider == "" {
originalProvider = "gemini-cli"
}
label := primary.Label
if label == "" {
label = originalProvider
}
virtuals := make([]*coreauth.Auth, 0, len(projects))
for _, projectID := range projects {
attrs := map[string]string{
"runtime_only": "true",
"gemini_virtual_parent": primary.ID,
"gemini_virtual_project": projectID,
}
if source != "" {
attrs["source"] = source
}
if authPath != "" {
attrs["path"] = authPath
}
// Propagate priority from primary auth to virtual auths
if priorityVal, hasPriority := primary.Attributes["priority"]; hasPriority && priorityVal != "" {
attrs["priority"] = priorityVal
}
// Propagate note from primary auth to virtual auths
if noteVal, hasNote := primary.Attributes["note"]; hasNote && noteVal != "" {
attrs["note"] = noteVal
}
metadataCopy := map[string]any{
"email": email,
"project_id": projectID,
"virtual": true,
"virtual_parent_id": primary.ID,
"type": metadata["type"],
}
if v, ok := metadata["disable_cooling"]; ok {
metadataCopy["disable_cooling"] = v
} else if v, ok := metadata["disable-cooling"]; ok {
metadataCopy["disable_cooling"] = v
}
if v, ok := metadata["request_retry"]; ok {
metadataCopy["request_retry"] = v
} else if v, ok := metadata["request-retry"]; ok {
metadataCopy["request_retry"] = v
}
proxy := strings.TrimSpace(primary.ProxyURL)
if proxy != "" {
metadataCopy["proxy_url"] = proxy
}
virtual := &coreauth.Auth{
ID: buildGeminiVirtualID(primary.ID, projectID),
Provider: originalProvider,
Label: fmt.Sprintf("%s [%s]", label, projectID),
Status: coreauth.StatusActive,
Attributes: attrs,
Metadata: metadataCopy,
ProxyURL: primary.ProxyURL,
Prefix: primary.Prefix,
CreatedAt: primary.CreatedAt,
UpdatedAt: primary.UpdatedAt,
Runtime: geminicli.NewVirtualCredential(projectID, shared),
}
virtuals = append(virtuals, virtual)
}
return virtuals
}
// splitGeminiProjectIDs extracts and deduplicates project IDs from metadata.
func splitGeminiProjectIDs(metadata map[string]any) []string {
raw, _ := metadata["project_id"].(string)
trimmed := strings.TrimSpace(raw)
if trimmed == "" {
return nil
}
parts := strings.Split(trimmed, ",")
result := make([]string, 0, len(parts))
seen := make(map[string]struct{}, len(parts))
for _, part := range parts {
id := strings.TrimSpace(part)
if id == "" {
continue
}
if _, ok := seen[id]; ok {
continue
}
seen[id] = struct{}{}
result = append(result, id)
}
return result
}
// buildGeminiVirtualID constructs a virtual auth ID from base ID and project ID.
func buildGeminiVirtualID(baseID, projectID string) string {
project := strings.TrimSpace(projectID)
if project == "" {
project = "project"
}
replacer := strings.NewReplacer("/", "_", "\\", "_", " ", "_")
return fmt.Sprintf("%s::%s", baseID, replacer.Replace(project))
}
// extractExcludedModelsFromMetadata reads per-account excluded models from the OAuth JSON metadata.
// Supports both "excluded_models" and "excluded-models" keys, and accepts both []string and []interface{}.
func extractExcludedModelsFromMetadata(metadata map[string]any) []string {
if metadata == nil {
return nil
}
// Try both key formats
raw, ok := metadata["excluded_models"]
if !ok {
raw, ok = metadata["excluded-models"]
}
if !ok || raw == nil {
return nil
}
var stringSlice []string
switch v := raw.(type) {
case []string:
stringSlice = v
case []interface{}:
stringSlice = make([]string, 0, len(v))
for _, item := range v {
if s, ok := item.(string); ok {
stringSlice = append(stringSlice, s)
}
}
default:
return nil
}
result := make([]string, 0, len(stringSlice))
for _, s := range stringSlice {
if trimmed := strings.TrimSpace(s); trimmed != "" {
result = append(result, trimmed)
}
}
return result
}
================================================
FILE: internal/watcher/synthesizer/file_test.go
================================================
package synthesizer
import (
"encoding/json"
"os"
"path/filepath"
"strings"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestNewFileSynthesizer(t *testing.T) {
synth := NewFileSynthesizer()
if synth == nil {
t.Fatal("expected non-nil synthesizer")
}
}
func TestFileSynthesizer_Synthesize_NilContext(t *testing.T) {
synth := NewFileSynthesizer()
auths, err := synth.Synthesize(nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 0 {
t.Fatalf("expected empty auths, got %d", len(auths))
}
}
func TestFileSynthesizer_Synthesize_EmptyAuthDir(t *testing.T) {
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: "",
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 0 {
t.Fatalf("expected empty auths, got %d", len(auths))
}
}
func TestFileSynthesizer_Synthesize_NonExistentDir(t *testing.T) {
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: "/non/existent/path",
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 0 {
t.Fatalf("expected empty auths, got %d", len(auths))
}
}
func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
tempDir := t.TempDir()
// Create a valid auth file
authData := map[string]any{
"type": "claude",
"email": "test@example.com",
"proxy_url": "http://proxy.local",
"prefix": "test-prefix",
"disable_cooling": true,
"request_retry": 2,
}
data, _ := json.Marshal(authData)
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
if err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
if auths[0].Provider != "claude" {
t.Errorf("expected provider claude, got %s", auths[0].Provider)
}
if auths[0].Label != "test@example.com" {
t.Errorf("expected label test@example.com, got %s", auths[0].Label)
}
if auths[0].Prefix != "test-prefix" {
t.Errorf("expected prefix test-prefix, got %s", auths[0].Prefix)
}
if auths[0].ProxyURL != "http://proxy.local" {
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
}
if v, ok := auths[0].Metadata["disable_cooling"].(bool); !ok || !v {
t.Errorf("expected disable_cooling true, got %v", auths[0].Metadata["disable_cooling"])
}
if v, ok := auths[0].Metadata["request_retry"].(float64); !ok || int(v) != 2 {
t.Errorf("expected request_retry 2, got %v", auths[0].Metadata["request_retry"])
}
if auths[0].Status != coreauth.StatusActive {
t.Errorf("expected status active, got %s", auths[0].Status)
}
}
func TestFileSynthesizer_Synthesize_GeminiProviderMapping(t *testing.T) {
tempDir := t.TempDir()
// Gemini type should be mapped to gemini-cli
authData := map[string]any{
"type": "gemini",
"email": "gemini@example.com",
}
data, _ := json.Marshal(authData)
err := os.WriteFile(filepath.Join(tempDir, "gemini-auth.json"), data, 0644)
if err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
if auths[0].Provider != "gemini-cli" {
t.Errorf("gemini should be mapped to gemini-cli, got %s", auths[0].Provider)
}
}
func TestFileSynthesizer_Synthesize_SkipsInvalidFiles(t *testing.T) {
tempDir := t.TempDir()
// Create various invalid files
_ = os.WriteFile(filepath.Join(tempDir, "not-json.txt"), []byte("text content"), 0644)
_ = os.WriteFile(filepath.Join(tempDir, "invalid.json"), []byte("not valid json"), 0644)
_ = os.WriteFile(filepath.Join(tempDir, "empty.json"), []byte(""), 0644)
_ = os.WriteFile(filepath.Join(tempDir, "no-type.json"), []byte(`{"email": "test@example.com"}`), 0644)
// Create one valid file
validData, _ := json.Marshal(map[string]any{"type": "claude", "email": "valid@example.com"})
_ = os.WriteFile(filepath.Join(tempDir, "valid.json"), validData, 0644)
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("only valid auth file should be processed, got %d", len(auths))
}
if auths[0].Label != "valid@example.com" {
t.Errorf("expected label valid@example.com, got %s", auths[0].Label)
}
}
func TestFileSynthesizer_Synthesize_SkipsDirectories(t *testing.T) {
tempDir := t.TempDir()
// Create a subdirectory with a json file inside
subDir := filepath.Join(tempDir, "subdir.json")
err := os.Mkdir(subDir, 0755)
if err != nil {
t.Fatalf("failed to create subdir: %v", err)
}
// Create a valid file in root
validData, _ := json.Marshal(map[string]any{"type": "claude"})
_ = os.WriteFile(filepath.Join(tempDir, "valid.json"), validData, 0644)
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
}
func TestFileSynthesizer_Synthesize_RelativeID(t *testing.T) {
tempDir := t.TempDir()
authData := map[string]any{"type": "claude"}
data, _ := json.Marshal(authData)
err := os.WriteFile(filepath.Join(tempDir, "my-auth.json"), data, 0644)
if err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
// ID should be relative path
if auths[0].ID != "my-auth.json" {
t.Errorf("expected ID my-auth.json, got %s", auths[0].ID)
}
}
func TestFileSynthesizer_Synthesize_PrefixValidation(t *testing.T) {
tests := []struct {
name string
prefix string
wantPrefix string
}{
{"valid prefix", "myprefix", "myprefix"},
{"prefix with slashes trimmed", "/myprefix/", "myprefix"},
{"prefix with spaces trimmed", " myprefix ", "myprefix"},
{"prefix with internal slash rejected", "my/prefix", ""},
{"empty prefix", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
authData := map[string]any{
"type": "claude",
"prefix": tt.prefix,
}
data, _ := json.Marshal(authData)
_ = os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
if auths[0].Prefix != tt.wantPrefix {
t.Errorf("expected prefix %q, got %q", tt.wantPrefix, auths[0].Prefix)
}
})
}
}
func TestFileSynthesizer_Synthesize_PriorityParsing(t *testing.T) {
tests := []struct {
name string
priority any
want string
hasValue bool
}{
{
name: "string with spaces",
priority: " 10 ",
want: "10",
hasValue: true,
},
{
name: "number",
priority: 8,
want: "8",
hasValue: true,
},
{
name: "invalid string",
priority: "1x",
hasValue: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
authData := map[string]any{
"type": "claude",
"priority": tt.priority,
}
data, _ := json.Marshal(authData)
errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
if errWriteFile != nil {
t.Fatalf("failed to write auth file: %v", errWriteFile)
}
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, errSynthesize := synth.Synthesize(ctx)
if errSynthesize != nil {
t.Fatalf("unexpected error: %v", errSynthesize)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
value, ok := auths[0].Attributes["priority"]
if tt.hasValue {
if !ok {
t.Fatal("expected priority attribute to be set")
}
if value != tt.want {
t.Fatalf("expected priority %q, got %q", tt.want, value)
}
return
}
if ok {
t.Fatalf("expected priority attribute to be absent, got %q", value)
}
})
}
}
func TestFileSynthesizer_Synthesize_OAuthExcludedModelsMerged(t *testing.T) {
tempDir := t.TempDir()
authData := map[string]any{
"type": "claude",
"excluded_models": []string{"custom-model", "MODEL-B"},
}
data, _ := json.Marshal(authData)
errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
if errWriteFile != nil {
t.Fatalf("failed to write auth file: %v", errWriteFile)
}
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{
OAuthExcludedModels: map[string][]string{
"claude": {"shared", "model-b"},
},
},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, errSynthesize := synth.Synthesize(ctx)
if errSynthesize != nil {
t.Fatalf("unexpected error: %v", errSynthesize)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
got := auths[0].Attributes["excluded_models"]
want := "custom-model,model-b,shared"
if got != want {
t.Fatalf("expected excluded_models %q, got %q", want, got)
}
}
func TestSynthesizeGeminiVirtualAuths_NilInputs(t *testing.T) {
now := time.Now()
if SynthesizeGeminiVirtualAuths(nil, nil, now) != nil {
t.Error("expected nil for nil primary")
}
if SynthesizeGeminiVirtualAuths(&coreauth.Auth{}, nil, now) != nil {
t.Error("expected nil for nil metadata")
}
if SynthesizeGeminiVirtualAuths(nil, map[string]any{}, now) != nil {
t.Error("expected nil for nil primary with metadata")
}
}
func TestSynthesizeGeminiVirtualAuths_SingleProject(t *testing.T) {
now := time.Now()
primary := &coreauth.Auth{
ID: "test-id",
Provider: "gemini-cli",
Label: "test@example.com",
}
metadata := map[string]any{
"project_id": "single-project",
"email": "test@example.com",
"type": "gemini",
}
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
if virtuals != nil {
t.Error("single project should not create virtuals")
}
}
func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
now := time.Now()
primary := &coreauth.Auth{
ID: "primary-id",
Provider: "gemini-cli",
Label: "test@example.com",
Prefix: "test-prefix",
ProxyURL: "http://proxy.local",
Attributes: map[string]string{
"source": "test-source",
"path": "/path/to/auth",
},
}
metadata := map[string]any{
"project_id": "project-a, project-b, project-c",
"email": "test@example.com",
"type": "gemini",
"request_retry": 2,
"disable_cooling": true,
}
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
if len(virtuals) != 3 {
t.Fatalf("expected 3 virtuals, got %d", len(virtuals))
}
// Check primary is disabled
if !primary.Disabled {
t.Error("expected primary to be disabled")
}
if primary.Status != coreauth.StatusDisabled {
t.Errorf("expected primary status disabled, got %s", primary.Status)
}
if primary.Attributes["gemini_virtual_primary"] != "true" {
t.Error("expected gemini_virtual_primary=true")
}
if !strings.Contains(primary.Attributes["virtual_children"], "project-a") {
t.Error("expected virtual_children to contain project-a")
}
// Check virtuals
projectIDs := []string{"project-a", "project-b", "project-c"}
for i, v := range virtuals {
if v.Provider != "gemini-cli" {
t.Errorf("expected provider gemini-cli, got %s", v.Provider)
}
if v.Status != coreauth.StatusActive {
t.Errorf("expected status active, got %s", v.Status)
}
if v.Prefix != "test-prefix" {
t.Errorf("expected prefix test-prefix, got %s", v.Prefix)
}
if v.ProxyURL != "http://proxy.local" {
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
}
if vv, ok := v.Metadata["disable_cooling"].(bool); !ok || !vv {
t.Errorf("expected disable_cooling true, got %v", v.Metadata["disable_cooling"])
}
if vv, ok := v.Metadata["request_retry"].(int); !ok || vv != 2 {
t.Errorf("expected request_retry 2, got %v", v.Metadata["request_retry"])
}
if v.Attributes["runtime_only"] != "true" {
t.Error("expected runtime_only=true")
}
if v.Attributes["gemini_virtual_parent"] != "primary-id" {
t.Errorf("expected gemini_virtual_parent=primary-id, got %s", v.Attributes["gemini_virtual_parent"])
}
if v.Attributes["gemini_virtual_project"] != projectIDs[i] {
t.Errorf("expected gemini_virtual_project=%s, got %s", projectIDs[i], v.Attributes["gemini_virtual_project"])
}
if !strings.Contains(v.Label, "["+projectIDs[i]+"]") {
t.Errorf("expected label to contain [%s], got %s", projectIDs[i], v.Label)
}
}
}
func TestSynthesizeGeminiVirtualAuths_EmptyProviderAndLabel(t *testing.T) {
now := time.Now()
// Test with empty Provider and Label to cover fallback branches
primary := &coreauth.Auth{
ID: "primary-id",
Provider: "", // empty provider - should default to gemini-cli
Label: "", // empty label - should default to provider
Attributes: map[string]string{},
}
metadata := map[string]any{
"project_id": "proj-a, proj-b",
"email": "user@example.com",
"type": "gemini",
}
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
if len(virtuals) != 2 {
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
}
// Check that empty provider defaults to gemini-cli
if virtuals[0].Provider != "gemini-cli" {
t.Errorf("expected provider gemini-cli (default), got %s", virtuals[0].Provider)
}
// Check that empty label defaults to provider
if !strings.Contains(virtuals[0].Label, "gemini-cli") {
t.Errorf("expected label to contain gemini-cli, got %s", virtuals[0].Label)
}
}
func TestSynthesizeGeminiVirtualAuths_NilPrimaryAttributes(t *testing.T) {
now := time.Now()
primary := &coreauth.Auth{
ID: "primary-id",
Provider: "gemini-cli",
Label: "test@example.com",
Attributes: nil, // nil attributes
}
metadata := map[string]any{
"project_id": "proj-a, proj-b",
"email": "test@example.com",
"type": "gemini",
}
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
if len(virtuals) != 2 {
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
}
// Nil attributes should be initialized
if primary.Attributes == nil {
t.Error("expected primary.Attributes to be initialized")
}
if primary.Attributes["gemini_virtual_primary"] != "true" {
t.Error("expected gemini_virtual_primary=true")
}
}
func TestSplitGeminiProjectIDs(t *testing.T) {
tests := []struct {
name string
metadata map[string]any
want []string
}{
{
name: "single project",
metadata: map[string]any{"project_id": "proj-a"},
want: []string{"proj-a"},
},
{
name: "multiple projects",
metadata: map[string]any{"project_id": "proj-a, proj-b, proj-c"},
want: []string{"proj-a", "proj-b", "proj-c"},
},
{
name: "with duplicates",
metadata: map[string]any{"project_id": "proj-a, proj-b, proj-a"},
want: []string{"proj-a", "proj-b"},
},
{
name: "with empty parts",
metadata: map[string]any{"project_id": "proj-a, , proj-b, "},
want: []string{"proj-a", "proj-b"},
},
{
name: "empty project_id",
metadata: map[string]any{"project_id": ""},
want: nil,
},
{
name: "no project_id",
metadata: map[string]any{},
want: nil,
},
{
name: "whitespace only",
metadata: map[string]any{"project_id": " "},
want: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := splitGeminiProjectIDs(tt.metadata)
if len(got) != len(tt.want) {
t.Fatalf("expected %v, got %v", tt.want, got)
}
for i := range got {
if got[i] != tt.want[i] {
t.Errorf("expected %v, got %v", tt.want, got)
break
}
}
})
}
}
func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
tempDir := t.TempDir()
// Create a gemini auth file with multiple projects
authData := map[string]any{
"type": "gemini",
"email": "multi@example.com",
"project_id": "project-a, project-b, project-c",
"priority": " 10 ",
}
data, _ := json.Marshal(authData)
err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644)
if err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have 4 auths: 1 primary (disabled) + 3 virtuals
if len(auths) != 4 {
t.Fatalf("expected 4 auths (1 primary + 3 virtuals), got %d", len(auths))
}
// First auth should be the primary (disabled)
primary := auths[0]
if !primary.Disabled {
t.Error("expected primary to be disabled")
}
if primary.Status != coreauth.StatusDisabled {
t.Errorf("expected primary status disabled, got %s", primary.Status)
}
if gotPriority := primary.Attributes["priority"]; gotPriority != "10" {
t.Errorf("expected primary priority 10, got %q", gotPriority)
}
// Remaining auths should be virtuals
for i := 1; i < 4; i++ {
v := auths[i]
if v.Status != coreauth.StatusActive {
t.Errorf("expected virtual %d to be active, got %s", i, v.Status)
}
if v.Attributes["gemini_virtual_parent"] != primary.ID {
t.Errorf("expected virtual %d parent to be %s, got %s", i, primary.ID, v.Attributes["gemini_virtual_parent"])
}
if gotPriority := v.Attributes["priority"]; gotPriority != "10" {
t.Errorf("expected virtual %d priority 10, got %q", i, gotPriority)
}
}
}
func TestBuildGeminiVirtualID(t *testing.T) {
tests := []struct {
name string
baseID string
projectID string
want string
}{
{
name: "basic",
baseID: "auth.json",
projectID: "my-project",
want: "auth.json::my-project",
},
{
name: "with slashes",
baseID: "path/to/auth.json",
projectID: "project/with/slashes",
want: "path/to/auth.json::project_with_slashes",
},
{
name: "with spaces",
baseID: "auth.json",
projectID: "my project",
want: "auth.json::my_project",
},
{
name: "empty project",
baseID: "auth.json",
projectID: "",
want: "auth.json::project",
},
{
name: "whitespace project",
baseID: "auth.json",
projectID: " ",
want: "auth.json::project",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := buildGeminiVirtualID(tt.baseID, tt.projectID)
if got != tt.want {
t.Errorf("expected %q, got %q", tt.want, got)
}
})
}
}
func TestSynthesizeGeminiVirtualAuths_NotePropagated(t *testing.T) {
now := time.Now()
primary := &coreauth.Auth{
ID: "primary-id",
Provider: "gemini-cli",
Label: "test@example.com",
Attributes: map[string]string{
"source": "test-source",
"path": "/path/to/auth",
"priority": "5",
"note": "my test note",
},
}
metadata := map[string]any{
"project_id": "proj-a, proj-b",
"email": "test@example.com",
"type": "gemini",
}
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
if len(virtuals) != 2 {
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
}
for i, v := range virtuals {
if got := v.Attributes["note"]; got != "my test note" {
t.Errorf("virtual %d: expected note %q, got %q", i, "my test note", got)
}
if got := v.Attributes["priority"]; got != "5" {
t.Errorf("virtual %d: expected priority %q, got %q", i, "5", got)
}
}
}
func TestSynthesizeGeminiVirtualAuths_NoteAbsentWhenEmpty(t *testing.T) {
now := time.Now()
primary := &coreauth.Auth{
ID: "primary-id",
Provider: "gemini-cli",
Label: "test@example.com",
Attributes: map[string]string{
"source": "test-source",
"path": "/path/to/auth",
},
}
metadata := map[string]any{
"project_id": "proj-a, proj-b",
"email": "test@example.com",
"type": "gemini",
}
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
if len(virtuals) != 2 {
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
}
for i, v := range virtuals {
if _, hasNote := v.Attributes["note"]; hasNote {
t.Errorf("virtual %d: expected no note attribute when primary has no note", i)
}
}
}
func TestFileSynthesizer_Synthesize_NoteParsing(t *testing.T) {
tests := []struct {
name string
note any
want string
hasValue bool
}{
{
name: "valid string note",
note: "hello world",
want: "hello world",
hasValue: true,
},
{
name: "string note with whitespace",
note: " trimmed note ",
want: "trimmed note",
hasValue: true,
},
{
name: "empty string note",
note: "",
hasValue: false,
},
{
name: "whitespace only note",
note: " ",
hasValue: false,
},
{
name: "non-string note ignored",
note: 12345,
hasValue: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempDir := t.TempDir()
authData := map[string]any{
"type": "claude",
"note": tt.note,
}
data, _ := json.Marshal(authData)
errWriteFile := os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
if errWriteFile != nil {
t.Fatalf("failed to write auth file: %v", errWriteFile)
}
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, errSynthesize := synth.Synthesize(ctx)
if errSynthesize != nil {
t.Fatalf("unexpected error: %v", errSynthesize)
}
if len(auths) != 1 {
t.Fatalf("expected 1 auth, got %d", len(auths))
}
value, ok := auths[0].Attributes["note"]
if tt.hasValue {
if !ok {
t.Fatal("expected note attribute to be set")
}
if value != tt.want {
t.Fatalf("expected note %q, got %q", tt.want, value)
}
return
}
if ok {
t.Fatalf("expected note attribute to be absent, got %q", value)
}
})
}
}
func TestFileSynthesizer_Synthesize_MultiProjectGeminiWithNote(t *testing.T) {
tempDir := t.TempDir()
authData := map[string]any{
"type": "gemini",
"email": "multi@example.com",
"project_id": "project-a, project-b",
"priority": 5,
"note": "production keys",
}
data, _ := json.Marshal(authData)
err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644)
if err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
synth := NewFileSynthesizer()
ctx := &SynthesisContext{
Config: &config.Config{},
AuthDir: tempDir,
Now: time.Now(),
IDGenerator: NewStableIDGenerator(),
}
auths, err := synth.Synthesize(ctx)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
// Should have 3 auths: 1 primary (disabled) + 2 virtuals
if len(auths) != 3 {
t.Fatalf("expected 3 auths (1 primary + 2 virtuals), got %d", len(auths))
}
primary := auths[0]
if gotNote := primary.Attributes["note"]; gotNote != "production keys" {
t.Errorf("expected primary note %q, got %q", "production keys", gotNote)
}
// Verify virtuals inherit note
for i := 1; i < len(auths); i++ {
v := auths[i]
if gotNote := v.Attributes["note"]; gotNote != "production keys" {
t.Errorf("expected virtual %d note %q, got %q", i, "production keys", gotNote)
}
if gotPriority := v.Attributes["priority"]; gotPriority != "5" {
t.Errorf("expected virtual %d priority %q, got %q", i, "5", gotPriority)
}
}
}
================================================
FILE: internal/watcher/synthesizer/helpers.go
================================================
package synthesizer
import (
"crypto/sha256"
"encoding/hex"
"fmt"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// StableIDGenerator generates stable, deterministic IDs for auth entries.
// It uses SHA256 hashing with collision handling via counters.
// It is not safe for concurrent use.
type StableIDGenerator struct {
counters map[string]int
}
// NewStableIDGenerator creates a new StableIDGenerator instance.
func NewStableIDGenerator() *StableIDGenerator {
return &StableIDGenerator{counters: make(map[string]int)}
}
// Next generates a stable ID based on the kind and parts.
// Returns the full ID (kind:hash) and the short hash portion.
func (g *StableIDGenerator) Next(kind string, parts ...string) (string, string) {
if g == nil {
return kind + ":000000000000", "000000000000"
}
hasher := sha256.New()
hasher.Write([]byte(kind))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
hasher.Write([]byte{0})
hasher.Write([]byte(trimmed))
}
digest := hex.EncodeToString(hasher.Sum(nil))
if len(digest) < 12 {
digest = fmt.Sprintf("%012s", digest)
}
short := digest[:12]
key := kind + ":" + short
index := g.counters[key]
g.counters[key] = index + 1
if index > 0 {
short = fmt.Sprintf("%s-%d", short, index)
}
return fmt.Sprintf("%s:%s", kind, short), short
}
// ApplyAuthExcludedModelsMeta applies excluded models metadata to an auth entry.
// It computes a hash of excluded models and sets the auth_kind attribute.
// For OAuth entries, perKey (from the JSON file's excluded-models field) is merged
// with the global oauth-excluded-models config for the provider.
func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
if auth == nil || cfg == nil {
return
}
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
seen := make(map[string]struct{})
add := func(list []string) {
for _, entry := range list {
if trimmed := strings.TrimSpace(entry); trimmed != "" {
key := strings.ToLower(trimmed)
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
}
}
}
if authKindKey == "apikey" {
add(perKey)
} else {
// For OAuth: merge per-account excluded models with global provider-level exclusions
add(perKey)
if cfg.OAuthExcludedModels != nil {
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
add(cfg.OAuthExcludedModels[providerKey])
}
}
combined := make([]string, 0, len(seen))
for k := range seen {
combined = append(combined, k)
}
sort.Strings(combined)
hash := diff.ComputeExcludedModelsHash(combined)
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
if hash != "" {
auth.Attributes["excluded_models_hash"] = hash
}
// Store the combined excluded models list so that routing can read it at runtime
if len(combined) > 0 {
auth.Attributes["excluded_models"] = strings.Join(combined, ",")
}
if authKind != "" {
auth.Attributes["auth_kind"] = authKind
}
}
// addConfigHeadersToAttrs adds header configuration to auth attributes.
// Headers are prefixed with "header:" in the attributes map.
func addConfigHeadersToAttrs(headers map[string]string, attrs map[string]string) {
if len(headers) == 0 || attrs == nil {
return
}
for hk, hv := range headers {
key := strings.TrimSpace(hk)
val := strings.TrimSpace(hv)
if key == "" || val == "" {
continue
}
attrs["header:"+key] = val
}
}
================================================
FILE: internal/watcher/synthesizer/helpers_test.go
================================================
package synthesizer
import (
"reflect"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestNewStableIDGenerator(t *testing.T) {
gen := NewStableIDGenerator()
if gen == nil {
t.Fatal("expected non-nil generator")
}
if gen.counters == nil {
t.Fatal("expected non-nil counters map")
}
}
func TestStableIDGenerator_Next(t *testing.T) {
tests := []struct {
name string
kind string
parts []string
wantPrefix string
}{
{
name: "basic gemini apikey",
kind: "gemini:apikey",
parts: []string{"test-key", ""},
wantPrefix: "gemini:apikey:",
},
{
name: "claude with base url",
kind: "claude:apikey",
parts: []string{"sk-ant-xxx", "https://api.anthropic.com"},
wantPrefix: "claude:apikey:",
},
{
name: "empty parts",
kind: "codex:apikey",
parts: []string{},
wantPrefix: "codex:apikey:",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gen := NewStableIDGenerator()
id, short := gen.Next(tt.kind, tt.parts...)
if !strings.Contains(id, tt.wantPrefix) {
t.Errorf("expected id to contain %q, got %q", tt.wantPrefix, id)
}
if short == "" {
t.Error("expected non-empty short id")
}
if len(short) != 12 {
t.Errorf("expected short id length 12, got %d", len(short))
}
})
}
}
func TestStableIDGenerator_Stability(t *testing.T) {
gen1 := NewStableIDGenerator()
gen2 := NewStableIDGenerator()
id1, _ := gen1.Next("gemini:apikey", "test-key", "https://api.example.com")
id2, _ := gen2.Next("gemini:apikey", "test-key", "https://api.example.com")
if id1 != id2 {
t.Errorf("same inputs should produce same ID: got %q and %q", id1, id2)
}
}
func TestStableIDGenerator_CollisionHandling(t *testing.T) {
gen := NewStableIDGenerator()
id1, short1 := gen.Next("gemini:apikey", "same-key")
id2, short2 := gen.Next("gemini:apikey", "same-key")
if id1 == id2 {
t.Error("collision should be handled with suffix")
}
if short1 == short2 {
t.Error("short ids should differ")
}
if !strings.Contains(short2, "-1") {
t.Errorf("second short id should contain -1 suffix, got %q", short2)
}
}
func TestStableIDGenerator_NilReceiver(t *testing.T) {
var gen *StableIDGenerator = nil
id, short := gen.Next("test:kind", "part")
if id != "test:kind:000000000000" {
t.Errorf("expected test:kind:000000000000, got %q", id)
}
if short != "000000000000" {
t.Errorf("expected 000000000000, got %q", short)
}
}
func TestApplyAuthExcludedModelsMeta(t *testing.T) {
tests := []struct {
name string
auth *coreauth.Auth
cfg *config.Config
perKey []string
authKind string
wantHash bool
wantKind string
}{
{
name: "apikey with excluded models",
auth: &coreauth.Auth{
Provider: "gemini",
Attributes: make(map[string]string),
},
cfg: &config.Config{},
perKey: []string{"model-a", "model-b"},
authKind: "apikey",
wantHash: true,
wantKind: "apikey",
},
{
name: "oauth with provider excluded models",
auth: &coreauth.Auth{
Provider: "claude",
Attributes: make(map[string]string),
},
cfg: &config.Config{
OAuthExcludedModels: map[string][]string{
"claude": {"claude-2.0"},
},
},
perKey: nil,
authKind: "oauth",
wantHash: true,
wantKind: "oauth",
},
{
name: "nil auth",
auth: nil,
cfg: &config.Config{},
},
{
name: "nil config",
auth: &coreauth.Auth{Provider: "test"},
cfg: nil,
authKind: "apikey",
},
{
name: "nil attributes initialized",
auth: &coreauth.Auth{
Provider: "gemini",
Attributes: nil,
},
cfg: &config.Config{},
perKey: []string{"model-x"},
authKind: "apikey",
wantHash: true,
wantKind: "apikey",
},
{
name: "apikey with duplicate excluded models",
auth: &coreauth.Auth{
Provider: "gemini",
Attributes: make(map[string]string),
},
cfg: &config.Config{},
perKey: []string{"model-a", "MODEL-A", "model-b", "model-a"},
authKind: "apikey",
wantHash: true,
wantKind: "apikey",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ApplyAuthExcludedModelsMeta(tt.auth, tt.cfg, tt.perKey, tt.authKind)
if tt.auth != nil && tt.cfg != nil {
if tt.wantHash {
if _, ok := tt.auth.Attributes["excluded_models_hash"]; !ok {
t.Error("expected excluded_models_hash in attributes")
}
}
if tt.wantKind != "" {
if got := tt.auth.Attributes["auth_kind"]; got != tt.wantKind {
t.Errorf("expected auth_kind=%s, got %s", tt.wantKind, got)
}
}
}
})
}
}
func TestApplyAuthExcludedModelsMeta_OAuthMergeWritesCombinedModels(t *testing.T) {
auth := &coreauth.Auth{
Provider: "claude",
Attributes: make(map[string]string),
}
cfg := &config.Config{
OAuthExcludedModels: map[string][]string{
"claude": {"global-a", "shared"},
},
}
ApplyAuthExcludedModelsMeta(auth, cfg, []string{"per", "SHARED"}, "oauth")
const wantCombined = "global-a,per,shared"
if gotCombined := auth.Attributes["excluded_models"]; gotCombined != wantCombined {
t.Fatalf("expected excluded_models=%q, got %q", wantCombined, gotCombined)
}
expectedHash := diff.ComputeExcludedModelsHash([]string{"global-a", "per", "shared"})
if gotHash := auth.Attributes["excluded_models_hash"]; gotHash != expectedHash {
t.Fatalf("expected excluded_models_hash=%q, got %q", expectedHash, gotHash)
}
}
func TestAddConfigHeadersToAttrs(t *testing.T) {
tests := []struct {
name string
headers map[string]string
attrs map[string]string
want map[string]string
}{
{
name: "basic headers",
headers: map[string]string{
"Authorization": "Bearer token",
"X-Custom": "value",
},
attrs: map[string]string{"existing": "key"},
want: map[string]string{
"existing": "key",
"header:Authorization": "Bearer token",
"header:X-Custom": "value",
},
},
{
name: "empty headers",
headers: map[string]string{},
attrs: map[string]string{"existing": "key"},
want: map[string]string{"existing": "key"},
},
{
name: "nil headers",
headers: nil,
attrs: map[string]string{"existing": "key"},
want: map[string]string{"existing": "key"},
},
{
name: "nil attrs",
headers: map[string]string{"key": "value"},
attrs: nil,
want: nil,
},
{
name: "skip empty keys and values",
headers: map[string]string{
"": "value",
"key": "",
" ": "value",
"valid": "valid-value",
},
attrs: make(map[string]string),
want: map[string]string{
"header:valid": "valid-value",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addConfigHeadersToAttrs(tt.headers, tt.attrs)
if !reflect.DeepEqual(tt.attrs, tt.want) {
t.Errorf("expected %v, got %v", tt.want, tt.attrs)
}
})
}
}
================================================
FILE: internal/watcher/synthesizer/interface.go
================================================
// Package synthesizer provides auth synthesis strategies for the watcher package.
// It implements the Strategy pattern to support multiple auth sources:
// - ConfigSynthesizer: generates Auth entries from config API keys
// - FileSynthesizer: generates Auth entries from OAuth JSON files
package synthesizer
import (
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// AuthSynthesizer defines the interface for generating Auth entries from various sources.
type AuthSynthesizer interface {
// Synthesize generates Auth entries from the given context.
// Returns a slice of Auth pointers and any error encountered.
Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error)
}
================================================
FILE: internal/watcher/watcher.go
================================================
// Package watcher watches config/auth files and triggers hot reloads.
// It supports cross-platform fsnotify event handling.
package watcher
import (
"context"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/fsnotify/fsnotify"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"gopkg.in/yaml.v3"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// storePersister captures persistence-capable token store methods used by the watcher.
type storePersister interface {
PersistConfig(ctx context.Context) error
PersistAuthFiles(ctx context.Context, message string, paths ...string) error
}
type authDirProvider interface {
AuthDir() string
}
// Watcher manages file watching for configuration and authentication files
type Watcher struct {
configPath string
authDir string
config *config.Config
clientsMutex sync.RWMutex
configReloadMu sync.Mutex
configReloadTimer *time.Timer
serverUpdateMu sync.Mutex
serverUpdateTimer *time.Timer
serverUpdateLast time.Time
serverUpdatePend bool
stopped atomic.Bool
reloadCallback func(*config.Config)
watcher *fsnotify.Watcher
lastAuthHashes map[string]string
lastAuthContents map[string]*coreauth.Auth
fileAuthsByPath map[string]map[string]*coreauth.Auth
lastRemoveTimes map[string]time.Time
lastConfigHash string
authQueue chan<- AuthUpdate
currentAuths map[string]*coreauth.Auth
runtimeAuths map[string]*coreauth.Auth
dispatchMu sync.Mutex
dispatchCond *sync.Cond
pendingUpdates map[string]AuthUpdate
pendingOrder []string
dispatchCancel context.CancelFunc
storePersister storePersister
mirroredAuthDir string
oldConfigYaml []byte
}
// AuthUpdateAction represents the type of change detected in auth sources.
type AuthUpdateAction string
const (
AuthUpdateActionAdd AuthUpdateAction = "add"
AuthUpdateActionModify AuthUpdateAction = "modify"
AuthUpdateActionDelete AuthUpdateAction = "delete"
)
// AuthUpdate describes an incremental change to auth configuration.
type AuthUpdate struct {
Action AuthUpdateAction
ID string
Auth *coreauth.Auth
}
const (
// replaceCheckDelay is a short delay to allow atomic replace (rename) to settle
// before deciding whether a Remove event indicates a real deletion.
replaceCheckDelay = 50 * time.Millisecond
configReloadDebounce = 150 * time.Millisecond
authRemoveDebounceWindow = 1 * time.Second
serverUpdateDebounce = 1 * time.Second
)
// NewWatcher creates a new file watcher instance
func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config)) (*Watcher, error) {
watcher, errNewWatcher := fsnotify.NewWatcher()
if errNewWatcher != nil {
return nil, errNewWatcher
}
w := &Watcher{
configPath: configPath,
authDir: authDir,
reloadCallback: reloadCallback,
watcher: watcher,
lastAuthHashes: make(map[string]string),
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
}
w.dispatchCond = sync.NewCond(&w.dispatchMu)
if store := sdkAuth.GetTokenStore(); store != nil {
if persister, ok := store.(storePersister); ok {
w.storePersister = persister
log.Debug("persistence-capable token store detected; watcher will propagate persisted changes")
}
if provider, ok := store.(authDirProvider); ok {
if fixed := strings.TrimSpace(provider.AuthDir()); fixed != "" {
w.mirroredAuthDir = fixed
log.Debugf("mirrored auth directory locked to %s", fixed)
}
}
}
return w, nil
}
// Start begins watching the configuration file and authentication directory
func (w *Watcher) Start(ctx context.Context) error {
return w.start(ctx)
}
// Stop stops the file watcher
func (w *Watcher) Stop() error {
w.stopped.Store(true)
w.stopDispatch()
w.stopConfigReloadTimer()
w.stopServerUpdateTimer()
return w.watcher.Close()
}
// SetConfig updates the current configuration
func (w *Watcher) SetConfig(cfg *config.Config) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
w.config = cfg
w.oldConfigYaml, _ = yaml.Marshal(cfg)
}
// SetAuthUpdateQueue sets the queue used to emit auth updates.
func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) {
w.setAuthUpdateQueue(queue)
}
// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths)
// to push auth updates through the same queue used by file/config watchers.
// Returns true if the update was enqueued; false if no queue is configured.
func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
return w.dispatchRuntimeAuthUpdate(update)
}
// SnapshotCoreAuths converts current clients snapshot into core auth entries.
func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
w.clientsMutex.RLock()
cfg := w.config
w.clientsMutex.RUnlock()
return snapshotCoreAuths(cfg, w.authDir)
}
================================================
FILE: internal/watcher/watcher_test.go
================================================
package watcher
import (
"context"
"crypto/sha256"
"encoding/json"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/fsnotify/fsnotify"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"gopkg.in/yaml.v3"
)
func TestApplyAuthExcludedModelsMeta_APIKey(t *testing.T) {
auth := &coreauth.Auth{Attributes: map[string]string{}}
cfg := &config.Config{}
perKey := []string{" Model-1 ", "model-2"}
synthesizer.ApplyAuthExcludedModelsMeta(auth, cfg, perKey, "apikey")
expected := diff.ComputeExcludedModelsHash([]string{"model-1", "model-2"})
if got := auth.Attributes["excluded_models_hash"]; got != expected {
t.Fatalf("expected hash %s, got %s", expected, got)
}
if got := auth.Attributes["auth_kind"]; got != "apikey" {
t.Fatalf("expected auth_kind=apikey, got %s", got)
}
}
func TestApplyAuthExcludedModelsMeta_OAuthProvider(t *testing.T) {
auth := &coreauth.Auth{
Provider: "TestProv",
Attributes: map[string]string{},
}
cfg := &config.Config{
OAuthExcludedModels: map[string][]string{
"testprov": {"A", "b"},
},
}
synthesizer.ApplyAuthExcludedModelsMeta(auth, cfg, nil, "oauth")
expected := diff.ComputeExcludedModelsHash([]string{"a", "b"})
if got := auth.Attributes["excluded_models_hash"]; got != expected {
t.Fatalf("expected hash %s, got %s", expected, got)
}
if got := auth.Attributes["auth_kind"]; got != "oauth" {
t.Fatalf("expected auth_kind=oauth, got %s", got)
}
}
func TestBuildAPIKeyClientsCounts(t *testing.T) {
cfg := &config.Config{
GeminiKey: []config.GeminiKey{{APIKey: "g1"}, {APIKey: "g2"}},
VertexCompatAPIKey: []config.VertexCompatKey{
{APIKey: "v1"},
},
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
CodexKey: []config.CodexKey{{APIKey: "x1"}, {APIKey: "x2"}},
OpenAICompatibility: []config.OpenAICompatibility{
{APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "o1"}, {APIKey: "o2"}}},
},
}
gemini, vertex, claude, codex, compat := BuildAPIKeyClients(cfg)
if gemini != 2 || vertex != 1 || claude != 1 || codex != 2 || compat != 2 {
t.Fatalf("unexpected counts: %d %d %d %d %d", gemini, vertex, claude, codex, compat)
}
}
func TestNormalizeAuthStripsTemporalFields(t *testing.T) {
now := time.Now()
auth := &coreauth.Auth{
CreatedAt: now,
UpdatedAt: now,
LastRefreshedAt: now,
NextRefreshAfter: now,
Quota: coreauth.QuotaState{
NextRecoverAt: now,
},
Runtime: map[string]any{"k": "v"},
}
normalized := normalizeAuth(auth)
if !normalized.CreatedAt.IsZero() || !normalized.UpdatedAt.IsZero() || !normalized.LastRefreshedAt.IsZero() || !normalized.NextRefreshAfter.IsZero() {
t.Fatal("expected time fields to be zeroed")
}
if normalized.Runtime != nil {
t.Fatal("expected runtime to be nil")
}
if !normalized.Quota.NextRecoverAt.IsZero() {
t.Fatal("expected quota.NextRecoverAt to be zeroed")
}
}
func TestMatchProvider(t *testing.T) {
if _, ok := matchProvider("OpenAI", []string{"openai", "claude"}); !ok {
t.Fatal("expected match to succeed ignoring case")
}
if _, ok := matchProvider("missing", []string{"openai"}); ok {
t.Fatal("expected match to fail for unknown provider")
}
}
func TestSnapshotCoreAuths_ConfigAndAuthFiles(t *testing.T) {
authDir := t.TempDir()
metadata := map[string]any{
"type": "gemini",
"email": "user@example.com",
"project_id": "proj-a, proj-b",
"proxy_url": "https://proxy",
}
authFile := filepath.Join(authDir, "gemini.json")
data, err := json.Marshal(metadata)
if err != nil {
t.Fatalf("failed to marshal metadata: %v", err)
}
if err = os.WriteFile(authFile, data, 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
cfg := &config.Config{
AuthDir: authDir,
GeminiKey: []config.GeminiKey{
{
APIKey: "g-key",
BaseURL: "https://gemini",
ExcludedModels: []string{"Model-A", "model-b"},
Headers: map[string]string{"X-Req": "1"},
},
},
OAuthExcludedModels: map[string][]string{
"gemini-cli": {"Foo", "bar"},
},
}
w := &Watcher{authDir: authDir}
w.SetConfig(cfg)
auths := w.SnapshotCoreAuths()
if len(auths) != 4 {
t.Fatalf("expected 4 auth entries (1 config + 1 primary + 2 virtual), got %d", len(auths))
}
var geminiAPIKeyAuth *coreauth.Auth
var geminiPrimary *coreauth.Auth
virtuals := make([]*coreauth.Auth, 0)
for _, a := range auths {
switch {
case a.Provider == "gemini" && a.Attributes["api_key"] == "g-key":
geminiAPIKeyAuth = a
case a.Attributes["gemini_virtual_primary"] == "true":
geminiPrimary = a
case strings.TrimSpace(a.Attributes["gemini_virtual_parent"]) != "":
virtuals = append(virtuals, a)
}
}
if geminiAPIKeyAuth == nil {
t.Fatal("expected synthesized Gemini API key auth")
}
expectedAPIKeyHash := diff.ComputeExcludedModelsHash([]string{"Model-A", "model-b"})
if geminiAPIKeyAuth.Attributes["excluded_models_hash"] != expectedAPIKeyHash {
t.Fatalf("expected API key excluded hash %s, got %s", expectedAPIKeyHash, geminiAPIKeyAuth.Attributes["excluded_models_hash"])
}
if geminiAPIKeyAuth.Attributes["auth_kind"] != "apikey" {
t.Fatalf("expected auth_kind=apikey, got %s", geminiAPIKeyAuth.Attributes["auth_kind"])
}
if geminiPrimary == nil {
t.Fatal("expected primary gemini-cli auth from file")
}
if !geminiPrimary.Disabled || geminiPrimary.Status != coreauth.StatusDisabled {
t.Fatal("expected primary gemini-cli auth to be disabled when virtual auths are synthesized")
}
expectedOAuthHash := diff.ComputeExcludedModelsHash([]string{"Foo", "bar"})
if geminiPrimary.Attributes["excluded_models_hash"] != expectedOAuthHash {
t.Fatalf("expected OAuth excluded hash %s, got %s", expectedOAuthHash, geminiPrimary.Attributes["excluded_models_hash"])
}
if geminiPrimary.Attributes["auth_kind"] != "oauth" {
t.Fatalf("expected auth_kind=oauth, got %s", geminiPrimary.Attributes["auth_kind"])
}
if len(virtuals) != 2 {
t.Fatalf("expected 2 virtual auths, got %d", len(virtuals))
}
for _, v := range virtuals {
if v.Attributes["gemini_virtual_parent"] != geminiPrimary.ID {
t.Fatalf("virtual auth missing parent link to %s", geminiPrimary.ID)
}
if v.Attributes["excluded_models_hash"] != expectedOAuthHash {
t.Fatalf("expected virtual excluded hash %s, got %s", expectedOAuthHash, v.Attributes["excluded_models_hash"])
}
if v.Status != coreauth.StatusActive {
t.Fatalf("expected virtual auth to be active, got %s", v.Status)
}
}
}
func TestReloadConfigIfChanged_TriggersOnChangeAndSkipsUnchanged(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
writeConfig := func(port int, allowRemote bool) {
cfg := &config.Config{
Port: port,
AuthDir: authDir,
RemoteManagement: config.RemoteManagement{
AllowRemote: allowRemote,
},
}
data, err := yaml.Marshal(cfg)
if err != nil {
t.Fatalf("failed to marshal config: %v", err)
}
if err = os.WriteFile(configPath, data, 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
}
writeConfig(8080, false)
reloads := 0
w := &Watcher{
configPath: configPath,
authDir: authDir,
reloadCallback: func(*config.Config) { reloads++ },
}
w.reloadConfigIfChanged()
if reloads != 1 {
t.Fatalf("expected first reload to trigger callback once, got %d", reloads)
}
// Same content should be skipped by hash check.
w.reloadConfigIfChanged()
if reloads != 1 {
t.Fatalf("expected unchanged config to be skipped, callback count %d", reloads)
}
writeConfig(9090, true)
w.reloadConfigIfChanged()
if reloads != 2 {
t.Fatalf("expected changed config to trigger reload, callback count %d", reloads)
}
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
if w.config == nil || w.config.Port != 9090 || !w.config.RemoteManagement.AllowRemote {
t.Fatalf("expected config to be updated after reload, got %+v", w.config)
}
}
func TestStartAndStopSuccess(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir), 0o644); err != nil {
t.Fatalf("failed to create config file: %v", err)
}
var reloads int32
w, err := NewWatcher(configPath, authDir, func(*config.Config) {
atomic.AddInt32(&reloads, 1)
})
if err != nil {
t.Fatalf("failed to create watcher: %v", err)
}
w.SetConfig(&config.Config{AuthDir: authDir})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := w.Start(ctx); err != nil {
t.Fatalf("expected Start to succeed: %v", err)
}
cancel()
if err := w.Stop(); err != nil {
t.Fatalf("expected Stop to succeed: %v", err)
}
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected one reload callback, got %d", got)
}
}
func TestStartFailsWhenConfigMissing(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "missing-config.yaml")
w, err := NewWatcher(configPath, authDir, nil)
if err != nil {
t.Fatalf("failed to create watcher: %v", err)
}
defer w.Stop()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := w.Start(ctx); err == nil {
t.Fatal("expected Start to fail for missing config file")
}
}
func TestDispatchRuntimeAuthUpdateEnqueuesAndUpdatesState(t *testing.T) {
queue := make(chan AuthUpdate, 4)
w := &Watcher{}
w.SetAuthUpdateQueue(queue)
defer w.stopDispatch()
auth := &coreauth.Auth{ID: "auth-1", Provider: "test"}
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionAdd, Auth: auth}); !ok {
t.Fatal("expected DispatchRuntimeAuthUpdate to enqueue")
}
select {
case update := <-queue:
if update.Action != AuthUpdateActionAdd || update.Auth.ID != "auth-1" {
t.Fatalf("unexpected update: %+v", update)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for auth update")
}
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionDelete, ID: "auth-1"}); !ok {
t.Fatal("expected delete update to enqueue")
}
select {
case update := <-queue:
if update.Action != AuthUpdateActionDelete || update.ID != "auth-1" {
t.Fatalf("unexpected delete update: %+v", update)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for delete update")
}
w.clientsMutex.RLock()
if _, exists := w.runtimeAuths["auth-1"]; exists {
w.clientsMutex.RUnlock()
t.Fatal("expected runtime auth to be cleared after delete")
}
w.clientsMutex.RUnlock()
}
func TestAddOrUpdateClientSkipsUnchanged(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "sample.json")
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
t.Fatalf("failed to create auth file: %v", err)
}
data, _ := os.ReadFile(authFile)
sum := sha256.Sum256(data)
var reloads int32
w := &Watcher{
authDir: tmpDir,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) {
atomic.AddInt32(&reloads, 1)
},
}
w.SetConfig(&config.Config{AuthDir: tmpDir})
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(sum[:])
w.addOrUpdateClient(authFile)
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no reload for unchanged file, got %d", got)
}
}
func TestAddOrUpdateClientTriggersReloadAndHash(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "sample.json")
if err := os.WriteFile(authFile, []byte(`{"type":"demo","api_key":"k"}`), 0o644); err != nil {
t.Fatalf("failed to create auth file: %v", err)
}
var reloads int32
w := &Watcher{
authDir: tmpDir,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) {
atomic.AddInt32(&reloads, 1)
},
}
w.SetConfig(&config.Config{AuthDir: tmpDir})
w.addOrUpdateClient(authFile)
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no reload callback for auth update, got %d", got)
}
// Use normalizeAuthPath to match how addOrUpdateClient stores the key
normalized := w.normalizeAuthPath(authFile)
if _, ok := w.lastAuthHashes[normalized]; !ok {
t.Fatalf("expected hash to be stored for %s", normalized)
}
}
func TestRemoveClientRemovesHash(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "sample.json")
var reloads int32
w := &Watcher{
authDir: tmpDir,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) {
atomic.AddInt32(&reloads, 1)
},
}
w.SetConfig(&config.Config{AuthDir: tmpDir})
// Use normalizeAuthPath to set up the hash with the correct key format
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
w.removeClient(authFile)
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected hash to be removed after deletion")
}
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no reload callback for auth removal, got %d", got)
}
}
func TestAuthFileEventsDoNotInvokeSnapshotCoreAuths(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "sample.json")
if err := os.WriteFile(authFile, []byte(`{"type":"codex","email":"u@example.com"}`), 0o644); err != nil {
t.Fatalf("failed to create auth file: %v", err)
}
origSnapshot := snapshotCoreAuthsFunc
var snapshotCalls int32
snapshotCoreAuthsFunc = func(cfg *config.Config, authDir string) []*coreauth.Auth {
atomic.AddInt32(&snapshotCalls, 1)
return origSnapshot(cfg, authDir)
}
defer func() { snapshotCoreAuthsFunc = origSnapshot }()
w := &Watcher{
authDir: tmpDir,
lastAuthHashes: make(map[string]string),
lastAuthContents: make(map[string]*coreauth.Auth),
fileAuthsByPath: make(map[string]map[string]*coreauth.Auth),
}
w.SetConfig(&config.Config{AuthDir: tmpDir})
w.addOrUpdateClient(authFile)
w.removeClient(authFile)
if got := atomic.LoadInt32(&snapshotCalls); got != 0 {
t.Fatalf("expected auth file events to avoid full snapshot, got %d calls", got)
}
}
func TestAuthSliceToMap(t *testing.T) {
t.Parallel()
valid1 := &coreauth.Auth{ID: "a"}
valid2 := &coreauth.Auth{ID: "b"}
dupOld := &coreauth.Auth{ID: "dup", Label: "old"}
dupNew := &coreauth.Auth{ID: "dup", Label: "new"}
empty := &coreauth.Auth{ID: " "}
tests := []struct {
name string
in []*coreauth.Auth
want map[string]*coreauth.Auth
}{
{
name: "nil input",
in: nil,
want: map[string]*coreauth.Auth{},
},
{
name: "empty input",
in: []*coreauth.Auth{},
want: map[string]*coreauth.Auth{},
},
{
name: "filters invalid auths",
in: []*coreauth.Auth{nil, empty},
want: map[string]*coreauth.Auth{},
},
{
name: "keeps valid auths",
in: []*coreauth.Auth{valid1, nil, valid2},
want: map[string]*coreauth.Auth{"a": valid1, "b": valid2},
},
{
name: "last duplicate wins",
in: []*coreauth.Auth{dupOld, dupNew},
want: map[string]*coreauth.Auth{"dup": dupNew},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
got := authSliceToMap(tc.in)
if len(tc.want) == 0 {
if got == nil {
t.Fatal("expected empty map, got nil")
}
if len(got) != 0 {
t.Fatalf("expected empty map, got %#v", got)
}
return
}
if len(got) != len(tc.want) {
t.Fatalf("unexpected map length: got %d, want %d", len(got), len(tc.want))
}
for id, wantAuth := range tc.want {
gotAuth, ok := got[id]
if !ok {
t.Fatalf("missing id %q in result map", id)
}
if !authEqual(gotAuth, wantAuth) {
t.Fatalf("unexpected auth for id %q: got %#v, want %#v", id, gotAuth, wantAuth)
}
}
})
}
}
func TestTriggerServerUpdateCancelsPendingTimerOnImmediate(t *testing.T) {
tmpDir := t.TempDir()
cfg := &config.Config{AuthDir: tmpDir}
var reloads int32
w := &Watcher{
reloadCallback: func(*config.Config) {
atomic.AddInt32(&reloads, 1)
},
}
w.SetConfig(cfg)
w.serverUpdateMu.Lock()
w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce - 100*time.Millisecond))
w.serverUpdateMu.Unlock()
w.triggerServerUpdate(cfg)
if got := atomic.LoadInt32(&reloads); got != 0 {
t.Fatalf("expected no immediate reload, got %d", got)
}
w.serverUpdateMu.Lock()
if !w.serverUpdatePend || w.serverUpdateTimer == nil {
w.serverUpdateMu.Unlock()
t.Fatal("expected a pending server update timer")
}
w.serverUpdateLast = time.Now().Add(-(serverUpdateDebounce + 10*time.Millisecond))
w.serverUpdateMu.Unlock()
w.triggerServerUpdate(cfg)
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected immediate reload once, got %d", got)
}
time.Sleep(250 * time.Millisecond)
if got := atomic.LoadInt32(&reloads); got != 1 {
t.Fatalf("expected pending timer to be cancelled, got %d reloads", got)
}
}
func TestShouldDebounceRemove(t *testing.T) {
w := &Watcher{}
path := filepath.Clean("test.json")
if w.shouldDebounceRemove(path, time.Now()) {
t.Fatal("first call should not debounce")
}
if !w.shouldDebounceRemove(path, time.Now()) {
t.Fatal("second call within window should debounce")
}
w.clientsMutex.Lock()
w.lastRemoveTimes = map[string]time.Time{path: time.Now().Add(-2 * authRemoveDebounceWindow)}
w.clientsMutex.Unlock()
if w.shouldDebounceRemove(path, time.Now()) {
t.Fatal("call after window should not debounce")
}
}
func TestAuthFileUnchangedUsesHash(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "sample.json")
content := []byte(`{"type":"demo"}`)
if err := os.WriteFile(authFile, content, 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
w := &Watcher{lastAuthHashes: make(map[string]string)}
unchanged, err := w.authFileUnchanged(authFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if unchanged {
t.Fatal("expected first check to report changed")
}
sum := sha256.Sum256(content)
// Use normalizeAuthPath to match how authFileUnchanged looks up the key
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(sum[:])
unchanged, err = w.authFileUnchanged(authFile)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !unchanged {
t.Fatal("expected hash match to report unchanged")
}
}
func TestAuthFileUnchangedEmptyAndMissing(t *testing.T) {
tmpDir := t.TempDir()
emptyFile := filepath.Join(tmpDir, "empty.json")
if err := os.WriteFile(emptyFile, []byte(""), 0o644); err != nil {
t.Fatalf("failed to write empty auth file: %v", err)
}
w := &Watcher{lastAuthHashes: make(map[string]string)}
unchanged, err := w.authFileUnchanged(emptyFile)
if err != nil {
t.Fatalf("unexpected error for empty file: %v", err)
}
if unchanged {
t.Fatal("expected empty file to be treated as changed")
}
_, err = w.authFileUnchanged(filepath.Join(tmpDir, "missing.json"))
if err == nil {
t.Fatal("expected error for missing auth file")
}
}
func TestReloadClientsCachesAuthHashes(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "one.json")
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
w := &Watcher{
authDir: tmpDir,
config: &config.Config{AuthDir: tmpDir},
}
w.reloadClients(true, nil, false)
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
if len(w.lastAuthHashes) != 1 {
t.Fatalf("expected hash cache for one auth file, got %d", len(w.lastAuthHashes))
}
}
func TestReloadClientsLogsConfigDiffs(t *testing.T) {
tmpDir := t.TempDir()
oldCfg := &config.Config{AuthDir: tmpDir, Port: 1, Debug: false}
newCfg := &config.Config{AuthDir: tmpDir, Port: 2, Debug: true}
w := &Watcher{
authDir: tmpDir,
config: oldCfg,
}
w.SetConfig(oldCfg)
w.oldConfigYaml, _ = yaml.Marshal(oldCfg)
w.clientsMutex.Lock()
w.config = newCfg
w.clientsMutex.Unlock()
w.reloadClients(false, nil, false)
}
func TestReloadClientsHandlesNilConfig(t *testing.T) {
w := &Watcher{}
w.reloadClients(true, nil, false)
}
func TestReloadClientsFiltersProvidersWithNilCurrentAuths(t *testing.T) {
tmp := t.TempDir()
w := &Watcher{
authDir: tmp,
config: &config.Config{AuthDir: tmp},
}
w.reloadClients(false, []string{"match"}, false)
if w.currentAuths != nil && len(w.currentAuths) != 0 {
t.Fatalf("expected currentAuths to be nil or empty, got %d", len(w.currentAuths))
}
}
func TestSetAuthUpdateQueueNilResetsDispatch(t *testing.T) {
w := &Watcher{}
queue := make(chan AuthUpdate, 1)
w.SetAuthUpdateQueue(queue)
if w.dispatchCond == nil || w.dispatchCancel == nil {
t.Fatal("expected dispatch to be initialized")
}
w.SetAuthUpdateQueue(nil)
if w.dispatchCancel != nil {
t.Fatal("expected dispatch cancel to be cleared when queue nil")
}
}
func TestPersistAsyncEarlyReturns(t *testing.T) {
var nilWatcher *Watcher
nilWatcher.persistConfigAsync()
nilWatcher.persistAuthAsync("msg", "a")
w := &Watcher{}
w.persistConfigAsync()
w.persistAuthAsync("msg", " ", "")
}
type errorPersister struct {
configCalls int32
authCalls int32
}
func (p *errorPersister) PersistConfig(context.Context) error {
atomic.AddInt32(&p.configCalls, 1)
return fmt.Errorf("persist config error")
}
func (p *errorPersister) PersistAuthFiles(context.Context, string, ...string) error {
atomic.AddInt32(&p.authCalls, 1)
return fmt.Errorf("persist auth error")
}
func TestPersistAsyncErrorPaths(t *testing.T) {
p := &errorPersister{}
w := &Watcher{storePersister: p}
w.persistConfigAsync()
w.persistAuthAsync("msg", "a")
time.Sleep(30 * time.Millisecond)
if atomic.LoadInt32(&p.configCalls) != 1 {
t.Fatalf("expected PersistConfig to be called once, got %d", p.configCalls)
}
if atomic.LoadInt32(&p.authCalls) != 1 {
t.Fatalf("expected PersistAuthFiles to be called once, got %d", p.authCalls)
}
}
func TestStopConfigReloadTimerSafeWhenNil(t *testing.T) {
w := &Watcher{}
w.stopConfigReloadTimer()
w.configReloadMu.Lock()
w.configReloadTimer = time.AfterFunc(10*time.Millisecond, func() {})
w.configReloadMu.Unlock()
time.Sleep(1 * time.Millisecond)
w.stopConfigReloadTimer()
}
func TestHandleEventRemovesAuthFile(t *testing.T) {
tmpDir := t.TempDir()
authFile := filepath.Join(tmpDir, "remove.json")
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
if err := os.Remove(authFile); err != nil {
t.Fatalf("failed to remove auth file pre-check: %v", err)
}
var reloads int32
w := &Watcher{
authDir: tmpDir,
config: &config.Config{AuthDir: tmpDir},
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) {
atomic.AddInt32(&reloads, 1)
},
}
// Use normalizeAuthPath to set up the hash with the correct key format
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected no reload callback for auth removal, got %d", reloads)
}
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected hash entry to be removed")
}
}
func TestDispatchAuthUpdatesFlushesQueue(t *testing.T) {
queue := make(chan AuthUpdate, 4)
w := &Watcher{}
w.SetAuthUpdateQueue(queue)
defer w.stopDispatch()
w.dispatchAuthUpdates([]AuthUpdate{
{Action: AuthUpdateActionAdd, ID: "a"},
{Action: AuthUpdateActionModify, ID: "b"},
})
got := make([]AuthUpdate, 0, 2)
for i := 0; i < 2; i++ {
select {
case u := <-queue:
got = append(got, u)
case <-time.After(2 * time.Second):
t.Fatalf("timed out waiting for update %d", i)
}
}
if len(got) != 2 || got[0].ID != "a" || got[1].ID != "b" {
t.Fatalf("unexpected updates order/content: %+v", got)
}
}
func TestDispatchLoopExitsOnContextDoneWhileSending(t *testing.T) {
queue := make(chan AuthUpdate) // unbuffered to block sends
w := &Watcher{
authQueue: queue,
pendingUpdates: map[string]AuthUpdate{
"k": {Action: AuthUpdateActionAdd, ID: "k"},
},
pendingOrder: []string{"k"},
}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
w.dispatchLoop(ctx)
close(done)
}()
time.Sleep(30 * time.Millisecond)
cancel()
select {
case <-done:
case <-time.After(2 * time.Second):
t.Fatal("expected dispatchLoop to exit after ctx canceled while blocked on send")
}
}
func TestProcessEventsHandlesEventErrorAndChannelClose(t *testing.T) {
w := &Watcher{
watcher: &fsnotify.Watcher{
Events: make(chan fsnotify.Event, 2),
Errors: make(chan error, 2),
},
configPath: "config.yaml",
authDir: "auth",
}
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
w.processEvents(ctx)
close(done)
}()
w.watcher.Events <- fsnotify.Event{Name: "unrelated.txt", Op: fsnotify.Write}
w.watcher.Errors <- fmt.Errorf("watcher error")
time.Sleep(20 * time.Millisecond)
close(w.watcher.Events)
close(w.watcher.Errors)
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("processEvents did not exit after channels closed")
}
}
func TestProcessEventsReturnsWhenErrorsChannelClosed(t *testing.T) {
w := &Watcher{
watcher: &fsnotify.Watcher{
Events: nil,
Errors: make(chan error),
},
}
close(w.watcher.Errors)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
done := make(chan struct{})
go func() {
w.processEvents(ctx)
close(done)
}()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("processEvents did not exit after errors channel closed")
}
}
func TestHandleEventIgnoresUnrelatedFiles(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: filepath.Join(tmpDir, "note.txt"), Op: fsnotify.Write})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected no reloads for unrelated file, got %d", reloads)
}
}
func TestHandleEventConfigChangeSchedulesReload(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: configPath, Op: fsnotify.Write})
time.Sleep(400 * time.Millisecond)
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected config change to trigger reload once, got %d", reloads)
}
}
func TestHandleEventAuthWriteTriggersUpdate(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "a.json")
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Write})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected auth write to avoid global reload callback, got %d", reloads)
}
}
func TestHandleEventRemoveDebounceSkips(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "remove.json")
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
lastRemoveTimes: map[string]time.Time{
filepath.Clean(authFile): time.Now(),
},
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected remove to be debounced, got %d", reloads)
}
}
func TestHandleEventAtomicReplaceUnchangedSkips(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "same.json")
content := []byte(`{"type":"demo"}`)
if err := os.WriteFile(authFile, content, 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
sum := sha256.Sum256(content)
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(sum[:])
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected unchanged atomic replace to be skipped, got %d", reloads)
}
}
func TestHandleEventAtomicReplaceChangedTriggersUpdate(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "change.json")
oldContent := []byte(`{"type":"demo","v":1}`)
newContent := []byte(`{"type":"demo","v":2}`)
if err := os.WriteFile(authFile, newContent, 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
oldSum := sha256.Sum256(oldContent)
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = hexString(oldSum[:])
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Rename})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected changed atomic replace to avoid global reload, got %d", reloads)
}
}
func TestHandleEventRemoveUnknownFileIgnored(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "unknown.json")
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected unknown remove to be ignored, got %d", reloads)
}
}
func TestHandleEventRemoveKnownFileDeletes(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authFile := filepath.Join(authDir, "known.json")
var reloads int32
w := &Watcher{
authDir: authDir,
configPath: configPath,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.lastAuthHashes[w.normalizeAuthPath(authFile)] = "hash"
w.handleEvent(fsnotify.Event{Name: authFile, Op: fsnotify.Remove})
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected known remove to avoid global reload, got %d", reloads)
}
if _, ok := w.lastAuthHashes[w.normalizeAuthPath(authFile)]; ok {
t.Fatal("expected known auth hash to be deleted")
}
}
func TestNormalizeAuthPathAndDebounceCleanup(t *testing.T) {
w := &Watcher{}
if got := w.normalizeAuthPath(" "); got != "" {
t.Fatalf("expected empty normalize result, got %q", got)
}
if got := w.normalizeAuthPath(" a/../b "); got != filepath.Clean("a/../b") {
t.Fatalf("unexpected normalize result: %q", got)
}
w.clientsMutex.Lock()
w.lastRemoveTimes = make(map[string]time.Time, 140)
old := time.Now().Add(-3 * authRemoveDebounceWindow)
for i := 0; i < 129; i++ {
w.lastRemoveTimes[fmt.Sprintf("old-%d", i)] = old
}
w.clientsMutex.Unlock()
w.shouldDebounceRemove("new-path", time.Now())
w.clientsMutex.Lock()
gotLen := len(w.lastRemoveTimes)
w.clientsMutex.Unlock()
if gotLen >= 129 {
t.Fatalf("expected debounce cleanup to shrink map, got %d", gotLen)
}
}
func TestRefreshAuthStateDispatchesRuntimeAuths(t *testing.T) {
queue := make(chan AuthUpdate, 8)
w := &Watcher{
authDir: t.TempDir(),
lastAuthHashes: make(map[string]string),
}
w.SetConfig(&config.Config{AuthDir: w.authDir})
w.SetAuthUpdateQueue(queue)
defer w.stopDispatch()
w.clientsMutex.Lock()
w.runtimeAuths = map[string]*coreauth.Auth{
"nil": nil,
"r1": {ID: "r1", Provider: "runtime"},
}
w.clientsMutex.Unlock()
w.refreshAuthState(false)
select {
case u := <-queue:
if u.Action != AuthUpdateActionAdd || u.ID != "r1" {
t.Fatalf("unexpected auth update: %+v", u)
}
case <-time.After(2 * time.Second):
t.Fatal("timed out waiting for runtime auth update")
}
}
func TestAddOrUpdateClientEdgeCases(t *testing.T) {
tmpDir := t.TempDir()
authDir := tmpDir
authFile := filepath.Join(tmpDir, "edge.json")
if err := os.WriteFile(authFile, []byte(`{"type":"demo"}`), 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
emptyFile := filepath.Join(tmpDir, "empty.json")
if err := os.WriteFile(emptyFile, []byte(""), 0o644); err != nil {
t.Fatalf("failed to write empty auth file: %v", err)
}
var reloads int32
w := &Watcher{
authDir: authDir,
lastAuthHashes: make(map[string]string),
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.addOrUpdateClient(filepath.Join(tmpDir, "missing.json"))
w.addOrUpdateClient(emptyFile)
if atomic.LoadInt32(&reloads) != 0 {
t.Fatalf("expected no reloads for missing/empty file, got %d", reloads)
}
w.addOrUpdateClient(authFile) // config nil -> should not panic or update
if len(w.lastAuthHashes) != 0 {
t.Fatalf("expected no hash entries without config, got %d", len(w.lastAuthHashes))
}
}
func TestLoadFileClientsWalkError(t *testing.T) {
tmpDir := t.TempDir()
noAccessDir := filepath.Join(tmpDir, "0noaccess")
if err := os.MkdirAll(noAccessDir, 0o755); err != nil {
t.Fatalf("failed to create noaccess dir: %v", err)
}
if err := os.Chmod(noAccessDir, 0); err != nil {
t.Skipf("chmod not supported: %v", err)
}
defer func() { _ = os.Chmod(noAccessDir, 0o755) }()
cfg := &config.Config{AuthDir: tmpDir}
w := &Watcher{}
w.SetConfig(cfg)
count := w.loadFileClients(cfg)
if count != 0 {
t.Fatalf("expected count 0 due to walk error, got %d", count)
}
}
func TestReloadConfigIfChangedHandlesMissingAndEmpty(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
w := &Watcher{
configPath: filepath.Join(tmpDir, "missing.yaml"),
authDir: authDir,
}
w.reloadConfigIfChanged() // missing file -> log + return
emptyPath := filepath.Join(tmpDir, "empty.yaml")
if err := os.WriteFile(emptyPath, []byte(""), 0o644); err != nil {
t.Fatalf("failed to write empty config: %v", err)
}
w.configPath = emptyPath
w.reloadConfigIfChanged() // empty file -> early return
}
func TestReloadConfigUsesMirroredAuthDir(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+filepath.Join(tmpDir, "other")+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
w := &Watcher{
configPath: configPath,
authDir: authDir,
mirroredAuthDir: authDir,
lastAuthHashes: make(map[string]string),
}
w.SetConfig(&config.Config{AuthDir: authDir})
if ok := w.reloadConfig(); !ok {
t.Fatal("expected reloadConfig to succeed")
}
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
if w.config == nil || w.config.AuthDir != authDir {
t.Fatalf("expected AuthDir to be overridden by mirroredAuthDir %s, got %+v", authDir, w.config)
}
}
func TestReloadConfigFiltersAffectedOAuthProviders(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
// Ensure SnapshotCoreAuths yields a provider that is NOT affected, so we can assert it survives.
if err := os.WriteFile(filepath.Join(authDir, "provider-b.json"), []byte(`{"type":"provider-b","email":"b@example.com"}`), 0o644); err != nil {
t.Fatalf("failed to write auth file: %v", err)
}
oldCfg := &config.Config{
AuthDir: authDir,
OAuthExcludedModels: map[string][]string{
"provider-a": {"m1"},
},
}
newCfg := &config.Config{
AuthDir: authDir,
OAuthExcludedModels: map[string][]string{
"provider-a": {"m2"},
},
}
data, err := yaml.Marshal(newCfg)
if err != nil {
t.Fatalf("failed to marshal config: %v", err)
}
if err = os.WriteFile(configPath, data, 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
w := &Watcher{
configPath: configPath,
authDir: authDir,
lastAuthHashes: make(map[string]string),
currentAuths: map[string]*coreauth.Auth{
"a": {ID: "a", Provider: "provider-a"},
},
}
w.SetConfig(oldCfg)
if ok := w.reloadConfig(); !ok {
t.Fatal("expected reloadConfig to succeed")
}
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
for _, auth := range w.currentAuths {
if auth != nil && auth.Provider == "provider-a" {
t.Fatal("expected affected provider auth to be filtered")
}
}
foundB := false
for _, auth := range w.currentAuths {
if auth != nil && auth.Provider == "provider-b" {
foundB = true
break
}
}
if !foundB {
t.Fatal("expected unaffected provider auth to remain")
}
}
func TestReloadConfigTriggersCallbackForMaxRetryCredentialsChange(t *testing.T) {
tmpDir := t.TempDir()
authDir := filepath.Join(tmpDir, "auth")
if err := os.MkdirAll(authDir, 0o755); err != nil {
t.Fatalf("failed to create auth dir: %v", err)
}
configPath := filepath.Join(tmpDir, "config.yaml")
oldCfg := &config.Config{
AuthDir: authDir,
MaxRetryCredentials: 0,
RequestRetry: 1,
MaxRetryInterval: 5,
}
newCfg := &config.Config{
AuthDir: authDir,
MaxRetryCredentials: 2,
RequestRetry: 1,
MaxRetryInterval: 5,
}
data, errMarshal := yaml.Marshal(newCfg)
if errMarshal != nil {
t.Fatalf("failed to marshal config: %v", errMarshal)
}
if errWrite := os.WriteFile(configPath, data, 0o644); errWrite != nil {
t.Fatalf("failed to write config: %v", errWrite)
}
callbackCalls := 0
callbackMaxRetryCredentials := -1
w := &Watcher{
configPath: configPath,
authDir: authDir,
lastAuthHashes: make(map[string]string),
reloadCallback: func(cfg *config.Config) {
callbackCalls++
if cfg != nil {
callbackMaxRetryCredentials = cfg.MaxRetryCredentials
}
},
}
w.SetConfig(oldCfg)
if ok := w.reloadConfig(); !ok {
t.Fatal("expected reloadConfig to succeed")
}
if callbackCalls != 1 {
t.Fatalf("expected reload callback to be called once, got %d", callbackCalls)
}
if callbackMaxRetryCredentials != 2 {
t.Fatalf("expected callback MaxRetryCredentials=2, got %d", callbackMaxRetryCredentials)
}
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
if w.config == nil || w.config.MaxRetryCredentials != 2 {
t.Fatalf("expected watcher config MaxRetryCredentials=2, got %+v", w.config)
}
}
func TestStartFailsWhenAuthDirMissing(t *testing.T) {
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("auth_dir: "+filepath.Join(tmpDir, "missing-auth")+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
authDir := filepath.Join(tmpDir, "missing-auth")
w, err := NewWatcher(configPath, authDir, nil)
if err != nil {
t.Fatalf("failed to create watcher: %v", err)
}
defer w.Stop()
w.SetConfig(&config.Config{AuthDir: authDir})
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if err := w.Start(ctx); err == nil {
t.Fatal("expected Start to fail for missing auth dir")
}
}
func TestDispatchRuntimeAuthUpdateReturnsFalseWithoutQueue(t *testing.T) {
w := &Watcher{}
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionAdd, Auth: &coreauth.Auth{ID: "a"}}); ok {
t.Fatal("expected DispatchRuntimeAuthUpdate to return false when no queue configured")
}
if ok := w.DispatchRuntimeAuthUpdate(AuthUpdate{Action: AuthUpdateActionDelete, Auth: &coreauth.Auth{ID: "a"}}); ok {
t.Fatal("expected DispatchRuntimeAuthUpdate delete to return false when no queue configured")
}
}
func TestNormalizeAuthNil(t *testing.T) {
if normalizeAuth(nil) != nil {
t.Fatal("expected normalizeAuth(nil) to return nil")
}
}
// stubStore implements coreauth.Store plus watcher-specific persistence helpers.
type stubStore struct {
authDir string
cfgPersisted int32
authPersisted int32
lastAuthMessage string
lastAuthPaths []string
}
func (s *stubStore) List(context.Context) ([]*coreauth.Auth, error) { return nil, nil }
func (s *stubStore) Save(context.Context, *coreauth.Auth) (string, error) {
return "", nil
}
func (s *stubStore) Delete(context.Context, string) error { return nil }
func (s *stubStore) PersistConfig(context.Context) error {
atomic.AddInt32(&s.cfgPersisted, 1)
return nil
}
func (s *stubStore) PersistAuthFiles(_ context.Context, message string, paths ...string) error {
atomic.AddInt32(&s.authPersisted, 1)
s.lastAuthMessage = message
s.lastAuthPaths = paths
return nil
}
func (s *stubStore) AuthDir() string { return s.authDir }
func TestNewWatcherDetectsPersisterAndAuthDir(t *testing.T) {
tmp := t.TempDir()
store := &stubStore{authDir: tmp}
orig := sdkAuth.GetTokenStore()
sdkAuth.RegisterTokenStore(store)
defer sdkAuth.RegisterTokenStore(orig)
w, err := NewWatcher("config.yaml", "auth", nil)
if err != nil {
t.Fatalf("NewWatcher failed: %v", err)
}
if w.storePersister == nil {
t.Fatal("expected storePersister to be set from token store")
}
if w.mirroredAuthDir != tmp {
t.Fatalf("expected mirroredAuthDir %s, got %s", tmp, w.mirroredAuthDir)
}
}
func TestPersistConfigAndAuthAsyncInvokePersister(t *testing.T) {
w := &Watcher{
storePersister: &stubStore{},
}
w.persistConfigAsync()
w.persistAuthAsync("msg", " a ", "", "b ")
time.Sleep(30 * time.Millisecond)
store := w.storePersister.(*stubStore)
if atomic.LoadInt32(&store.cfgPersisted) != 1 {
t.Fatalf("expected PersistConfig to be called once, got %d", store.cfgPersisted)
}
if atomic.LoadInt32(&store.authPersisted) != 1 {
t.Fatalf("expected PersistAuthFiles to be called once, got %d", store.authPersisted)
}
if store.lastAuthMessage != "msg" {
t.Fatalf("unexpected auth message: %s", store.lastAuthMessage)
}
if len(store.lastAuthPaths) != 2 || store.lastAuthPaths[0] != "a" || store.lastAuthPaths[1] != "b" {
t.Fatalf("unexpected filtered paths: %#v", store.lastAuthPaths)
}
}
func TestScheduleConfigReloadDebounces(t *testing.T) {
tmp := t.TempDir()
authDir := tmp
cfgPath := tmp + "/config.yaml"
if err := os.WriteFile(cfgPath, []byte("auth_dir: "+authDir+"\n"), 0o644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
var reloads int32
w := &Watcher{
configPath: cfgPath,
authDir: authDir,
reloadCallback: func(*config.Config) { atomic.AddInt32(&reloads, 1) },
}
w.SetConfig(&config.Config{AuthDir: authDir})
w.scheduleConfigReload()
w.scheduleConfigReload()
time.Sleep(400 * time.Millisecond)
if atomic.LoadInt32(&reloads) != 1 {
t.Fatalf("expected single debounced reload, got %d", reloads)
}
if w.lastConfigHash == "" {
t.Fatal("expected lastConfigHash to be set after reload")
}
}
func TestPrepareAuthUpdatesLockedForceAndDelete(t *testing.T) {
w := &Watcher{
currentAuths: map[string]*coreauth.Auth{
"a": {ID: "a", Provider: "p1"},
},
authQueue: make(chan AuthUpdate, 4),
}
updates := w.prepareAuthUpdatesLocked([]*coreauth.Auth{{ID: "a", Provider: "p2"}}, false)
if len(updates) != 1 || updates[0].Action != AuthUpdateActionModify || updates[0].ID != "a" {
t.Fatalf("unexpected modify updates: %+v", updates)
}
updates = w.prepareAuthUpdatesLocked([]*coreauth.Auth{{ID: "a", Provider: "p2"}}, true)
if len(updates) != 1 || updates[0].Action != AuthUpdateActionModify {
t.Fatalf("expected force modify, got %+v", updates)
}
updates = w.prepareAuthUpdatesLocked([]*coreauth.Auth{}, false)
if len(updates) != 1 || updates[0].Action != AuthUpdateActionDelete || updates[0].ID != "a" {
t.Fatalf("expected delete for missing auth, got %+v", updates)
}
}
func TestAuthEqualIgnoresTemporalFields(t *testing.T) {
now := time.Now()
a := &coreauth.Auth{ID: "x", CreatedAt: now}
b := &coreauth.Auth{ID: "x", CreatedAt: now.Add(5 * time.Second)}
if !authEqual(a, b) {
t.Fatal("expected authEqual to ignore temporal differences")
}
}
func TestDispatchLoopExitsWhenQueueNilAndContextCanceled(t *testing.T) {
w := &Watcher{
dispatchCond: nil,
pendingUpdates: map[string]AuthUpdate{"k": {ID: "k"}},
pendingOrder: []string{"k"},
}
w.dispatchCond = sync.NewCond(&w.dispatchMu)
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
w.dispatchLoop(ctx)
close(done)
}()
time.Sleep(20 * time.Millisecond)
cancel()
w.dispatchMu.Lock()
w.dispatchCond.Broadcast()
w.dispatchMu.Unlock()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("dispatchLoop did not exit after context cancel")
}
}
func TestReloadClientsFiltersOAuthProvidersWithoutRescan(t *testing.T) {
tmp := t.TempDir()
w := &Watcher{
authDir: tmp,
config: &config.Config{AuthDir: tmp},
currentAuths: map[string]*coreauth.Auth{
"a": {ID: "a", Provider: "Match"},
"b": {ID: "b", Provider: "other"},
},
lastAuthHashes: map[string]string{"cached": "hash"},
}
w.reloadClients(false, []string{"match"}, false)
w.clientsMutex.RLock()
defer w.clientsMutex.RUnlock()
if _, ok := w.currentAuths["a"]; ok {
t.Fatal("expected filtered provider to be removed")
}
if len(w.lastAuthHashes) != 1 {
t.Fatalf("expected existing hash cache to be retained, got %d", len(w.lastAuthHashes))
}
}
func TestScheduleProcessEventsStopsOnContextDone(t *testing.T) {
w := &Watcher{
watcher: &fsnotify.Watcher{
Events: make(chan fsnotify.Event, 1),
Errors: make(chan error, 1),
},
configPath: "config.yaml",
authDir: "auth",
}
ctx, cancel := context.WithCancel(context.Background())
done := make(chan struct{})
go func() {
w.processEvents(ctx)
close(done)
}()
cancel()
select {
case <-done:
case <-time.After(500 * time.Millisecond):
t.Fatal("processEvents did not exit on context cancel")
}
}
func hexString(data []byte) string {
return strings.ToLower(fmt.Sprintf("%x", data))
}
================================================
FILE: internal/wsrelay/http.go
================================================
package wsrelay
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
)
// HTTPRequest represents a proxied HTTP request delivered to websocket clients.
type HTTPRequest struct {
Method string
URL string
Headers http.Header
Body []byte
}
// HTTPResponse captures the response relayed back from websocket clients.
type HTTPResponse struct {
Status int
Headers http.Header
Body []byte
}
// StreamEvent represents a streaming response event from clients.
type StreamEvent struct {
Type string
Payload []byte
Status int
Headers http.Header
Err error
}
// NonStream executes a non-streaming HTTP request using the websocket provider.
func (m *Manager) NonStream(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) {
if req == nil {
return nil, fmt.Errorf("wsrelay: request is nil")
}
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
respCh, err := m.Send(ctx, provider, msg)
if err != nil {
return nil, err
}
var (
streamMode bool
streamResp *HTTPResponse
streamBody bytes.Buffer
)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case msg, ok := <-respCh:
if !ok {
if streamMode {
if streamResp == nil {
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
} else if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
return streamResp, nil
}
return nil, errors.New("wsrelay: connection closed during response")
}
switch msg.Type {
case MessageTypeHTTPResp:
resp := decodeResponse(msg.Payload)
if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 {
resp.Body = append(resp.Body[:0], streamBody.Bytes()...)
}
return resp, nil
case MessageTypeError:
return nil, decodeError(msg.Payload)
case MessageTypeStreamStart, MessageTypeStreamChunk:
if msg.Type == MessageTypeStreamStart {
streamMode = true
streamResp = decodeResponse(msg.Payload)
if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamBody.Reset()
continue
}
if !streamMode {
streamMode = true
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
}
chunk := decodeChunk(msg.Payload)
if len(chunk) > 0 {
streamBody.Write(chunk)
}
case MessageTypeStreamEnd:
if !streamMode {
return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil
}
if streamResp == nil {
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
} else if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
return streamResp, nil
default:
}
}
}
}
// Stream executes a streaming HTTP request and returns channel with stream events.
func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) {
if req == nil {
return nil, fmt.Errorf("wsrelay: request is nil")
}
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
respCh, err := m.Send(ctx, provider, msg)
if err != nil {
return nil, err
}
out := make(chan StreamEvent)
go func() {
defer close(out)
send := func(ev StreamEvent) bool {
if ctx == nil {
out <- ev
return true
}
select {
case <-ctx.Done():
return false
case out <- ev:
return true
}
}
for {
select {
case <-ctx.Done():
return
case msg, ok := <-respCh:
if !ok {
_ = send(StreamEvent{Err: errors.New("wsrelay: stream closed")})
return
}
switch msg.Type {
case MessageTypeStreamStart:
resp := decodeResponse(msg.Payload)
if okSend := send(StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}); !okSend {
return
}
case MessageTypeStreamChunk:
chunk := decodeChunk(msg.Payload)
if okSend := send(StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}); !okSend {
return
}
case MessageTypeStreamEnd:
_ = send(StreamEvent{Type: MessageTypeStreamEnd})
return
case MessageTypeError:
_ = send(StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)})
return
case MessageTypeHTTPResp:
resp := decodeResponse(msg.Payload)
_ = send(StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body})
return
default:
}
}
}
}()
return out, nil
}
func encodeRequest(req *HTTPRequest) map[string]any {
headers := make(map[string]any, len(req.Headers))
for key, values := range req.Headers {
copyValues := make([]string, len(values))
copy(copyValues, values)
headers[key] = copyValues
}
return map[string]any{
"method": req.Method,
"url": req.URL,
"headers": headers,
"body": string(req.Body),
"sent_at": time.Now().UTC().Format(time.RFC3339Nano),
}
}
func decodeResponse(payload map[string]any) *HTTPResponse {
if payload == nil {
return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)}
}
resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
if status, ok := payload["status"].(float64); ok {
resp.Status = int(status)
}
if headers, ok := payload["headers"].(map[string]any); ok {
for key, raw := range headers {
switch v := raw.(type) {
case []any:
for _, item := range v {
if str, ok := item.(string); ok {
resp.Headers.Add(key, str)
}
}
case []string:
for _, str := range v {
resp.Headers.Add(key, str)
}
case string:
resp.Headers.Set(key, v)
}
}
}
if body, ok := payload["body"].(string); ok {
resp.Body = []byte(body)
}
return resp
}
func decodeChunk(payload map[string]any) []byte {
if payload == nil {
return nil
}
if data, ok := payload["data"].(string); ok {
return []byte(data)
}
return nil
}
func decodeError(payload map[string]any) error {
if payload == nil {
return errors.New("wsrelay: unknown error")
}
message, _ := payload["error"].(string)
status := 0
if v, ok := payload["status"].(float64); ok {
status = int(v)
}
if message == "" {
message = "wsrelay: upstream error"
}
return fmt.Errorf("%s (status=%d)", message, status)
}
================================================
FILE: internal/wsrelay/manager.go
================================================
package wsrelay
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
// Manager exposes a websocket endpoint that proxies Gemini requests to
// connected clients.
type Manager struct {
path string
upgrader websocket.Upgrader
sessions map[string]*session
sessMutex sync.RWMutex
providerFactory func(*http.Request) (string, error)
onConnected func(string)
onDisconnected func(string, error)
logDebugf func(string, ...any)
logInfof func(string, ...any)
logWarnf func(string, ...any)
}
// Options configures a Manager instance.
type Options struct {
Path string
ProviderFactory func(*http.Request) (string, error)
OnConnected func(string)
OnDisconnected func(string, error)
LogDebugf func(string, ...any)
LogInfof func(string, ...any)
LogWarnf func(string, ...any)
}
// NewManager builds a websocket relay manager with the supplied options.
func NewManager(opts Options) *Manager {
path := strings.TrimSpace(opts.Path)
if path == "" {
path = "/v1/ws"
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
mgr := &Manager{
path: path,
sessions: make(map[string]*session),
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
},
providerFactory: opts.ProviderFactory,
onConnected: opts.OnConnected,
onDisconnected: opts.OnDisconnected,
logDebugf: opts.LogDebugf,
logInfof: opts.LogInfof,
logWarnf: opts.LogWarnf,
}
if mgr.logDebugf == nil {
mgr.logDebugf = func(string, ...any) {}
}
if mgr.logInfof == nil {
mgr.logInfof = func(string, ...any) {}
}
if mgr.logWarnf == nil {
mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) }
}
return mgr
}
// Path returns the HTTP path the manager expects for websocket upgrades.
func (m *Manager) Path() string {
if m == nil {
return "/v1/ws"
}
return m.path
}
// Handler exposes an http.Handler that upgrades connections to websocket sessions.
func (m *Manager) Handler() http.Handler {
return http.HandlerFunc(m.handleWebsocket)
}
// Stop gracefully closes all active websocket sessions.
func (m *Manager) Stop(_ context.Context) error {
m.sessMutex.Lock()
sessions := make([]*session, 0, len(m.sessions))
for _, sess := range m.sessions {
sessions = append(sessions, sess)
}
m.sessions = make(map[string]*session)
m.sessMutex.Unlock()
for _, sess := range sessions {
if sess != nil {
sess.cleanup(errors.New("wsrelay: manager stopped"))
}
}
return nil
}
// handleWebsocket upgrades the connection and wires the session into the pool.
func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
expectedPath := m.Path()
if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath {
http.NotFound(w, r)
return
}
if !strings.EqualFold(r.Method, http.MethodGet) {
w.Header().Set("Allow", http.MethodGet)
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
conn, err := m.upgrader.Upgrade(w, r, nil)
if err != nil {
m.logWarnf("wsrelay: upgrade failed: %v", err)
return
}
s := newSession(conn, m, randomProviderName())
if m.providerFactory != nil {
name, err := m.providerFactory(r)
if err != nil {
s.cleanup(err)
return
}
if strings.TrimSpace(name) != "" {
s.provider = strings.ToLower(name)
}
}
if s.provider == "" {
s.provider = strings.ToLower(s.id)
}
m.sessMutex.Lock()
var replaced *session
if existing, ok := m.sessions[s.provider]; ok {
replaced = existing
}
m.sessions[s.provider] = s
m.sessMutex.Unlock()
if replaced != nil {
replaced.cleanup(errors.New("replaced by new connection"))
}
if m.onConnected != nil {
m.onConnected(s.provider)
}
go s.run(context.Background())
}
// Send forwards the message to the specific provider connection and returns a channel
// yielding response messages.
func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) {
s := m.session(provider)
if s == nil {
return nil, fmt.Errorf("wsrelay: provider %s not connected", provider)
}
return s.request(ctx, msg)
}
func (m *Manager) session(provider string) *session {
key := strings.ToLower(strings.TrimSpace(provider))
m.sessMutex.RLock()
s := m.sessions[key]
m.sessMutex.RUnlock()
return s
}
func (m *Manager) handleSessionClosed(s *session, cause error) {
if s == nil {
return
}
key := strings.ToLower(strings.TrimSpace(s.provider))
m.sessMutex.Lock()
if cur, ok := m.sessions[key]; ok && cur == s {
delete(m.sessions, key)
}
m.sessMutex.Unlock()
if m.onDisconnected != nil {
m.onDisconnected(s.provider, cause)
}
}
func randomProviderName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
buf := make([]byte, 16)
if _, err := rand.Read(buf); err != nil {
return fmt.Sprintf("aistudio-%x", time.Now().UnixNano())
}
for i := range buf {
buf[i] = alphabet[int(buf[i])%len(alphabet)]
}
return "aistudio-" + string(buf)
}
================================================
FILE: internal/wsrelay/message.go
================================================
package wsrelay
// Message represents the JSON payload exchanged with websocket clients.
type Message struct {
ID string `json:"id"`
Type string `json:"type"`
Payload map[string]any `json:"payload,omitempty"`
}
const (
// MessageTypeHTTPReq identifies an HTTP-style request envelope.
MessageTypeHTTPReq = "http_request"
// MessageTypeHTTPResp identifies a non-streaming HTTP response envelope.
MessageTypeHTTPResp = "http_response"
// MessageTypeStreamStart marks the beginning of a streaming response.
MessageTypeStreamStart = "stream_start"
// MessageTypeStreamChunk carries a streaming response chunk.
MessageTypeStreamChunk = "stream_chunk"
// MessageTypeStreamEnd marks the completion of a streaming response.
MessageTypeStreamEnd = "stream_end"
// MessageTypeError carries an error response.
MessageTypeError = "error"
// MessageTypePing represents ping messages from clients.
MessageTypePing = "ping"
// MessageTypePong represents pong responses back to clients.
MessageTypePong = "pong"
)
================================================
FILE: internal/wsrelay/session.go
================================================
package wsrelay
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/gorilla/websocket"
)
const (
readTimeout = 60 * time.Second
writeTimeout = 10 * time.Second
maxInboundMessageLen = 64 << 20 // 64 MiB
heartbeatInterval = 30 * time.Second
)
var errClosed = errors.New("websocket session closed")
type pendingRequest struct {
ch chan Message
closeOnce sync.Once
}
func (pr *pendingRequest) close() {
if pr == nil {
return
}
pr.closeOnce.Do(func() {
close(pr.ch)
})
}
type session struct {
conn *websocket.Conn
manager *Manager
provider string
id string
closed chan struct{}
closeOnce sync.Once
writeMutex sync.Mutex
pending sync.Map // map[string]*pendingRequest
}
func newSession(conn *websocket.Conn, mgr *Manager, id string) *session {
s := &session{
conn: conn,
manager: mgr,
provider: "",
id: id,
closed: make(chan struct{}),
}
conn.SetReadLimit(maxInboundMessageLen)
conn.SetReadDeadline(time.Now().Add(readTimeout))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(readTimeout))
return nil
})
s.startHeartbeat()
return s
}
func (s *session) startHeartbeat() {
if s == nil || s.conn == nil {
return
}
ticker := time.NewTicker(heartbeatInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-s.closed:
return
case <-ticker.C:
s.writeMutex.Lock()
err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout))
s.writeMutex.Unlock()
if err != nil {
s.cleanup(err)
return
}
}
}
}()
}
func (s *session) run(ctx context.Context) {
defer s.cleanup(errClosed)
for {
var msg Message
if err := s.conn.ReadJSON(&msg); err != nil {
s.cleanup(err)
return
}
s.dispatch(msg)
}
}
func (s *session) dispatch(msg Message) {
if msg.Type == MessageTypePing {
_ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong})
return
}
if value, ok := s.pending.Load(msg.ID); ok {
req := value.(*pendingRequest)
select {
case req.ch <- msg:
default:
}
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
actual.(*pendingRequest).close()
}
}
return
}
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider)
}
}
func (s *session) send(ctx context.Context, msg Message) error {
select {
case <-s.closed:
return errClosed
default:
}
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
if err := s.conn.WriteJSON(msg); err != nil {
return fmt.Errorf("write json: %w", err)
}
return nil
}
func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) {
if msg.ID == "" {
return nil, fmt.Errorf("wsrelay: message id is required")
}
if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded {
return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID)
}
value, _ := s.pending.Load(msg.ID)
req := value.(*pendingRequest)
if err := s.send(ctx, msg); err != nil {
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
req := actual.(*pendingRequest)
req.close()
}
return nil, err
}
go func() {
select {
case <-ctx.Done():
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
actual.(*pendingRequest).close()
}
case <-s.closed:
}
}()
return req.ch, nil
}
func (s *session) cleanup(cause error) {
s.closeOnce.Do(func() {
close(s.closed)
s.pending.Range(func(key, value any) bool {
req := value.(*pendingRequest)
msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}}
select {
case req.ch <- msg:
default:
}
req.close()
return true
})
s.pending = sync.Map{}
_ = s.conn.Close()
if s.manager != nil {
s.manager.handleSessionClosed(s, cause)
}
})
}
================================================
FILE: sdk/access/errors.go
================================================
package access
import (
"fmt"
"net/http"
"strings"
)
// AuthErrorCode classifies authentication failures.
type AuthErrorCode string
const (
AuthErrorCodeNoCredentials AuthErrorCode = "no_credentials"
AuthErrorCodeInvalidCredential AuthErrorCode = "invalid_credential"
AuthErrorCodeNotHandled AuthErrorCode = "not_handled"
AuthErrorCodeInternal AuthErrorCode = "internal_error"
)
// AuthError carries authentication failure details and HTTP status.
type AuthError struct {
Code AuthErrorCode
Message string
StatusCode int
Cause error
}
func (e *AuthError) Error() string {
if e == nil {
return ""
}
message := strings.TrimSpace(e.Message)
if message == "" {
message = "authentication error"
}
if e.Cause != nil {
return fmt.Sprintf("%s: %v", message, e.Cause)
}
return message
}
func (e *AuthError) Unwrap() error {
if e == nil {
return nil
}
return e.Cause
}
// HTTPStatusCode returns a safe fallback for missing status codes.
func (e *AuthError) HTTPStatusCode() int {
if e == nil || e.StatusCode <= 0 {
return http.StatusInternalServerError
}
return e.StatusCode
}
func newAuthError(code AuthErrorCode, message string, statusCode int, cause error) *AuthError {
return &AuthError{
Code: code,
Message: message,
StatusCode: statusCode,
Cause: cause,
}
}
func NewNoCredentialsError() *AuthError {
return newAuthError(AuthErrorCodeNoCredentials, "Missing API key", http.StatusUnauthorized, nil)
}
func NewInvalidCredentialError() *AuthError {
return newAuthError(AuthErrorCodeInvalidCredential, "Invalid API key", http.StatusUnauthorized, nil)
}
func NewNotHandledError() *AuthError {
return newAuthError(AuthErrorCodeNotHandled, "authentication provider did not handle request", 0, nil)
}
func NewInternalAuthError(message string, cause error) *AuthError {
normalizedMessage := strings.TrimSpace(message)
if normalizedMessage == "" {
normalizedMessage = "Authentication service error"
}
return newAuthError(AuthErrorCodeInternal, normalizedMessage, http.StatusInternalServerError, cause)
}
func IsAuthErrorCode(authErr *AuthError, code AuthErrorCode) bool {
if authErr == nil {
return false
}
return authErr.Code == code
}
================================================
FILE: sdk/access/manager.go
================================================
package access
import (
"context"
"net/http"
"sync"
)
// Manager coordinates authentication providers.
type Manager struct {
mu sync.RWMutex
providers []Provider
}
// NewManager constructs an empty manager.
func NewManager() *Manager {
return &Manager{}
}
// SetProviders replaces the active provider list.
func (m *Manager) SetProviders(providers []Provider) {
if m == nil {
return
}
cloned := make([]Provider, len(providers))
copy(cloned, providers)
m.mu.Lock()
m.providers = cloned
m.mu.Unlock()
}
// Providers returns a snapshot of the active providers.
func (m *Manager) Providers() []Provider {
if m == nil {
return nil
}
m.mu.RLock()
defer m.mu.RUnlock()
snapshot := make([]Provider, len(m.providers))
copy(snapshot, m.providers)
return snapshot
}
// Authenticate evaluates providers until one succeeds.
func (m *Manager) Authenticate(ctx context.Context, r *http.Request) (*Result, *AuthError) {
if m == nil {
return nil, nil
}
providers := m.Providers()
if len(providers) == 0 {
return nil, nil
}
var (
missing bool
invalid bool
)
for _, provider := range providers {
if provider == nil {
continue
}
res, authErr := provider.Authenticate(ctx, r)
if authErr == nil {
return res, nil
}
if IsAuthErrorCode(authErr, AuthErrorCodeNotHandled) {
continue
}
if IsAuthErrorCode(authErr, AuthErrorCodeNoCredentials) {
missing = true
continue
}
if IsAuthErrorCode(authErr, AuthErrorCodeInvalidCredential) {
invalid = true
continue
}
return nil, authErr
}
if invalid {
return nil, NewInvalidCredentialError()
}
if missing {
return nil, NewNoCredentialsError()
}
return nil, NewNoCredentialsError()
}
================================================
FILE: sdk/access/registry.go
================================================
package access
import (
"context"
"net/http"
"strings"
"sync"
)
// Provider validates credentials for incoming requests.
type Provider interface {
Identifier() string
Authenticate(ctx context.Context, r *http.Request) (*Result, *AuthError)
}
// Result conveys authentication outcome.
type Result struct {
Provider string
Principal string
Metadata map[string]string
}
var (
registryMu sync.RWMutex
registry = make(map[string]Provider)
order []string
)
// RegisterProvider registers a pre-built provider instance for a given type identifier.
func RegisterProvider(typ string, provider Provider) {
normalizedType := strings.TrimSpace(typ)
if normalizedType == "" || provider == nil {
return
}
registryMu.Lock()
if _, exists := registry[normalizedType]; !exists {
order = append(order, normalizedType)
}
registry[normalizedType] = provider
registryMu.Unlock()
}
// UnregisterProvider removes a provider by type identifier.
func UnregisterProvider(typ string) {
normalizedType := strings.TrimSpace(typ)
if normalizedType == "" {
return
}
registryMu.Lock()
if _, exists := registry[normalizedType]; !exists {
registryMu.Unlock()
return
}
delete(registry, normalizedType)
for index := range order {
if order[index] != normalizedType {
continue
}
order = append(order[:index], order[index+1:]...)
break
}
registryMu.Unlock()
}
// RegisteredProviders returns the global provider instances in registration order.
func RegisteredProviders() []Provider {
registryMu.RLock()
if len(order) == 0 {
registryMu.RUnlock()
return nil
}
providers := make([]Provider, 0, len(order))
for _, providerType := range order {
provider, exists := registry[providerType]
if !exists || provider == nil {
continue
}
providers = append(providers, provider)
}
registryMu.RUnlock()
return providers
}
================================================
FILE: sdk/access/types.go
================================================
package access
// AccessConfig groups request authentication providers.
type AccessConfig struct {
// Providers lists configured authentication providers.
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
}
// AccessProvider describes a request authentication provider entry.
type AccessProvider struct {
// Name is the instance identifier for the provider.
Name string `yaml:"name" json:"name"`
// Type selects the provider implementation registered via the SDK.
Type string `yaml:"type" json:"type"`
// SDK optionally names a third-party SDK module providing this provider.
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
// APIKeys lists inline keys for providers that require them.
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
// Config passes provider-specific options to the implementation.
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
}
const (
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
AccessProviderTypeConfigAPIKey = "config-api-key"
// DefaultAccessProviderName is applied when no provider name is supplied.
DefaultAccessProviderName = "config-inline"
)
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
// It returns nil when no keys are supplied.
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
if len(keys) == 0 {
return nil
}
provider := &AccessProvider{
Name: DefaultAccessProviderName,
Type: AccessProviderTypeConfigAPIKey,
APIKeys: append([]string(nil), keys...),
}
return provider
}
================================================
FILE: sdk/api/handlers/claude/code_handlers.go
================================================
// Package claude provides HTTP handlers for Claude API code-related functionality.
// This package implements Claude-compatible streaming chat completions with sophisticated
// client rotation and quota management systems to ensure high availability and optimal
// resource utilization across multiple backend clients. It handles request translation
// between Claude API format and the underlying Gemini backend, providing seamless
// API compatibility while maintaining robust error handling and connection management.
package claude
import (
"bytes"
"compress/gzip"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// ClaudeCodeAPIHandler contains the handlers for Claude API endpoints.
// It holds a pool of clients to interact with the backend service.
type ClaudeCodeAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewClaudeCodeAPIHandler creates a new Claude API handlers instance.
// It takes an BaseAPIHandler instance as input and returns a ClaudeCodeAPIHandler.
//
// Parameters:
// - apiHandlers: The base API handler instance.
//
// Returns:
// - *ClaudeCodeAPIHandler: A new Claude code API handler instance.
func NewClaudeCodeAPIHandler(apiHandlers *handlers.BaseAPIHandler) *ClaudeCodeAPIHandler {
return &ClaudeCodeAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the identifier for this handler implementation.
func (h *ClaudeCodeAPIHandler) HandlerType() string {
return Claude
}
// Models returns a list of models supported by this handler.
func (h *ClaudeCodeAPIHandler) Models() []map[string]any {
// Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("claude")
}
// ClaudeMessages handles Claude-compatible streaming chat completions.
// This function implements a sophisticated client rotation and quota management system
// to ensure high availability and optimal resource utilization across multiple backend clients.
//
// Parameters:
// - c: The Gin context for the request.
func (h *ClaudeCodeAPIHandler) ClaudeMessages(c *gin.Context) {
// Extract raw JSON data from the incoming request
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if !streamResult.Exists() || streamResult.Type == gjson.False {
h.handleNonStreamingResponse(c, rawJSON)
} else {
h.handleStreamingResponse(c, rawJSON)
}
}
// ClaudeMessages handles Claude-compatible streaming chat completions.
// This function implements a sophisticated client rotation and quota management system
// to ensure high availability and optimal resource utilization across multiple backend clients.
//
// Parameters:
// - c: The Gin context for the request.
func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
// Extract raw JSON data from the incoming request
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
modelName := gjson.GetBytes(rawJSON, "model").String()
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
// ClaudeModels handles the Claude models listing endpoint.
// It returns a JSON response containing available Claude models and their specifications.
//
// Parameters:
// - c: The Gin context for the request.
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
models := h.Models()
firstID := ""
lastID := ""
if len(models) > 0 {
if id, ok := models[0]["id"].(string); ok {
firstID = id
}
if id, ok := models[len(models)-1]["id"].(string); ok {
lastID = id
}
}
c.JSON(http.StatusOK, gin.H{
"data": models,
"has_more": false,
"first_id": firstID,
"last_id": lastID,
})
}
// handleNonStreamingResponse handles non-streaming content generation requests for Claude models.
// This function processes the request synchronously and returns the complete generated
// response in a single API call. It supports various generation parameters and
// response formats.
//
// Parameters:
// - c: The Gin context for the request
// - modelName: The name of the Gemini model to use for content generation
// - rawJSON: The raw JSON request body containing generation parameters and content
func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
modelName := gjson.GetBytes(rawJSON, "model").String()
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
// Decompress gzipped responses - Claude API sometimes returns gzip without Content-Encoding header
// This fixes title generation and other non-streaming responses that arrive compressed
if len(resp) >= 2 && resp[0] == 0x1f && resp[1] == 0x8b {
gzReader, errGzip := gzip.NewReader(bytes.NewReader(resp))
if errGzip != nil {
log.Warnf("failed to decompress gzipped Claude response: %v", errGzip)
} else {
defer func() {
if errClose := gzReader.Close(); errClose != nil {
log.Warnf("failed to close Claude gzip reader: %v", errClose)
}
}()
decompressed, errRead := io.ReadAll(gzReader)
if errRead != nil {
log.Warnf("failed to read decompressed Claude response: %v", errRead)
} else {
resp = decompressed
}
}
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
// handleStreamingResponse streams Claude-compatible responses backed by Gemini.
// It sets up SSE, selects a backend client with rotation/quota logic,
// forwards chunks, and translates them to Claude CLI format.
//
// Parameters:
// - c: The Gin context for the request.
// - rawJSON: The raw JSON request body.
func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
// Get the http.Flusher interface to manually flush the response.
// This is crucial for streaming as it allows immediate sending of data chunks
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelName := gjson.GetBytes(rawJSON, "model").String()
// Create a cancellable context for the backend client request
// This allows proper cleanup and cancellation of ongoing requests
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk to determine success or failure before setting headers
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
flusher.Flush()
cliCancel(nil)
return
}
// Success! Set headers now.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write the first chunk
if len(chunk) > 0 {
_, _ = c.Writer.Write(chunk)
flusher.Flush()
}
// Continue streaming the rest
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
}
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
if len(chunk) == 0 {
return
}
_, _ = c.Writer.Write(chunk)
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
c.Status(status)
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
},
})
}
type claudeErrorDetail struct {
Type string `json:"type"`
Message string `json:"message"`
}
type claudeErrorResponse struct {
Type string `json:"type"`
Error claudeErrorDetail `json:"error"`
}
func (h *ClaudeCodeAPIHandler) toClaudeError(msg *interfaces.ErrorMessage) claudeErrorResponse {
return claudeErrorResponse{
Type: "error",
Error: claudeErrorDetail{
Type: "api_error",
Message: msg.Error.Error(),
},
}
}
================================================
FILE: sdk/api/handlers/gemini/gemini-cli_handlers.go
================================================
// Package gemini provides HTTP handlers for Gemini CLI API functionality.
// This package implements handlers that process CLI-specific requests for Gemini API operations,
// including content generation and streaming content generation endpoints.
// The handlers restrict access to localhost only and manage communication with the backend service.
package gemini
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// GeminiCLIAPIHandler contains the handlers for Gemini CLI API endpoints.
// It holds a pool of clients to interact with the backend service.
type GeminiCLIAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewGeminiCLIAPIHandler creates a new Gemini CLI API handlers instance.
// It takes an BaseAPIHandler instance as input and returns a GeminiCLIAPIHandler.
func NewGeminiCLIAPIHandler(apiHandlers *handlers.BaseAPIHandler) *GeminiCLIAPIHandler {
return &GeminiCLIAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the type of this handler.
func (h *GeminiCLIAPIHandler) HandlerType() string {
return GeminiCLI
}
// Models returns a list of models supported by this handler.
func (h *GeminiCLIAPIHandler) Models() []map[string]any {
return make([]map[string]any, 0)
}
// CLIHandler handles CLI-specific requests for Gemini API operations.
// It restricts access to localhost only and routes requests to appropriate internal handlers.
func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) {
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
c.JSON(http.StatusForbidden, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "CLI reply only allow local access",
Type: "forbidden",
},
})
return
}
rawJSON, _ := c.GetRawData()
requestRawURI := c.Request.URL.Path
if requestRawURI == "/v1internal:generateContent" {
h.handleInternalGenerateContent(c, rawJSON)
} else if requestRawURI == "/v1internal:streamGenerateContent" {
h.handleInternalStreamGenerateContent(c, rawJSON)
} else {
reqBody := bytes.NewBuffer(rawJSON)
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
for key, value := range c.Request.Header {
req.Header[key] = value
}
httpClient := util.SetProxy(h.Cfg, &http.Client{})
resp, err := httpClient.Do(req)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: string(bodyBytes),
Type: "invalid_request_error",
},
})
return
}
defer func() {
_ = resp.Body.Close()
}()
for key, value := range resp.Header {
c.Header(key, value[0])
}
output, err := io.ReadAll(resp.Body)
if err != nil {
log.Errorf("Failed to read response body: %v", err)
return
}
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
_, _ = c.Writer.Write(output)
c.Set("API_RESPONSE", output)
}
}
// handleInternalStreamGenerateContent handles streaming content generation requests.
// It sets up a server-sent event stream and forwards the request to the backend client.
// The function continuously proxies response chunks from the backend to the client.
func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
alt := h.GetAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
// handleInternalGenerateContent handles non-streaming content generation requests.
// It sends a request to the backend client and proxies the entire response back to the client at once.
func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
var keepAliveInterval *time.Duration
if alt != "" {
keepAliveInterval = new(time.Duration(0))
}
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
KeepAliveInterval: keepAliveInterval,
WriteChunk: func(chunk []byte) {
if alt == "" {
if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) {
return
}
if !bytes.HasPrefix(chunk, []byte("data:")) {
_, _ = c.Writer.Write([]byte("data: "))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
if alt == "" {
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
} else {
_, _ = c.Writer.Write(body)
}
},
})
}
================================================
FILE: sdk/api/handlers/gemini/gemini_handlers.go
================================================
// Package gemini provides HTTP handlers for Gemini API endpoints.
// This package implements handlers for managing Gemini model operations including
// model listing, content generation, streaming content generation, and token counting.
// It serves as a proxy layer between clients and the Gemini backend service,
// handling request translation, client management, and response processing.
package gemini
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
)
// GeminiAPIHandler contains the handlers for Gemini API endpoints.
// It holds a pool of clients to interact with the backend service.
type GeminiAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewGeminiAPIHandler creates a new Gemini API handlers instance.
// It takes an BaseAPIHandler instance as input and returns a GeminiAPIHandler.
func NewGeminiAPIHandler(apiHandlers *handlers.BaseAPIHandler) *GeminiAPIHandler {
return &GeminiAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the identifier for this handler implementation.
func (h *GeminiAPIHandler) HandlerType() string {
return Gemini
}
// Models returns the Gemini-compatible model metadata supported by this handler.
func (h *GeminiAPIHandler) Models() []map[string]any {
// Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("gemini")
}
// GeminiModels handles the Gemini models listing endpoint.
// It returns a JSON response containing available Gemini models and their specifications.
func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
rawModels := h.Models()
normalizedModels := make([]map[string]any, 0, len(rawModels))
defaultMethods := []string{"generateContent"}
for _, model := range rawModels {
normalizedModel := make(map[string]any, len(model))
for k, v := range model {
normalizedModel[k] = v
}
if name, ok := normalizedModel["name"].(string); ok && name != "" {
if !strings.HasPrefix(name, "models/") {
normalizedModel["name"] = "models/" + name
}
if displayName, _ := normalizedModel["displayName"].(string); displayName == "" {
normalizedModel["displayName"] = name
}
if description, _ := normalizedModel["description"].(string); description == "" {
normalizedModel["description"] = name
}
}
if _, ok := normalizedModel["supportedGenerationMethods"]; !ok {
normalizedModel["supportedGenerationMethods"] = defaultMethods
}
normalizedModels = append(normalizedModels, normalizedModel)
}
c.JSON(http.StatusOK, gin.H{
"models": normalizedModels,
})
}
// GeminiGetHandler handles GET requests for specific Gemini model information.
// It returns detailed information about a specific Gemini model based on the action parameter.
func (h *GeminiAPIHandler) GeminiGetHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
action := strings.TrimPrefix(request.Action, "/")
// Get dynamic models from the global registry and find the matching one
availableModels := h.Models()
var targetModel map[string]any
for _, model := range availableModels {
name, _ := model["name"].(string)
// Match name with or without 'models/' prefix
if name == action || name == "models/"+action {
targetModel = model
break
}
}
if targetModel != nil {
// Ensure the name has 'models/' prefix in the output if it's a Gemini model
if name, ok := targetModel["name"].(string); ok && name != "" && !strings.HasPrefix(name, "models/") {
targetModel["name"] = "models/" + name
}
c.JSON(http.StatusOK, targetModel)
return
}
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Not Found",
Type: "not_found",
},
})
}
// GeminiHandler handles POST requests for Gemini API operations.
// It routes requests to appropriate handlers based on the action parameter (model:method format).
func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
action := strings.Split(strings.TrimPrefix(request.Action, "/"), ":")
if len(action) != 2 {
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
Type: "invalid_request_error",
},
})
return
}
method := action[1]
rawJSON, _ := c.GetRawData()
switch method {
case "generateContent":
h.handleGenerateContent(c, action[0], rawJSON)
case "streamGenerateContent":
h.handleStreamGenerateContent(c, action[0], rawJSON)
case "countTokens":
h.handleCountTokens(c, action[0], rawJSON)
}
}
// handleStreamGenerateContent handles streaming content generation requests for Gemini models.
// This function establishes a Server-Sent Events connection and streams the generated content
// back to the client in real-time. It supports both SSE format and direct streaming based
// on the 'alt' query parameter.
//
// Parameters:
// - c: The Gin context for the request
// - modelName: The name of the Gemini model to use for content generation
// - rawJSON: The raw JSON request body containing generation parameters
func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
alt := h.GetAlt(c)
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Closed without data
if alt == "" {
setSSEHeaders()
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
flusher.Flush()
cliCancel(nil)
return
}
// Success! Set headers.
if alt == "" {
setSSEHeaders()
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write first chunk
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
// Continue
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
}
// handleCountTokens handles token counting requests for Gemini models.
// This function counts the number of tokens in the provided content without
// generating a response. It's useful for quota management and content validation.
//
// Parameters:
// - c: The Gin context for the request
// - modelName: The name of the Gemini model to use for token counting
// - rawJSON: The raw JSON request body containing the content to count
func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, rawJSON []byte) {
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, upstreamHeaders, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
// handleGenerateContent handles non-streaming content generation requests for Gemini models.
// This function processes the request synchronously and returns the complete generated
// response in a single API call. It supports various generation parameters and
// response formats.
//
// Parameters:
// - c: The Gin context for the request
// - modelName: The name of the Gemini model to use for content generation
// - rawJSON: The raw JSON request body containing generation parameters and content
func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
var keepAliveInterval *time.Duration
if alt != "" {
keepAliveInterval = new(time.Duration(0))
}
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
KeepAliveInterval: keepAliveInterval,
WriteChunk: func(chunk []byte) {
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
if alt == "" {
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
} else {
_, _ = c.Writer.Write(body)
}
},
})
}
================================================
FILE: sdk/api/handlers/handlers.go
================================================
// Package handlers provides core API handler functionality for the CLI Proxy API server.
// It includes common types, client management, load balancing, and error handling
// shared across all API endpoint handlers (OpenAI, Claude, Gemini).
package handlers
import (
"bytes"
"encoding/json"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"golang.org/x/net/context"
)
// ErrorResponse represents a standard error response format for the API.
// It contains a single ErrorDetail field.
type ErrorResponse struct {
// Error contains detailed information about the error that occurred.
Error ErrorDetail `json:"error"`
}
// ErrorDetail provides specific information about an error that occurred.
// It includes a human-readable message, an error type, and an optional error code.
type ErrorDetail struct {
// Message is a human-readable message providing more details about the error.
Message string `json:"message"`
// Type is the category of error that occurred (e.g., "invalid_request_error").
Type string `json:"type"`
// Code is a short code identifying the error, if applicable.
Code string `json:"code,omitempty"`
}
const idempotencyKeyMetadataKey = "idempotency_key"
const (
defaultStreamingKeepAliveSeconds = 0
defaultStreamingBootstrapRetries = 0
)
type pinnedAuthContextKey struct{}
type selectedAuthCallbackContextKey struct{}
type executionSessionContextKey struct{}
// WithPinnedAuthID returns a child context that requests execution on a specific auth ID.
func WithPinnedAuthID(ctx context.Context, authID string) context.Context {
authID = strings.TrimSpace(authID)
if authID == "" {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, pinnedAuthContextKey{}, authID)
}
// WithSelectedAuthIDCallback returns a child context that receives the selected auth ID.
func WithSelectedAuthIDCallback(ctx context.Context, callback func(string)) context.Context {
if callback == nil {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, selectedAuthCallbackContextKey{}, callback)
}
// WithExecutionSessionID returns a child context tagged with a long-lived execution session ID.
func WithExecutionSessionID(ctx context.Context, sessionID string) context.Context {
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, executionSessionContextKey{}, sessionID)
}
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
func BuildErrorResponseBody(status int, errText string) []byte {
if status <= 0 {
status = http.StatusInternalServerError
}
if strings.TrimSpace(errText) == "" {
errText = http.StatusText(status)
}
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
return []byte(trimmed)
}
errType := "invalid_request_error"
var code string
switch status {
case http.StatusUnauthorized:
errType = "authentication_error"
code = "invalid_api_key"
case http.StatusForbidden:
errType = "permission_error"
code = "insufficient_quota"
case http.StatusTooManyRequests:
errType = "rate_limit_error"
code = "rate_limit_exceeded"
case http.StatusNotFound:
errType = "invalid_request_error"
code = "model_not_found"
default:
if status >= http.StatusInternalServerError {
errType = "server_error"
code = "internal_server_error"
}
}
payload, err := json.Marshal(ErrorResponse{
Error: ErrorDetail{
Message: errText,
Type: errType,
Code: code,
},
})
if err != nil {
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText))
}
return payload
}
// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server.
// Returning 0 disables keep-alives (default when unset).
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
seconds := defaultStreamingKeepAliveSeconds
if cfg != nil {
seconds = cfg.Streaming.KeepAliveSeconds
}
if seconds <= 0 {
return 0
}
return time.Duration(seconds) * time.Second
}
// NonStreamingKeepAliveInterval returns the keep-alive interval for non-streaming responses.
// Returning 0 disables keep-alives (default when unset).
func NonStreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
seconds := 0
if cfg != nil {
seconds = cfg.NonStreamKeepAliveInterval
}
if seconds <= 0 {
return 0
}
return time.Duration(seconds) * time.Second
}
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
retries := defaultStreamingBootstrapRetries
if cfg != nil {
retries = cfg.Streaming.BootstrapRetries
}
if retries < 0 {
retries = 0
}
return retries
}
// PassthroughHeadersEnabled returns whether upstream response headers should be forwarded to clients.
// Default is false.
func PassthroughHeadersEnabled(cfg *config.SDKConfig) bool {
return cfg != nil && cfg.PassthroughHeaders
}
func requestExecutionMetadata(ctx context.Context) map[string]any {
// Idempotency-Key is an optional client-supplied header used to correlate retries.
// It is forwarded as execution metadata; when absent we generate a UUID.
key := ""
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
}
}
if key == "" {
key = uuid.NewString()
}
meta := map[string]any{idempotencyKeyMetadataKey: key}
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
}
if selectedCallback := selectedAuthIDCallbackFromContext(ctx); selectedCallback != nil {
meta[coreexecutor.SelectedAuthCallbackMetadataKey] = selectedCallback
}
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
}
return meta
}
func pinnedAuthIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(pinnedAuthContextKey{})
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
func selectedAuthIDCallbackFromContext(ctx context.Context) func(string) {
if ctx == nil {
return nil
}
raw := ctx.Value(selectedAuthCallbackContextKey{})
if callback, ok := raw.(func(string)); ok && callback != nil {
return callback
}
return nil
}
func executionSessionIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(executionSessionContextKey{})
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
// BaseAPIHandler contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service and manages
// load balancing, client selection, and configuration.
type BaseAPIHandler struct {
// AuthManager manages auth lifecycle and execution in the new architecture.
AuthManager *coreauth.Manager
// Cfg holds the current application configuration.
Cfg *config.SDKConfig
}
// NewBaseAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and configuration as input.
//
// Parameters:
// - cliClients: A slice of AI service clients
// - cfg: The application configuration
//
// Returns:
// - *BaseAPIHandler: A new API handlers instance
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
return &BaseAPIHandler{
Cfg: cfg,
AuthManager: authManager,
}
}
// UpdateClients updates the handlers' client list and configuration.
// This method is called when the configuration or authentication tokens change.
//
// Parameters:
// - clients: The new slice of AI service clients
// - cfg: The new application configuration
func (h *BaseAPIHandler) UpdateClients(cfg *config.SDKConfig) { h.Cfg = cfg }
// GetAlt extracts the 'alt' parameter from the request query string.
// It checks both 'alt' and '$alt' parameters and returns the appropriate value.
//
// Parameters:
// - c: The Gin context containing the HTTP request
//
// Returns:
// - string: The alt parameter value, or empty string if it's "sse"
func (h *BaseAPIHandler) GetAlt(c *gin.Context) string {
var alt string
var hasAlt bool
alt, hasAlt = c.GetQuery("alt")
if !hasAlt {
alt, _ = c.GetQuery("$alt")
}
if alt == "sse" {
return ""
}
return alt
}
// GetContextWithCancel creates a new context with cancellation capabilities.
// It embeds the Gin context and the API handler into the new context for later use.
// The returned cancel function also handles logging the API response if request logging is enabled.
//
// Parameters:
// - handler: The API handler associated with the request.
// - c: The Gin context of the current request.
// - ctx: The parent context (caller values/deadlines are preserved; request context adds cancellation and request ID).
//
// Returns:
// - context.Context: The new context with cancellation and embedded values.
// - APIHandlerCancelFunc: A function to cancel the context and log the response.
func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *gin.Context, ctx context.Context) (context.Context, APIHandlerCancelFunc) {
parentCtx := ctx
if parentCtx == nil {
parentCtx = context.Background()
}
var requestCtx context.Context
if c != nil && c.Request != nil {
requestCtx = c.Request.Context()
}
if requestCtx != nil && logging.GetRequestID(parentCtx) == "" {
if requestID := logging.GetRequestID(requestCtx); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
} else if requestID := logging.GetGinRequestID(c); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
}
}
newCtx, cancel := context.WithCancel(parentCtx)
if requestCtx != nil && requestCtx != parentCtx {
go func() {
select {
case <-requestCtx.Done():
cancel()
case <-newCtx.Done():
}
}()
}
newCtx = context.WithValue(newCtx, "gin", c)
newCtx = context.WithValue(newCtx, "handler", handler)
return newCtx, func(params ...interface{}) {
if h.Cfg.RequestLog && len(params) == 1 {
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(bytes.TrimSpace(existingBytes)) > 0 {
switch params[0].(type) {
case error, string:
cancel()
return
}
}
}
var payload []byte
switch data := params[0].(type) {
case []byte:
payload = data
case error:
if data != nil {
payload = []byte(data.Error())
}
case string:
payload = []byte(data)
}
if len(payload) > 0 {
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) > 0 && bytes.Contains(existingBytes, trimmedPayload) {
cancel()
return
}
}
}
appendAPIResponse(c, payload)
}
}
cancel()
}
}
// StartNonStreamingKeepAlive emits blank lines every 5 seconds while waiting for a non-streaming response.
// It returns a stop function that must be called before writing the final response.
func (h *BaseAPIHandler) StartNonStreamingKeepAlive(c *gin.Context, ctx context.Context) func() {
if h == nil || c == nil {
return func() {}
}
interval := NonStreamingKeepAliveInterval(h.Cfg)
if interval <= 0 {
return func() {}
}
flusher, ok := c.Writer.(http.Flusher)
if !ok {
return func() {}
}
if ctx == nil {
ctx = context.Background()
}
stopChan := make(chan struct{})
var stopOnce sync.Once
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
ticker := time.NewTicker(interval)
defer ticker.Stop()
for {
select {
case <-stopChan:
return
case <-ctx.Done():
return
case <-ticker.C:
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
}
}
}()
return func() {
stopOnce.Do(func() {
close(stopChan)
})
wg.Wait()
}
}
// appendAPIResponse preserves any previously captured API response and appends new data.
func appendAPIResponse(c *gin.Context, data []byte) {
if c == nil || len(data) == 0 {
return
}
// Capture timestamp on first API response
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); !exists {
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
}
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
combined := make([]byte, 0, len(existingBytes)+len(data)+1)
combined = append(combined, existingBytes...)
if existingBytes[len(existingBytes)-1] != '\n' {
combined = append(combined, '\n')
}
combined = append(combined, data...)
c.Set("API_RESPONSE", combined)
return
}
}
c.Set("API_RESPONSE", bytes.Clone(data))
}
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
payload := rawJSON
if len(payload) == 0 {
payload = nil
}
req := coreexecutor.Request{
Model: normalizedModel,
Payload: payload,
}
opts := coreexecutor.Options{
Stream: false,
Alt: alt,
OriginalRequest: rawJSON,
SourceFormat: sdktranslator.FromString(handlerType),
}
opts.Metadata = reqMeta
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil {
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
if !PassthroughHeadersEnabled(h.Cfg) {
return resp.Payload, nil, nil
}
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
}
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, http.Header, *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
return nil, nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
payload := rawJSON
if len(payload) == 0 {
payload = nil
}
req := coreexecutor.Request{
Model: normalizedModel,
Payload: payload,
}
opts := coreexecutor.Options{
Stream: false,
Alt: alt,
OriginalRequest: rawJSON,
SourceFormat: sdktranslator.FromString(handlerType),
}
opts.Metadata = reqMeta
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil {
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
return nil, nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
}
if !PassthroughHeadersEnabled(h.Cfg) {
return resp.Payload, nil, nil
}
return resp.Payload, FilterUpstreamHeaders(resp.Headers), nil
}
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
// This path is the only supported execution route.
// The returned http.Header carries upstream response headers captured before streaming begins.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, http.Header, <-chan *interfaces.ErrorMessage) {
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- errMsg
close(errChan)
return nil, nil, errChan
}
reqMeta := requestExecutionMetadata(ctx)
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
payload := rawJSON
if len(payload) == 0 {
payload = nil
}
req := coreexecutor.Request{
Model: normalizedModel,
Payload: payload,
}
opts := coreexecutor.Options{
Stream: true,
Alt: alt,
OriginalRequest: rawJSON,
SourceFormat: sdktranslator.FromString(handlerType),
}
opts.Metadata = reqMeta
streamResult, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
status := http.StatusInternalServerError
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
close(errChan)
return nil, nil, errChan
}
passthroughHeadersEnabled := PassthroughHeadersEnabled(h.Cfg)
// Capture upstream headers from the initial connection synchronously before the goroutine starts.
// Keep a mutable map so bootstrap retries can replace it before first payload is sent.
var upstreamHeaders http.Header
if passthroughHeadersEnabled {
upstreamHeaders = cloneHeader(FilterUpstreamHeaders(streamResult.Headers))
if upstreamHeaders == nil {
upstreamHeaders = make(http.Header)
}
}
chunks := streamResult.Chunks
dataChan := make(chan []byte)
errChan := make(chan *interfaces.ErrorMessage, 1)
go func() {
defer close(dataChan)
defer close(errChan)
sentPayload := false
bootstrapRetries := 0
maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg)
sendErr := func(msg *interfaces.ErrorMessage) bool {
if ctx == nil {
errChan <- msg
return true
}
select {
case <-ctx.Done():
return false
case errChan <- msg:
return true
}
}
sendData := func(chunk []byte) bool {
if ctx == nil {
dataChan <- chunk
return true
}
select {
case <-ctx.Done():
return false
case dataChan <- chunk:
return true
}
}
bootstrapEligible := func(err error) bool {
status := statusFromError(err)
if status == 0 {
return true
}
switch status {
case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired,
http.StatusRequestTimeout, http.StatusTooManyRequests:
return true
default:
return status >= http.StatusInternalServerError
}
}
outer:
for {
for {
var chunk coreexecutor.StreamChunk
var ok bool
if ctx != nil {
select {
case <-ctx.Done():
return
case chunk, ok = <-chunks:
}
} else {
chunk, ok = <-chunks
}
if !ok {
return
}
if chunk.Err != nil {
streamErr := chunk.Err
// Safe bootstrap recovery: if the upstream fails before any payload bytes are sent,
// retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback.
if !sentPayload {
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
bootstrapRetries++
retryResult, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if retryErr == nil {
if passthroughHeadersEnabled {
replaceHeader(upstreamHeaders, FilterUpstreamHeaders(retryResult.Headers))
}
chunks = retryResult.Chunks
continue outer
}
streamErr = retryErr
}
}
status := http.StatusInternalServerError
if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
_ = sendErr(&interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon})
return
}
if len(chunk.Payload) > 0 {
if handlerType == "openai-response" {
if err := validateSSEDataJSON(chunk.Payload); err != nil {
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
return
}
}
sentPayload = true
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
return
}
}
}
}
}()
return dataChan, upstreamHeaders, errChan
}
func validateSSEDataJSON(chunk []byte) error {
for _, line := range bytes.Split(chunk, []byte("\n")) {
line = bytes.TrimSpace(line)
if len(line) == 0 {
continue
}
if !bytes.HasPrefix(line, []byte("data:")) {
continue
}
data := bytes.TrimSpace(line[5:])
if len(data) == 0 {
continue
}
if bytes.Equal(data, []byte("[DONE]")) {
continue
}
if json.Valid(data) {
continue
}
const max = 512
preview := data
if len(preview) > max {
preview = preview[:max]
}
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
}
return nil
}
func statusFromError(err error) int {
if err == nil {
return 0
}
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
return code
}
}
return 0
}
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
resolvedModelName := modelName
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
if initialSuffix.HasSuffix {
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
} else {
resolvedModelName = resolvedBase
}
} else {
resolvedModelName = util.ResolveAutoModel(modelName)
}
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
providers = util.GetProviderName(baseModel)
// Fallback: if baseModel has no provider but differs from resolvedModelName,
// try using the full model name. This handles edge cases where custom models
// may be registered with their full suffixed name (e.g., "my-model(8192)").
// Evaluated in Story 11.8: This fallback is intentionally preserved to support
// custom model registrations that include thinking suffixes.
if len(providers) == 0 && baseModel != resolvedModelName {
providers = util.GetProviderName(resolvedModelName)
}
if len(providers) == 0 {
return nil, "", &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: fmt.Errorf("unknown provider for model %s", modelName)}
}
// The thinking suffix is preserved in the model name itself, so no
// metadata-based configuration passing is needed.
return providers, resolvedModelName, nil
}
func cloneBytes(src []byte) []byte {
if len(src) == 0 {
return nil
}
dst := make([]byte, len(src))
copy(dst, src)
return dst
}
func cloneHeader(src http.Header) http.Header {
if src == nil {
return nil
}
dst := make(http.Header, len(src))
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
return dst
}
func replaceHeader(dst http.Header, src http.Header) {
for key := range dst {
delete(dst, key)
}
for key, values := range src {
dst[key] = append([]string(nil), values...)
}
}
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
status := http.StatusInternalServerError
if msg != nil && msg.StatusCode > 0 {
status = msg.StatusCode
}
if msg != nil && msg.Addon != nil && PassthroughHeadersEnabled(h.Cfg) {
for key, values := range msg.Addon {
if len(values) == 0 {
continue
}
c.Writer.Header().Del(key)
for _, value := range values {
c.Writer.Header().Add(key, value)
}
}
}
errText := http.StatusText(status)
if msg != nil && msg.Error != nil {
if v := strings.TrimSpace(msg.Error.Error()); v != "" {
errText = v
}
}
body := BuildErrorResponseBody(status, errText)
// Append first to preserve upstream response logs, then drop duplicate payloads if already recorded.
var previous []byte
if existing, exists := c.Get("API_RESPONSE"); exists {
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
previous = existingBytes
}
}
appendAPIResponse(c, body)
trimmedErrText := strings.TrimSpace(errText)
trimmedBody := bytes.TrimSpace(body)
if len(previous) > 0 {
if (trimmedErrText != "" && bytes.Contains(previous, []byte(trimmedErrText))) ||
(len(trimmedBody) > 0 && bytes.Contains(previous, trimmedBody)) {
c.Set("API_RESPONSE", previous)
}
}
if !c.Writer.Written() {
c.Writer.Header().Set("Content-Type", "application/json")
}
c.Status(status)
_, _ = c.Writer.Write(body)
}
func (h *BaseAPIHandler) LoggingAPIResponseError(ctx context.Context, err *interfaces.ErrorMessage) {
if h.Cfg.RequestLog {
if ginContext, ok := ctx.Value("gin").(*gin.Context); ok {
if apiResponseErrors, isExist := ginContext.Get("API_RESPONSE_ERROR"); isExist {
if slicesAPIResponseError, isOk := apiResponseErrors.([]*interfaces.ErrorMessage); isOk {
slicesAPIResponseError = append(slicesAPIResponseError, err)
ginContext.Set("API_RESPONSE_ERROR", slicesAPIResponseError)
}
} else {
// Create new response data entry
ginContext.Set("API_RESPONSE_ERROR", []*interfaces.ErrorMessage{err})
}
}
}
}
// APIHandlerCancelFunc is a function type for canceling an API handler's context.
// It can optionally accept parameters, which are used for logging the response.
type APIHandlerCancelFunc func(params ...interface{})
================================================
FILE: sdk/api/handlers/handlers_error_response_test.go
================================================
package handlers
import (
"errors"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestWriteErrorResponse_AddonHeadersDisabledByDefault(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
handler := NewBaseAPIHandlers(nil, nil)
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
StatusCode: http.StatusTooManyRequests,
Error: errors.New("rate limit"),
Addon: http.Header{
"Retry-After": {"30"},
"X-Request-Id": {"req-1"},
},
})
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
}
if got := recorder.Header().Get("Retry-After"); got != "" {
t.Fatalf("Retry-After should be empty when passthrough is disabled, got %q", got)
}
if got := recorder.Header().Get("X-Request-Id"); got != "" {
t.Fatalf("X-Request-Id should be empty when passthrough is disabled, got %q", got)
}
}
func TestWriteErrorResponse_AddonHeadersEnabled(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodGet, "/", nil)
c.Writer.Header().Set("X-Request-Id", "old-value")
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{PassthroughHeaders: true}, nil)
handler.WriteErrorResponse(c, &interfaces.ErrorMessage{
StatusCode: http.StatusTooManyRequests,
Error: errors.New("rate limit"),
Addon: http.Header{
"Retry-After": {"30"},
"X-Request-Id": {"new-1", "new-2"},
},
})
if recorder.Code != http.StatusTooManyRequests {
t.Fatalf("status = %d, want %d", recorder.Code, http.StatusTooManyRequests)
}
if got := recorder.Header().Get("Retry-After"); got != "30" {
t.Fatalf("Retry-After = %q, want %q", got, "30")
}
if got := recorder.Header().Values("X-Request-Id"); !reflect.DeepEqual(got, []string{"new-1", "new-2"}) {
t.Fatalf("X-Request-Id = %#v, want %#v", got, []string{"new-1", "new-2"})
}
}
================================================
FILE: sdk/api/handlers/handlers_request_details_test.go
================================================
package handlers
import (
"reflect"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestGetRequestDetails_PreservesSuffix(t *testing.T) {
modelRegistry := registry.GetGlobalRegistry()
now := time.Now().Unix()
modelRegistry.RegisterClient("test-request-details-gemini", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", Created: now + 30},
{ID: "gemini-2.5-flash", Created: now + 25},
})
modelRegistry.RegisterClient("test-request-details-openai", "openai", []*registry.ModelInfo{
{ID: "gpt-5.2", Created: now + 20},
})
modelRegistry.RegisterClient("test-request-details-claude", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4-5", Created: now + 5},
})
// Ensure cleanup of all test registrations.
clientIDs := []string{
"test-request-details-gemini",
"test-request-details-openai",
"test-request-details-claude",
}
for _, clientID := range clientIDs {
id := clientID
t.Cleanup(func() {
modelRegistry.UnregisterClient(id)
})
}
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, coreauth.NewManager(nil, nil, nil))
tests := []struct {
name string
inputModel string
wantProviders []string
wantModel string
wantErr bool
}{
{
name: "numeric suffix preserved",
inputModel: "gemini-2.5-pro(8192)",
wantProviders: []string{"gemini"},
wantModel: "gemini-2.5-pro(8192)",
wantErr: false,
},
{
name: "level suffix preserved",
inputModel: "gpt-5.2(high)",
wantProviders: []string{"openai"},
wantModel: "gpt-5.2(high)",
wantErr: false,
},
{
name: "no suffix unchanged",
inputModel: "claude-sonnet-4-5",
wantProviders: []string{"claude"},
wantModel: "claude-sonnet-4-5",
wantErr: false,
},
{
name: "unknown model with suffix",
inputModel: "unknown-model(8192)",
wantProviders: nil,
wantModel: "",
wantErr: true,
},
{
name: "auto suffix resolved",
inputModel: "auto(high)",
wantProviders: []string{"gemini"},
wantModel: "gemini-2.5-pro(high)",
wantErr: false,
},
{
name: "special suffix none preserved",
inputModel: "gemini-2.5-flash(none)",
wantProviders: []string{"gemini"},
wantModel: "gemini-2.5-flash(none)",
wantErr: false,
},
{
name: "special suffix auto preserved",
inputModel: "claude-sonnet-4-5(auto)",
wantProviders: []string{"claude"},
wantModel: "claude-sonnet-4-5(auto)",
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
providers, model, errMsg := handler.getRequestDetails(tt.inputModel)
if (errMsg != nil) != tt.wantErr {
t.Fatalf("getRequestDetails() error = %v, wantErr %v", errMsg, tt.wantErr)
}
if errMsg != nil {
return
}
if !reflect.DeepEqual(providers, tt.wantProviders) {
t.Fatalf("getRequestDetails() providers = %v, want %v", providers, tt.wantProviders)
}
if model != tt.wantModel {
t.Fatalf("getRequestDetails() model = %v, want %v", model, tt.wantModel)
}
})
}
}
================================================
FILE: sdk/api/handlers/handlers_stream_bootstrap_test.go
================================================
package handlers
import (
"context"
"net/http"
"sync"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
type failOnceStreamExecutor struct {
mu sync.Mutex
calls int
}
func (e *failOnceStreamExecutor) Identifier() string { return "codex" }
func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.mu.Lock()
e.calls++
call := e.calls
e.mu.Unlock()
ch := make(chan coreexecutor.StreamChunk, 1)
if call == 1 {
ch <- coreexecutor.StreamChunk{
Err: &coreauth.Error{
Code: "unauthorized",
Message: "unauthorized",
Retryable: false,
HTTPStatus: http.StatusUnauthorized,
},
}
close(ch)
return &coreexecutor.StreamResult{
Headers: http.Header{"X-Upstream-Attempt": {"1"}},
Chunks: ch,
}, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return &coreexecutor.StreamResult{
Headers: http.Header{"X-Upstream-Attempt": {"2"}},
Chunks: ch,
}, nil
}
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *failOnceStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *failOnceStreamExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
type payloadThenErrorStreamExecutor struct {
mu sync.Mutex
calls int
}
func (e *payloadThenErrorStreamExecutor) Identifier() string { return "codex" }
func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.mu.Lock()
e.calls++
e.mu.Unlock()
ch := make(chan coreexecutor.StreamChunk, 2)
ch <- coreexecutor.StreamChunk{Payload: []byte("partial")}
ch <- coreexecutor.StreamChunk{
Err: &coreauth.Error{
Code: "upstream_closed",
Message: "upstream closed",
Retryable: false,
HTTPStatus: http.StatusBadGateway,
},
}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *payloadThenErrorStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *payloadThenErrorStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *payloadThenErrorStreamExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
type authAwareStreamExecutor struct {
mu sync.Mutex
calls int
authIDs []string
}
type invalidJSONStreamExecutor struct{}
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
ch := make(chan coreexecutor.StreamChunk, 1)
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (*coreexecutor.StreamResult, error) {
_ = ctx
_ = req
_ = opts
ch := make(chan coreexecutor.StreamChunk, 1)
authID := ""
if auth != nil {
authID = auth.ID
}
e.mu.Lock()
e.calls++
e.authIDs = append(e.authIDs, authID)
e.mu.Unlock()
if authID == "auth1" {
ch <- coreexecutor.StreamChunk{
Err: &coreauth.Error{
Code: "unauthorized",
Message: "unauthorized",
Retryable: false,
HTTPStatus: http.StatusUnauthorized,
},
}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return &coreexecutor.StreamResult{Chunks: ch}, nil
}
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *authAwareStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *authAwareStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
func (e *authAwareStreamExecutor) AuthIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.authIDs))
copy(out, e.authIDs)
return out
}
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
PassthroughHeaders: true,
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if executor.Calls() != 2 {
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
}
upstreamAttemptHeader := upstreamHeaders.Get("X-Upstream-Attempt")
if upstreamAttemptHeader != "2" {
t.Fatalf("expected upstream header from retry attempt, got %q", upstreamAttemptHeader)
}
}
func TestExecuteStreamWithAuthManager_HeaderPassthroughDisabledByDefault(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, upstreamHeaders, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if upstreamHeaders != nil {
t.Fatalf("expected nil upstream headers when passthrough is disabled, got %#v", upstreamHeaders)
}
}
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
executor := &payloadThenErrorStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
var gotErr error
var gotStatus int
for msg := range errChan {
if msg != nil && msg.Error != nil {
gotErr = msg.Error
gotStatus = msg.StatusCode
}
}
if string(got) != "partial" {
t.Fatalf("expected payload partial, got %q", string(got))
}
if gotErr == nil {
t.Fatalf("expected terminal error, got nil")
}
if gotStatus != http.StatusBadGateway {
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, gotStatus)
}
if executor.Calls() != 1 {
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
}
}
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
ctx := WithPinnedAuthID(context.Background(), "auth1")
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
var gotErr error
for msg := range errChan {
if msg != nil && msg.Error != nil {
gotErr = msg.Error
}
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
if gotErr == nil {
t.Fatalf("expected terminal error, got nil")
}
authIDs := executor.AuthIDs()
if len(authIDs) == 0 {
t.Fatalf("expected at least one upstream attempt")
}
for _, authID := range authIDs {
if authID != "auth1" {
t.Fatalf("expected all attempts on auth1, got sequence %v", authIDs)
}
}
}
func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 0,
},
}, manager)
selectedAuthID := ""
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
selectedAuthID = authID
})
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if selectedAuthID != "auth2" {
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
}
}
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
executor := &invalidJSONStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
gotErr := false
for msg := range errChan {
if msg == nil {
continue
}
if msg.StatusCode != http.StatusBadGateway {
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
}
if msg.Error == nil {
t.Fatalf("expected error")
}
gotErr = true
}
if !gotErr {
t.Fatalf("expected terminal error")
}
}
================================================
FILE: sdk/api/handlers/header_filter.go
================================================
package handlers
import (
"net/http"
"strings"
)
// hopByHopHeaders lists RFC 7230 Section 6.1 hop-by-hop headers that MUST NOT
// be forwarded by proxies, plus security-sensitive headers that should not leak.
var hopByHopHeaders = map[string]struct{}{
// RFC 7230 hop-by-hop
"Connection": {},
"Keep-Alive": {},
"Proxy-Authenticate": {},
"Proxy-Authorization": {},
"Te": {},
"Trailer": {},
"Transfer-Encoding": {},
"Upgrade": {},
// Security-sensitive
"Set-Cookie": {},
// CPA-managed (set by handlers, not upstream)
"Content-Length": {},
"Content-Encoding": {},
}
// FilterUpstreamHeaders returns a copy of src with hop-by-hop and security-sensitive
// headers removed. Returns nil if src is nil or empty after filtering.
func FilterUpstreamHeaders(src http.Header) http.Header {
if src == nil {
return nil
}
connectionScoped := connectionScopedHeaders(src)
dst := make(http.Header)
for key, values := range src {
canonicalKey := http.CanonicalHeaderKey(key)
if _, blocked := hopByHopHeaders[canonicalKey]; blocked {
continue
}
if _, scoped := connectionScoped[canonicalKey]; scoped {
continue
}
dst[key] = values
}
if len(dst) == 0 {
return nil
}
return dst
}
func connectionScopedHeaders(src http.Header) map[string]struct{} {
scoped := make(map[string]struct{})
for _, rawValue := range src.Values("Connection") {
for _, token := range strings.Split(rawValue, ",") {
headerName := strings.TrimSpace(token)
if headerName == "" {
continue
}
scoped[http.CanonicalHeaderKey(headerName)] = struct{}{}
}
}
return scoped
}
// WriteUpstreamHeaders writes filtered upstream headers to the gin response writer.
// Headers already set by CPA (e.g., Content-Type) are NOT overwritten.
func WriteUpstreamHeaders(dst http.Header, src http.Header) {
if src == nil {
return
}
for key, values := range src {
// Don't overwrite headers already set by CPA handlers
if dst.Get(key) != "" {
continue
}
for _, v := range values {
dst.Add(key, v)
}
}
}
================================================
FILE: sdk/api/handlers/header_filter_test.go
================================================
package handlers
import (
"net/http"
"testing"
)
func TestFilterUpstreamHeaders_RemovesConnectionScopedHeaders(t *testing.T) {
src := http.Header{}
src.Add("Connection", "keep-alive, x-hop-a, x-hop-b")
src.Add("Connection", "x-hop-c")
src.Set("Keep-Alive", "timeout=5")
src.Set("X-Hop-A", "a")
src.Set("X-Hop-B", "b")
src.Set("X-Hop-C", "c")
src.Set("X-Request-Id", "req-1")
src.Set("Set-Cookie", "session=secret")
filtered := FilterUpstreamHeaders(src)
if filtered == nil {
t.Fatalf("expected filtered headers, got nil")
}
requestID := filtered.Get("X-Request-Id")
if requestID != "req-1" {
t.Fatalf("expected X-Request-Id to be preserved, got %q", requestID)
}
blockedHeaderKeys := []string{
"Connection",
"Keep-Alive",
"X-Hop-A",
"X-Hop-B",
"X-Hop-C",
"Set-Cookie",
}
for _, key := range blockedHeaderKeys {
value := filtered.Get(key)
if value != "" {
t.Fatalf("expected %s to be removed, got %q", key, value)
}
}
}
func TestFilterUpstreamHeaders_ReturnsNilWhenAllHeadersBlocked(t *testing.T) {
src := http.Header{}
src.Add("Connection", "x-hop-a")
src.Set("X-Hop-A", "a")
src.Set("Set-Cookie", "session=secret")
filtered := FilterUpstreamHeaders(src)
if filtered != nil {
t.Fatalf("expected nil when all headers are filtered, got %#v", filtered)
}
}
================================================
FILE: sdk/api/handlers/openai/openai_handlers.go
================================================
// Package openai provides HTTP handlers for OpenAI API endpoints.
// This package implements the OpenAI-compatible API interface, including model listing
// and chat completion functionality. It supports both streaming and non-streaming responses,
// and manages a pool of clients to interact with backend services.
// The handlers translate OpenAI API requests to the appropriate backend format and
// convert responses back to OpenAI-compatible format.
package openai
import (
"context"
"encoding/json"
"fmt"
"net/http"
"sync"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
responsesconverter "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/openai/responses"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// OpenAIAPIHandler contains the handlers for OpenAI API endpoints.
// It holds a pool of clients to interact with the backend service.
type OpenAIAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewOpenAIAPIHandler creates a new OpenAI API handlers instance.
// It takes an BaseAPIHandler instance as input and returns an OpenAIAPIHandler.
//
// Parameters:
// - apiHandlers: The base API handlers instance
//
// Returns:
// - *OpenAIAPIHandler: A new OpenAI API handlers instance
func NewOpenAIAPIHandler(apiHandlers *handlers.BaseAPIHandler) *OpenAIAPIHandler {
return &OpenAIAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the identifier for this handler implementation.
func (h *OpenAIAPIHandler) HandlerType() string {
return OpenAI
}
// Models returns the OpenAI-compatible model metadata supported by this handler.
func (h *OpenAIAPIHandler) Models() []map[string]any {
// Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("openai")
}
// OpenAIModels handles the /v1/models endpoint.
// It returns a list of available AI models with their capabilities
// and specifications in OpenAI-compatible format.
func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) {
// Get all available models
allModels := h.Models()
// Filter to only include the 4 required fields: id, object, created, owned_by
filteredModels := make([]map[string]any, len(allModels))
for i, model := range allModels {
filteredModel := map[string]any{
"id": model["id"],
"object": model["object"],
}
// Add created field if it exists
if created, exists := model["created"]; exists {
filteredModel["created"] = created
}
// Add owned_by field if it exists
if ownedBy, exists := model["owned_by"]; exists {
filteredModel["owned_by"] = ownedBy
}
filteredModels[i] = filteredModel
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": filteredModels,
})
}
// ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
stream := streamResult.Type == gjson.True
// Some clients send OpenAI Responses-format payloads to /v1/chat/completions.
// Convert them to Chat Completions so downstream translators preserve tool metadata.
if shouldTreatAsResponsesFormat(rawJSON) {
modelName := gjson.GetBytes(rawJSON, "model").String()
rawJSON = responsesconverter.ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName, rawJSON, stream)
stream = gjson.GetBytes(rawJSON, "stream").Bool()
}
if stream {
h.handleStreamingResponse(c, rawJSON)
} else {
h.handleNonStreamingResponse(c, rawJSON)
}
}
// shouldTreatAsResponsesFormat detects OpenAI Responses-style payloads that are
// accidentally sent to the Chat Completions endpoint.
func shouldTreatAsResponsesFormat(rawJSON []byte) bool {
if gjson.GetBytes(rawJSON, "messages").Exists() {
return false
}
if gjson.GetBytes(rawJSON, "input").Exists() {
return true
}
if gjson.GetBytes(rawJSON, "instructions").Exists() {
return true
}
return false
}
// Completions handles the /v1/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider.
// This endpoint follows the OpenAI completions API specification.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
func (h *OpenAIAPIHandler) Completions(c *gin.Context) {
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True {
h.handleCompletionsStreamingResponse(c, rawJSON)
} else {
h.handleCompletionsNonStreamingResponse(c, rawJSON)
}
}
// convertCompletionsRequestToChatCompletions converts OpenAI completions API request to chat completions format.
// This allows the completions endpoint to use the existing chat completions infrastructure.
//
// Parameters:
// - rawJSON: The raw JSON bytes of the completions request
//
// Returns:
// - []byte: The converted chat completions request
func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
root := gjson.ParseBytes(rawJSON)
// Extract prompt from completions request
prompt := root.Get("prompt").String()
if prompt == "" {
prompt = "Complete this:"
}
// Create chat completions structure
out := `{"model":"","messages":[{"role":"user","content":""}]}`
// Set model
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
// Set the prompt as user message content
out, _ = sjson.Set(out, "messages.0.content", prompt)
// Copy other parameters from completions to chat completions
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
}
if temperature := root.Get("temperature"); temperature.Exists() {
out, _ = sjson.Set(out, "temperature", temperature.Float())
}
if topP := root.Get("top_p"); topP.Exists() {
out, _ = sjson.Set(out, "top_p", topP.Float())
}
if frequencyPenalty := root.Get("frequency_penalty"); frequencyPenalty.Exists() {
out, _ = sjson.Set(out, "frequency_penalty", frequencyPenalty.Float())
}
if presencePenalty := root.Get("presence_penalty"); presencePenalty.Exists() {
out, _ = sjson.Set(out, "presence_penalty", presencePenalty.Float())
}
if stop := root.Get("stop"); stop.Exists() {
out, _ = sjson.SetRaw(out, "stop", stop.Raw)
}
if stream := root.Get("stream"); stream.Exists() {
out, _ = sjson.Set(out, "stream", stream.Bool())
}
if logprobs := root.Get("logprobs"); logprobs.Exists() {
out, _ = sjson.Set(out, "logprobs", logprobs.Bool())
}
if topLogprobs := root.Get("top_logprobs"); topLogprobs.Exists() {
out, _ = sjson.Set(out, "top_logprobs", topLogprobs.Int())
}
if echo := root.Get("echo"); echo.Exists() {
out, _ = sjson.Set(out, "echo", echo.Bool())
}
return []byte(out)
}
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
// This ensures the completions endpoint returns data in the expected format.
//
// Parameters:
// - rawJSON: The raw JSON bytes of the chat completions response
//
// Returns:
// - []byte: The converted completions response
func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
root := gjson.ParseBytes(rawJSON)
// Base completions response structure
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
// Copy basic fields
if id := root.Get("id"); id.Exists() {
out, _ = sjson.Set(out, "id", id.String())
}
if created := root.Get("created"); created.Exists() {
out, _ = sjson.Set(out, "created", created.Int())
}
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
}
// Convert choices from chat completions to completions format
var choices []interface{}
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
completionsChoice := map[string]interface{}{
"index": choice.Get("index").Int(),
}
// Extract text content from message.content
if message := choice.Get("message"); message.Exists() {
if content := message.Get("content"); content.Exists() {
completionsChoice["text"] = content.String()
}
} else if delta := choice.Get("delta"); delta.Exists() {
// For streaming responses, use delta.content
if content := delta.Get("content"); content.Exists() {
completionsChoice["text"] = content.String()
}
}
// Copy finish_reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
completionsChoice["finish_reason"] = finishReason.String()
}
// Copy logprobs if present
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
completionsChoice["logprobs"] = logprobs.Value()
}
choices = append(choices, completionsChoice)
return true
})
}
if len(choices) > 0 {
choicesJSON, _ := json.Marshal(choices)
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
}
return []byte(out)
}
// convertChatCompletionsStreamChunkToCompletions converts a streaming chat completions chunk to completions format.
// This handles the real-time conversion of streaming response chunks and filters out empty text responses.
//
// Parameters:
// - chunkData: The raw JSON bytes of a single chat completions stream chunk
//
// Returns:
// - []byte: The converted completions stream chunk, or nil if should be filtered out
func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
root := gjson.ParseBytes(chunkData)
// Check if this chunk has any meaningful content
hasContent := false
hasUsage := root.Get("usage").Exists()
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
// Check if delta has content or finish_reason
if delta := choice.Get("delta"); delta.Exists() {
if content := delta.Get("content"); content.Exists() && content.String() != "" {
hasContent = true
return false // Break out of forEach
}
}
// Also check for finish_reason to ensure we don't skip final chunks
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "" && finishReason.String() != "null" {
hasContent = true
return false // Break out of forEach
}
return true
})
}
// If no meaningful content and no usage, return nil to indicate this chunk should be skipped
if !hasContent && !hasUsage {
return nil
}
// Base completions stream response structure
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
// Copy basic fields
if id := root.Get("id"); id.Exists() {
out, _ = sjson.Set(out, "id", id.String())
}
if created := root.Get("created"); created.Exists() {
out, _ = sjson.Set(out, "created", created.Int())
}
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
// Convert choices from chat completions delta to completions format
var choices []interface{}
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
completionsChoice := map[string]interface{}{
"index": choice.Get("index").Int(),
}
// Extract text content from delta.content
if delta := choice.Get("delta"); delta.Exists() {
if content := delta.Get("content"); content.Exists() && content.String() != "" {
completionsChoice["text"] = content.String()
} else {
completionsChoice["text"] = ""
}
} else {
completionsChoice["text"] = ""
}
// Copy finish_reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "null" {
completionsChoice["finish_reason"] = finishReason.String()
}
// Copy logprobs if present
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
completionsChoice["logprobs"] = logprobs.Value()
}
choices = append(choices, completionsChoice)
return true
})
}
if len(choices) > 0 {
choicesJSON, _ := json.Marshal(choices)
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
}
// Copy usage if present
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
}
return []byte(out)
}
// handleNonStreamingResponse handles non-streaming chat completion responses
// for Gemini models. It selects a client from the pool, sends the request, and
// aggregates the response before sending it back to the client in OpenAI format.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
// handleStreamingResponse handles streaming responses for Gemini models.
// It establishes a streaming connection with the backend service and forwards
// the response chunks to the client in real-time using Server-Sent Events.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk to determine success or failure before setting headers
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel(nil)
return
}
// Success! Commit to streaming headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
flusher.Flush()
// Continue streaming the rest
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
}
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
// It converts completions request to chat completions format, sends to backend,
// then converts the response back to completions format before sending to client.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
// Convert completions request to chat completions format
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
completionsResp := convertChatCompletionsResponseToCompletions(resp)
_, _ = c.Writer.Write(completionsResp)
cliCancel()
}
// handleCompletionsStreamingResponse handles streaming completions responses.
// It converts completions request to chat completions format, streams from backend,
// then converts each response chunk back to completions format before sending to client.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
// Convert completions request to chat completions format
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel(nil)
return
}
// Success! Set headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write the first chunk
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
if converted != nil {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
flusher.Flush()
}
done := make(chan struct{})
var doneOnce sync.Once
stop := func() { doneOnce.Do(func() { close(done) }) }
convertedChan := make(chan []byte)
go func() {
defer close(convertedChan)
for {
select {
case <-done:
return
case chunk, ok := <-dataChan:
if !ok {
return
}
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
if converted == nil {
continue
}
select {
case <-done:
return
case convertedChan <- converted:
}
}
}
}()
h.handleStreamResult(c, flusher, func(err error) {
stop()
cliCancel(err)
}, convertedChan, errChan)
return
}
}
}
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
},
WriteDone: func() {
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
},
})
}
================================================
FILE: sdk/api/handlers/openai/openai_responses_compact_test.go
================================================
package openai
import (
"context"
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
type compactCaptureExecutor struct {
alt string
sourceFormat string
calls int
}
func (e *compactCaptureExecutor) Identifier() string { return "test-provider" }
func (e *compactCaptureExecutor) Execute(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (coreexecutor.Response, error) {
e.calls++
e.alt = opts.Alt
e.sourceFormat = opts.SourceFormat.String()
return coreexecutor.Response{Payload: []byte(`{"ok":true}`)}, nil
}
func (e *compactCaptureExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
return nil, errors.New("not implemented")
}
func (e *compactCaptureExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *compactCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *compactCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func TestOpenAIResponsesCompactRejectsStream(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &compactCaptureExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{ID: "auth1", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.POST("/v1/responses/compact", h.Compact)
req := httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"test-model","stream":true}`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", resp.Code, http.StatusBadRequest)
}
if executor.calls != 0 {
t.Fatalf("executor calls = %d, want 0", executor.calls)
}
}
func TestOpenAIResponsesCompactExecute(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &compactCaptureExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{ID: "auth2", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.POST("/v1/responses/compact", h.Compact)
req := httptest.NewRequest(http.MethodPost, "/v1/responses/compact", strings.NewReader(`{"model":"test-model","input":"hello"}`))
req.Header.Set("Content-Type", "application/json")
resp := httptest.NewRecorder()
router.ServeHTTP(resp, req)
if resp.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", resp.Code, http.StatusOK)
}
if executor.alt != "responses/compact" {
t.Fatalf("alt = %q, want %q", executor.alt, "responses/compact")
}
if executor.sourceFormat != "openai-response" {
t.Fatalf("source format = %q, want %q", executor.sourceFormat, "openai-response")
}
if strings.TrimSpace(resp.Body.String()) != `{"ok":true}` {
t.Fatalf("body = %s", resp.Body.String())
}
}
================================================
FILE: sdk/api/handlers/openai/openai_responses_handlers.go
================================================
// Package openai provides HTTP handlers for OpenAIResponses API endpoints.
// This package implements the OpenAIResponses-compatible API interface, including model listing
// and chat completion functionality. It supports both streaming and non-streaming responses,
// and manages a pool of clients to interact with backend services.
// The handlers translate OpenAIResponses API requests to the appropriate backend format and
// convert responses back to OpenAIResponses-compatible format.
package openai
import (
"bytes"
"context"
"fmt"
"net/http"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// OpenAIResponsesAPIHandler contains the handlers for OpenAIResponses API endpoints.
// It holds a pool of clients to interact with the backend service.
type OpenAIResponsesAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewOpenAIResponsesAPIHandler creates a new OpenAIResponses API handlers instance.
// It takes an BaseAPIHandler instance as input and returns an OpenAIResponsesAPIHandler.
//
// Parameters:
// - apiHandlers: The base API handlers instance
//
// Returns:
// - *OpenAIResponsesAPIHandler: A new OpenAIResponses API handlers instance
func NewOpenAIResponsesAPIHandler(apiHandlers *handlers.BaseAPIHandler) *OpenAIResponsesAPIHandler {
return &OpenAIResponsesAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the identifier for this handler implementation.
func (h *OpenAIResponsesAPIHandler) HandlerType() string {
return OpenaiResponse
}
// Models returns the OpenAIResponses-compatible model metadata supported by this handler.
func (h *OpenAIResponsesAPIHandler) Models() []map[string]any {
// Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("openai")
}
// OpenAIResponsesModels handles the /v1/models endpoint.
// It returns a list of available AI models with their capabilities
// and specifications in OpenAIResponses-compatible format.
func (h *OpenAIResponsesAPIHandler) OpenAIResponsesModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": h.Models(),
})
}
// Responses handles the /v1/responses endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) {
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJSON)
} else {
h.handleNonStreamingResponse(c, rawJSON)
}
}
func (h *OpenAIResponsesAPIHandler) Compact(c *gin.Context) {
rawJSON, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported for compact responses",
Type: "invalid_request_error",
},
})
return
}
if streamResult.Exists() {
if updated, err := sjson.DeleteBytes(rawJSON, "stream"); err == nil {
rawJSON = updated
}
}
c.Header("Content-Type", "application/json")
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "responses/compact")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
// handleNonStreamingResponse handles non-streaming chat completion responses
// for Gemini models. It selects a client from the pool, sends the request, and
// aggregates the response before sending it back to the client in OpenAIResponses format.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
stopKeepAlive := h.StartNonStreamingKeepAlive(c, cliCtx)
resp, upstreamHeaders, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
stopKeepAlive()
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write(resp)
cliCancel()
}
// handleStreamingResponse handles streaming responses for Gemini models.
// It establishes a streaming connection with the backend service and forwards
// the response chunks to the client in real-time using Server-Sent Events.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
// New core execution path
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, upstreamHeaders, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Stream closed without data? Send headers and done.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
cliCancel(nil)
return
}
// Success! Set headers.
setSSEHeaders()
handlers.WriteUpstreamHeaders(c.Writer.Header(), upstreamHeaders)
// Write first chunk logic (matching forwardResponsesStream)
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
// Continue
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
},
WriteDone: func() {
_, _ = c.Writer.Write([]byte("\n"))
},
})
}
================================================
FILE: sdk/api/handlers/openai/openai_responses_handlers_stream_error_test.go
================================================
package openai
import (
"errors"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
gin.SetMode(gin.TestMode)
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
h := NewOpenAIResponsesAPIHandler(base)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
flusher, ok := c.Writer.(http.Flusher)
if !ok {
t.Fatalf("expected gin writer to implement http.Flusher")
}
data := make(chan []byte)
errs := make(chan *interfaces.ErrorMessage, 1)
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
close(errs)
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
body := recorder.Body.String()
if !strings.Contains(body, `"type":"error"`) {
t.Fatalf("expected responses error chunk, got: %q", body)
}
if strings.Contains(body, `"error":{`) {
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
}
}
================================================
FILE: sdk/api/handlers/openai/openai_responses_websocket.go
================================================
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
wsRequestTypeCreate = "response.create"
wsRequestTypeAppend = "response.append"
wsEventTypeError = "error"
wsEventTypeCompleted = "response.completed"
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048
wsBodyLogMaxSize = 64 * 1024
wsBodyLogTruncated = "\n[websocket log truncated]\n"
)
var responsesWebsocketUpgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// ResponsesWebsocket handles websocket requests for /v1/responses.
// It accepts `response.create` and `response.append` requests and streams
// response events back as JSON websocket text messages.
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
if err != nil {
return
}
passthroughSessionID := uuid.NewString()
clientRemoteAddr := ""
if c != nil && c.Request != nil {
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
}
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
var wsTerminateErr error
var wsBodyLog strings.Builder
defer func() {
if wsTerminateErr != nil {
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
} else {
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
}
if h != nil && h.AuthManager != nil {
h.AuthManager.CloseExecutionSession(passthroughSessionID)
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
}
setWebsocketRequestBody(c, wsBodyLog.String())
if errClose := conn.Close(); errClose != nil {
log.Warnf("responses websocket: close connection error: %v", errClose)
}
}()
var lastRequest []byte
lastResponseOutput := []byte("[]")
pinnedAuthID := ""
for {
msgType, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
wsTerminateErr = errReadMessage
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
} else {
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
}
return
}
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
continue
}
// log.Infof(
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
// passthroughSessionID,
// msgType,
// websocketPayloadEventType(payload),
// websocketPayloadPreview(payload),
// )
appendWebsocketEvent(&wsBodyLog, "request", payload)
allowIncrementalInputWithPreviousResponseID := false
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
} else {
requestModelName := strings.TrimSpace(gjson.GetBytes(payload, "model").String())
if requestModelName == "" {
requestModelName = strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
}
allowIncrementalInputWithPreviousResponseID = h.websocketUpstreamSupportsIncrementalInputForModel(requestModelName)
}
var requestJSON []byte
var updatedLastRequest []byte
var errMsg *interfaces.ErrorMessage
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
payload,
lastRequest,
lastResponseOutput,
allowIncrementalInputWithPreviousResponseID,
)
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
passthroughSessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
passthroughSessionID,
websocketPayloadEventType(errorPayload),
errWrite,
)
return
}
continue
}
if shouldHandleResponsesWebsocketPrewarmLocally(payload, lastRequest, allowIncrementalInputWithPreviousResponseID) {
if updated, errDelete := sjson.DeleteBytes(requestJSON, "generate"); errDelete == nil {
requestJSON = updated
}
if updated, errDelete := sjson.DeleteBytes(updatedLastRequest, "generate"); errDelete == nil {
updatedLastRequest = updated
}
lastRequest = updatedLastRequest
lastResponseOutput = []byte("[]")
if errWrite := writeResponsesWebsocketSyntheticPrewarm(c, conn, requestJSON, &wsBodyLog, passthroughSessionID); errWrite != nil {
wsTerminateErr = errWrite
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errWrite.Error()))
return
}
continue
}
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
if pinnedAuthID != "" {
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
} else {
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
authID = strings.TrimSpace(authID)
if authID == "" || h == nil || h.AuthManager == nil {
return
}
selectedAuth, ok := h.AuthManager.GetByID(authID)
if !ok || selectedAuth == nil {
return
}
if websocketUpstreamSupportsIncrementalInput(selectedAuth.Attributes, selectedAuth.Metadata) {
pinnedAuthID = authID
}
})
}
dataChan, _, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
lastResponseOutput = completedOutput
}
}
func websocketUpgradeHeaders(req *http.Request) http.Header {
headers := http.Header{}
if req == nil {
return headers
}
// Keep the same sticky turn-state across reconnects when provided by the client.
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
if turnState != "" {
headers.Set(wsTurnStateHeader, turnState)
}
return headers
}
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
}
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
switch requestType {
case wsRequestTypeCreate:
// log.Infof("responses websocket: response.create request")
if len(lastRequest) == 0 {
return normalizeResponseCreateRequest(rawJSON)
}
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
case wsRequestTypeAppend:
// log.Infof("responses websocket: response.append request")
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
default:
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
}
}
}
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
if !gjson.GetBytes(normalized, "input").Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
}
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
if modelName == "" {
return nil, nil, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("missing model in response.create request"),
}
}
return normalized, bytes.Clone(normalized), nil
}
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
if len(lastRequest) == 0 {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("websocket request received before response.create"),
}
}
nextInput := gjson.GetBytes(rawJSON, "input")
if !nextInput.Exists() || !nextInput.IsArray() {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("websocket request requires array field: input"),
}
}
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
// Do not expand it into a full input transcript; upstream expects the incremental payload.
if allowIncrementalInputWithPreviousResponseID {
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return normalized, bytes.Clone(normalized), nil
}
}
existingInput := gjson.GetBytes(lastRequest, "input")
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
}
}
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid request input: %w", errMerge),
}
}
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
var errSet error
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
if errSet != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
}
}
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return normalized, bytes.Clone(normalized), nil
}
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed
}
}
}
if len(metadata) == 0 {
return false
}
raw, ok := metadata["websockets"]
if !ok || raw == nil {
return false
}
switch value := raw.(type) {
case bool:
return value
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
if errParse == nil {
return parsed
}
default:
}
return false
}
func (h *OpenAIResponsesAPIHandler) websocketUpstreamSupportsIncrementalInputForModel(modelName string) bool {
if h == nil || h.AuthManager == nil {
return false
}
resolvedModelName := modelName
initialSuffix := thinking.ParseSuffix(modelName)
if initialSuffix.ModelName == "auto" {
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
if initialSuffix.HasSuffix {
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
} else {
resolvedModelName = resolvedBase
}
} else {
resolvedModelName = util.ResolveAutoModel(modelName)
}
parsed := thinking.ParseSuffix(resolvedModelName)
baseModel := strings.TrimSpace(parsed.ModelName)
providers := util.GetProviderName(baseModel)
if len(providers) == 0 && baseModel != resolvedModelName {
providers = util.GetProviderName(resolvedModelName)
}
if len(providers) == 0 {
return false
}
providerSet := make(map[string]struct{}, len(providers))
for i := 0; i < len(providers); i++ {
providerKey := strings.TrimSpace(strings.ToLower(providers[i]))
if providerKey == "" {
continue
}
providerSet[providerKey] = struct{}{}
}
if len(providerSet) == 0 {
return false
}
modelKey := baseModel
if modelKey == "" {
modelKey = strings.TrimSpace(resolvedModelName)
}
registryRef := registry.GetGlobalRegistry()
now := time.Now()
auths := h.AuthManager.List()
for i := 0; i < len(auths); i++ {
auth := auths[i]
if auth == nil {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(auth.ID, modelKey) {
continue
}
if !responsesWebsocketAuthAvailableForModel(auth, modelKey, now) {
continue
}
if websocketUpstreamSupportsIncrementalInput(auth.Attributes, auth.Metadata) {
return true
}
}
return false
}
func responsesWebsocketAuthAvailableForModel(auth *coreauth.Auth, modelName string, now time.Time) bool {
if auth == nil {
return false
}
if auth.Disabled || auth.Status == coreauth.StatusDisabled {
return false
}
if modelName != "" && len(auth.ModelStates) > 0 {
state, ok := auth.ModelStates[modelName]
if (!ok || state == nil) && modelName != "" {
baseModel := strings.TrimSpace(thinking.ParseSuffix(modelName).ModelName)
if baseModel != "" && baseModel != modelName {
state, ok = auth.ModelStates[baseModel]
}
}
if ok && state != nil {
if state.Status == coreauth.StatusDisabled {
return false
}
if state.Unavailable && !state.NextRetryAfter.IsZero() && state.NextRetryAfter.After(now) {
return false
}
return true
}
}
if auth.Unavailable && !auth.NextRetryAfter.IsZero() && auth.NextRetryAfter.After(now) {
return false
}
return true
}
func shouldHandleResponsesWebsocketPrewarmLocally(rawJSON []byte, lastRequest []byte, allowIncrementalInputWithPreviousResponseID bool) bool {
if allowIncrementalInputWithPreviousResponseID || len(lastRequest) != 0 {
return false
}
if strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) != wsRequestTypeCreate {
return false
}
generateResult := gjson.GetBytes(rawJSON, "generate")
return generateResult.Exists() && !generateResult.Bool()
}
func writeResponsesWebsocketSyntheticPrewarm(
c *gin.Context,
conn *websocket.Conn,
requestJSON []byte,
wsBodyLog *strings.Builder,
sessionID string,
) error {
payloads, errPayloads := syntheticResponsesWebsocketPrewarmPayloads(requestJSON)
if errPayloads != nil {
return errPayloads
}
for i := 0; i < len(payloads); i++ {
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
// websocket.TextMessage,
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(payloads[i]),
errWrite,
)
return errWrite
}
}
return nil
}
func syntheticResponsesWebsocketPrewarmPayloads(requestJSON []byte) ([][]byte, error) {
responseID := "resp_prewarm_" + uuid.NewString()
createdAt := time.Now().Unix()
modelName := strings.TrimSpace(gjson.GetBytes(requestJSON, "model").String())
createdPayload := []byte(`{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`)
var errSet error
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.id", responseID)
if errSet != nil {
return nil, errSet
}
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.created_at", createdAt)
if errSet != nil {
return nil, errSet
}
if modelName != "" {
createdPayload, errSet = sjson.SetBytes(createdPayload, "response.model", modelName)
if errSet != nil {
return nil, errSet
}
}
completedPayload := []byte(`{"type":"response.completed","sequence_number":1,"response":{"id":"","object":"response","created_at":0,"status":"completed","background":false,"error":null,"output":[],"usage":{"input_tokens":0,"output_tokens":0,"total_tokens":0}}}`)
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.id", responseID)
if errSet != nil {
return nil, errSet
}
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.created_at", createdAt)
if errSet != nil {
return nil, errSet
}
if modelName != "" {
completedPayload, errSet = sjson.SetBytes(completedPayload, "response.model", modelName)
if errSet != nil {
return nil, errSet
}
}
return [][]byte{createdPayload, completedPayload}, nil
}
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
existingRaw = strings.TrimSpace(existingRaw)
appendRaw = strings.TrimSpace(appendRaw)
if existingRaw == "" {
existingRaw = "[]"
}
if appendRaw == "" {
appendRaw = "[]"
}
var existing []json.RawMessage
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
return "", err
}
var appendItems []json.RawMessage
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
return "", err
}
merged := append(existing, appendItems...)
out, err := json.Marshal(merged)
if err != nil {
return "", err
}
return string(out), nil
}
func normalizeJSONArrayRaw(raw []byte) string {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" {
return "[]"
}
result := gjson.Parse(trimmed)
if result.Type == gjson.JSON && result.IsArray() {
return trimmed
}
return "[]"
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
c *gin.Context,
conn *websocket.Conn,
cancel handlers.APIHandlerCancelFunc,
data <-chan []byte,
errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder,
sessionID string,
) ([]byte, error) {
completed := false
completedOutput := []byte("[]")
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
case errMsg, ok := <-errs:
if !ok {
errs = nil
continue
}
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
// log.Warnf(
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
// sessionID,
// websocketPayloadEventType(errorPayload),
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
}
}
if errMsg != nil {
cancel(errMsg.Error)
} else {
cancel(nil)
}
return completedOutput, nil
case chunk, ok := <-data:
if !ok {
if !completed {
errMsg := &interfaces.ErrorMessage{
StatusCode: http.StatusRequestTimeout,
Error: fmt.Errorf("stream closed before response.completed"),
}
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(errorPayload),
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
}
cancel(errMsg.Error)
return completedOutput, nil
}
cancel(nil)
return completedOutput, nil
}
payloads := websocketJSONPayloadsFromChunk(chunk)
for i := range payloads {
eventType := gjson.GetBytes(payloads[i], "type").String()
if eventType == wsEventTypeCompleted {
completed = true
completedOutput = responseCompletedOutputFromPayload(payloads[i])
}
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
// websocket.TextMessage,
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(payloads[i]),
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
}
}
}
}
}
func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
return bytes.Clone([]byte(output.Raw))
}
return []byte("[]")
}
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
payloads := make([][]byte, 0, 2)
lines := bytes.Split(chunk, []byte("\n"))
for i := range lines {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
continue
}
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimSpace(line[len("data:"):])
}
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
continue
}
if json.Valid(line) {
payloads = append(payloads, bytes.Clone(line))
}
}
if len(payloads) > 0 {
return payloads
}
trimmed := bytes.TrimSpace(chunk)
if bytes.HasPrefix(trimmed, []byte("data:")) {
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
}
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
payloads = append(payloads, bytes.Clone(trimmed))
}
return payloads
}
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
status := http.StatusInternalServerError
errText := http.StatusText(status)
if errMsg != nil {
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
errText = http.StatusText(status)
}
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
errText = errMsg.Error.Error()
}
}
body := handlers.BuildErrorResponseBody(status, errText)
payload := []byte(`{}`)
var errSet error
payload, errSet = sjson.SetBytes(payload, "type", wsEventTypeError)
if errSet != nil {
return nil, errSet
}
payload, errSet = sjson.SetBytes(payload, "status", status)
if errSet != nil {
return nil, errSet
}
if errMsg != nil && errMsg.Addon != nil {
headers := []byte(`{}`)
hasHeaders := false
for key, values := range errMsg.Addon {
if len(values) == 0 {
continue
}
headerPath := strings.ReplaceAll(strings.ReplaceAll(key, `\\`, `\\\\`), ".", `\\.`)
headers, errSet = sjson.SetBytes(headers, headerPath, values[0])
if errSet != nil {
return nil, errSet
}
hasHeaders = true
}
if hasHeaders {
payload, errSet = sjson.SetRawBytes(payload, "headers", headers)
if errSet != nil {
return nil, errSet
}
}
}
if len(body) > 0 && json.Valid(body) {
errorNode := gjson.GetBytes(body, "error")
if errorNode.Exists() {
payload, errSet = sjson.SetRawBytes(payload, "error", []byte(errorNode.Raw))
} else {
payload, errSet = sjson.SetRawBytes(payload, "error", body)
}
if errSet != nil {
return nil, errSet
}
}
if !gjson.GetBytes(payload, "error").Exists() {
payload, errSet = sjson.SetBytes(payload, "error.type", "server_error")
if errSet != nil {
return nil, errSet
}
payload, errSet = sjson.SetBytes(payload, "error.message", errText)
if errSet != nil {
return nil, errSet
}
}
return payload, conn.WriteMessage(websocket.TextMessage, payload)
}
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
if builder == nil {
return
}
if builder.Len() >= wsBodyLogMaxSize {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
if !appendWebsocketLogString(builder, "\n") {
return
}
}
if !appendWebsocketLogString(builder, "websocket.") {
return
}
if !appendWebsocketLogString(builder, eventType) {
return
}
if !appendWebsocketLogString(builder, "\n") {
return
}
if !appendWebsocketLogBytes(builder, trimmedPayload, len(wsBodyLogTruncated)) {
appendWebsocketLogString(builder, wsBodyLogTruncated)
return
}
appendWebsocketLogString(builder, "\n")
}
func appendWebsocketLogString(builder *strings.Builder, value string) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.WriteString(value)
return true
}
builder.WriteString(value[:remaining])
return false
}
func appendWebsocketLogBytes(builder *strings.Builder, value []byte, reserveForSuffix int) bool {
if builder == nil {
return false
}
remaining := wsBodyLogMaxSize - builder.Len()
if remaining <= 0 {
return false
}
if len(value) <= remaining {
builder.Write(value)
return true
}
limit := remaining - reserveForSuffix
if limit < 0 {
limit = 0
}
if limit > len(value) {
limit = len(value)
}
builder.Write(value[:limit])
return false
}
func websocketPayloadEventType(payload []byte) string {
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if eventType == "" {
return "-"
}
return eventType
}
func websocketPayloadPreview(payload []byte) string {
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return ""
}
preview := trimmedPayload
if len(preview) > wsPayloadLogMaxSize {
preview = preview[:wsPayloadLogMaxSize]
}
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
if len(trimmedPayload) > wsPayloadLogMaxSize {
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
}
return previewText
}
func setWebsocketRequestBody(c *gin.Context, body string) {
if c == nil {
return
}
trimmedBody := strings.TrimSpace(body)
if trimmedBody == "" {
return
}
c.Set(wsRequestBodyKey, []byte(trimmedBody))
}
func markAPIResponseTimestamp(c *gin.Context) {
if c == nil {
return
}
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
return
}
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
}
================================================
FILE: sdk/api/handlers/openai/openai_responses_websocket_test.go
================================================
package openai
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"net/http/httptest"
"strings"
"sync"
"testing"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/tidwall/gjson"
)
type websocketCaptureExecutor struct {
streamCalls int
payloads [][]byte
}
type orderedWebsocketSelector struct {
mu sync.Mutex
order []string
cursor int
}
func (s *orderedWebsocketSelector) Pick(_ context.Context, _ string, _ string, _ coreexecutor.Options, auths []*coreauth.Auth) (*coreauth.Auth, error) {
s.mu.Lock()
defer s.mu.Unlock()
if len(auths) == 0 {
return nil, errors.New("no auth available")
}
for len(s.order) > 0 && s.cursor < len(s.order) {
authID := strings.TrimSpace(s.order[s.cursor])
s.cursor++
for _, auth := range auths {
if auth != nil && auth.ID == authID {
return auth, nil
}
}
}
for _, auth := range auths {
if auth != nil {
return auth, nil
}
}
return nil, errors.New("no auth available")
}
type websocketAuthCaptureExecutor struct {
mu sync.Mutex
authIDs []string
}
func (e *websocketAuthCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketAuthCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketAuthCaptureExecutor) ExecuteStream(_ context.Context, auth *coreauth.Auth, _ coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.mu.Lock()
if auth != nil {
e.authIDs = append(e.authIDs, auth.ID)
}
e.mu.Unlock()
chunks := make(chan coreexecutor.StreamChunk, 1)
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
close(chunks)
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
func (e *websocketAuthCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *websocketAuthCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketAuthCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func (e *websocketAuthCaptureExecutor) AuthIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
return append([]string(nil), e.authIDs...)
}
func (e *websocketCaptureExecutor) Identifier() string { return "test-provider" }
func (e *websocketCaptureExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCaptureExecutor) ExecuteStream(_ context.Context, _ *coreauth.Auth, req coreexecutor.Request, _ coreexecutor.Options) (*coreexecutor.StreamResult, error) {
e.streamCalls++
e.payloads = append(e.payloads, bytes.Clone(req.Payload))
chunks := make(chan coreexecutor.StreamChunk, 1)
chunks <- coreexecutor.StreamChunk{Payload: []byte(`{"type":"response.completed","response":{"id":"resp-upstream","output":[{"type":"message","id":"out-1"}]}}`)}
close(chunks)
return &coreexecutor.StreamResult{Chunks: chunks}, nil
}
func (e *websocketCaptureExecutor) Refresh(_ context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *websocketCaptureExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, errors.New("not implemented")
}
func (e *websocketCaptureExecutor) HttpRequest(context.Context, *coreauth.Auth, *http.Request) (*http.Response, error) {
return nil, errors.New("not implemented")
}
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
normalized, last, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized create request must not include type field")
}
if !gjson.GetBytes(normalized, "stream").Bool() {
t.Fatalf("normalized create request must force stream=true")
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
if !bytes.Equal(last, normalized) {
t.Fatalf("last request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestCreateWithHistory(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized subsequent create request must not include type field")
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 4 {
t.Fatalf("merged input len = %d, want 4", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "fc-1" ||
input[2].Get("id").String() != "assistant-1" ||
input[3].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"instructions":"be helpful","input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized request must not include type field")
}
if gjson.GetBytes(normalized, "previous_response_id").String() != "resp-1" {
t.Fatalf("previous_response_id must be preserved in incremental mode")
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 1 {
t.Fatalf("incremental input len = %d, want 1", len(input))
}
if input[0].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected incremental input item id: %s", input[0].Get("id").String())
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
if gjson.GetBytes(normalized, "instructions").String() != "be helpful" {
t.Fatalf("unexpected instructions: %s", gjson.GetBytes(normalized, "instructions").String())
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncrementalDisabled(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
t.Fatalf("previous_response_id must be removed when incremental mode is disabled")
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 4 {
t.Fatalf("merged input len = %d, want 4", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "fc-1" ||
input[2].Get("id").String() != "assistant-1" ||
input[3].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestAppend(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"message","id":"assistant-1"},
{"type":"function_call_output","id":"tool-out-1"}
]`)
raw := []byte(`{"type":"response.append","input":[{"type":"message","id":"msg-2"},{"type":"message","id":"msg-3"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 5 {
t.Fatalf("merged input len = %d, want 5", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "assistant-1" ||
input[2].Get("id").String() != "tool-out-1" ||
input[3].Get("id").String() != "msg-2" ||
input[4].Get("id").String() != "msg-3" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized append request")
}
}
func TestNormalizeResponsesWebsocketRequestAppendWithoutCreate(t *testing.T) {
raw := []byte(`{"type":"response.append","input":[]}`)
_, _, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
if errMsg == nil {
t.Fatalf("expected error for append without previous request")
}
if errMsg.StatusCode != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusBadRequest)
}
}
func TestWebsocketJSONPayloadsFromChunk(t *testing.T) {
chunk := []byte("event: response.created\n\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\ndata: [DONE]\n")
payloads := websocketJSONPayloadsFromChunk(chunk)
if len(payloads) != 1 {
t.Fatalf("payloads len = %d, want 1", len(payloads))
}
if gjson.GetBytes(payloads[0], "type").String() != "response.created" {
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
}
}
func TestWebsocketJSONPayloadsFromPlainJSONChunk(t *testing.T) {
chunk := []byte(`{"type":"response.completed","response":{"id":"resp-1"}}`)
payloads := websocketJSONPayloadsFromChunk(chunk)
if len(payloads) != 1 {
t.Fatalf("payloads len = %d, want 1", len(payloads))
}
if gjson.GetBytes(payloads[0], "type").String() != "response.completed" {
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
}
}
func TestResponseCompletedOutputFromPayload(t *testing.T) {
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"message","id":"out-1"}]}}`)
output := responseCompletedOutputFromPayload(payload)
items := gjson.ParseBytes(output).Array()
if len(items) != 1 {
t.Fatalf("output len = %d, want 1", len(items))
}
if items[0].Get("id").String() != "out-1" {
t.Fatalf("unexpected output id: %s", items[0].Get("id").String())
}
}
func TestAppendWebsocketEvent(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"))
appendWebsocketEvent(&builder, "response", []byte("{\"type\":\"response.created\"}"))
got := builder.String()
if !strings.Contains(got, "websocket.request\n{\"type\":\"response.create\"}\n") {
t.Fatalf("request event not found in body: %s", got)
}
if !strings.Contains(got, "websocket.response\n{\"type\":\"response.created\"}\n") {
t.Fatalf("response event not found in body: %s", got)
}
}
func TestAppendWebsocketEventTruncatesAtLimit(t *testing.T) {
var builder strings.Builder
payload := bytes.Repeat([]byte("x"), wsBodyLogMaxSize)
appendWebsocketEvent(&builder, "request", payload)
got := builder.String()
if len(got) > wsBodyLogMaxSize {
t.Fatalf("body log len = %d, want <= %d", len(got), wsBodyLogMaxSize)
}
if !strings.Contains(got, wsBodyLogTruncated) {
t.Fatalf("expected truncation marker in body log")
}
}
func TestAppendWebsocketEventNoGrowthAfterLimit(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", bytes.Repeat([]byte("x"), wsBodyLogMaxSize))
initial := builder.String()
appendWebsocketEvent(&builder, "response", []byte(`{"type":"response.completed"}`))
if builder.String() != initial {
t.Fatalf("builder grew after reaching limit")
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
setWebsocketRequestBody(c, " \n ")
if _, exists := c.Get(wsRequestBodyKey); exists {
t.Fatalf("request body key should not be set for empty body")
}
setWebsocketRequestBody(c, "event body")
value, exists := c.Get(wsRequestBodyKey)
if !exists {
t.Fatalf("request body key not set")
}
bodyBytes, ok := value.([]byte)
if !ok {
t.Fatalf("request body key type mismatch")
}
if string(bodyBytes) != "event body" {
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
}
}
func TestForwardResponsesWebsocketPreservesCompletedEvent(t *testing.T) {
gin.SetMode(gin.TestMode)
serverErrCh := make(chan error, 1)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
conn, err := responsesWebsocketUpgrader.Upgrade(w, r, nil)
if err != nil {
serverErrCh <- err
return
}
defer func() {
errClose := conn.Close()
if errClose != nil {
serverErrCh <- errClose
}
}()
ctx, _ := gin.CreateTestContext(httptest.NewRecorder())
ctx.Request = r
data := make(chan []byte, 1)
errCh := make(chan *interfaces.ErrorMessage)
data <- []byte("data: {\"type\":\"response.completed\",\"response\":{\"id\":\"resp-1\",\"output\":[{\"type\":\"message\",\"id\":\"out-1\"}]}}\n\n")
close(data)
close(errCh)
var bodyLog strings.Builder
completedOutput, err := (*OpenAIResponsesAPIHandler)(nil).forwardResponsesWebsocket(
ctx,
conn,
func(...interface{}) {},
data,
errCh,
&bodyLog,
"session-1",
)
if err != nil {
serverErrCh <- err
return
}
if gjson.GetBytes(completedOutput, "0.id").String() != "out-1" {
serverErrCh <- errors.New("completed output not captured")
return
}
serverErrCh <- nil
}))
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http")
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
errClose := conn.Close()
if errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read websocket message: %v", errReadMessage)
}
if gjson.GetBytes(payload, "type").String() != wsEventTypeCompleted {
t.Fatalf("payload type = %s, want %s", gjson.GetBytes(payload, "type").String(), wsEventTypeCompleted)
}
if strings.Contains(string(payload), "response.done") {
t.Fatalf("payload unexpectedly rewrote completed event: %s", payload)
}
if errServer := <-serverErrCh; errServer != nil {
t.Fatalf("server error: %v", errServer)
}
}
func TestWebsocketUpstreamSupportsIncrementalInputForModel(t *testing.T) {
manager := coreauth.NewManager(nil, nil, nil)
auth := &coreauth.Auth{
ID: "auth-ws",
Provider: "test-provider",
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
if !h.websocketUpstreamSupportsIncrementalInputForModel("test-model") {
t.Fatalf("expected websocket-capable upstream for test-model")
}
}
func TestResponsesWebsocketPrewarmHandledLocallyForSSEUpstream(t *testing.T) {
gin.SetMode(gin.TestMode)
executor := &websocketCaptureExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), auth); err != nil {
t.Fatalf("Register auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth.ID, auth.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
errClose := conn.Close()
if errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
errWrite := conn.WriteMessage(websocket.TextMessage, []byte(`{"type":"response.create","model":"test-model","generate":false}`))
if errWrite != nil {
t.Fatalf("write prewarm websocket message: %v", errWrite)
}
_, createdPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read prewarm created message: %v", errReadMessage)
}
if gjson.GetBytes(createdPayload, "type").String() != "response.created" {
t.Fatalf("created payload type = %s, want response.created", gjson.GetBytes(createdPayload, "type").String())
}
prewarmResponseID := gjson.GetBytes(createdPayload, "response.id").String()
if prewarmResponseID == "" {
t.Fatalf("prewarm response id is empty")
}
if executor.streamCalls != 0 {
t.Fatalf("stream calls after prewarm = %d, want 0", executor.streamCalls)
}
_, completedPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read prewarm completed message: %v", errReadMessage)
}
if gjson.GetBytes(completedPayload, "type").String() != wsEventTypeCompleted {
t.Fatalf("completed payload type = %s, want %s", gjson.GetBytes(completedPayload, "type").String(), wsEventTypeCompleted)
}
if gjson.GetBytes(completedPayload, "response.id").String() != prewarmResponseID {
t.Fatalf("completed response id = %s, want %s", gjson.GetBytes(completedPayload, "response.id").String(), prewarmResponseID)
}
if gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int() != 0 {
t.Fatalf("prewarm total tokens = %d, want 0", gjson.GetBytes(completedPayload, "response.usage.total_tokens").Int())
}
secondRequest := fmt.Sprintf(`{"type":"response.create","previous_response_id":%q,"input":[{"type":"message","id":"msg-1"}]}`, prewarmResponseID)
errWrite = conn.WriteMessage(websocket.TextMessage, []byte(secondRequest))
if errWrite != nil {
t.Fatalf("write follow-up websocket message: %v", errWrite)
}
_, upstreamPayload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read upstream completed message: %v", errReadMessage)
}
if gjson.GetBytes(upstreamPayload, "type").String() != wsEventTypeCompleted {
t.Fatalf("upstream payload type = %s, want %s", gjson.GetBytes(upstreamPayload, "type").String(), wsEventTypeCompleted)
}
if executor.streamCalls != 1 {
t.Fatalf("stream calls after follow-up = %d, want 1", executor.streamCalls)
}
if len(executor.payloads) != 1 {
t.Fatalf("captured upstream payloads = %d, want 1", len(executor.payloads))
}
forwarded := executor.payloads[0]
if gjson.GetBytes(forwarded, "previous_response_id").Exists() {
t.Fatalf("previous_response_id leaked upstream: %s", forwarded)
}
if gjson.GetBytes(forwarded, "generate").Exists() {
t.Fatalf("generate leaked upstream: %s", forwarded)
}
if gjson.GetBytes(forwarded, "model").String() != "test-model" {
t.Fatalf("forwarded model = %s, want test-model", gjson.GetBytes(forwarded, "model").String())
}
input := gjson.GetBytes(forwarded, "input").Array()
if len(input) != 1 || input[0].Get("id").String() != "msg-1" {
t.Fatalf("unexpected forwarded input: %s", forwarded)
}
}
func TestResponsesWebsocketPinsOnlyWebsocketCapableAuth(t *testing.T) {
gin.SetMode(gin.TestMode)
selector := &orderedWebsocketSelector{order: []string{"auth-sse", "auth-ws"}}
executor := &websocketAuthCaptureExecutor{}
manager := coreauth.NewManager(nil, selector, nil)
manager.RegisterExecutor(executor)
authSSE := &coreauth.Auth{ID: "auth-sse", Provider: executor.Identifier(), Status: coreauth.StatusActive}
if _, err := manager.Register(context.Background(), authSSE); err != nil {
t.Fatalf("Register SSE auth: %v", err)
}
authWS := &coreauth.Auth{
ID: "auth-ws",
Provider: executor.Identifier(),
Status: coreauth.StatusActive,
Attributes: map[string]string{"websockets": "true"},
}
if _, err := manager.Register(context.Background(), authWS); err != nil {
t.Fatalf("Register websocket auth: %v", err)
}
registry.GetGlobalRegistry().RegisterClient(authSSE.ID, authSSE.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(authWS.ID, authWS.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(authSSE.ID)
registry.GetGlobalRegistry().UnregisterClient(authWS.ID)
})
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
h := NewOpenAIResponsesAPIHandler(base)
router := gin.New()
router.GET("/v1/responses/ws", h.ResponsesWebsocket)
server := httptest.NewServer(router)
defer server.Close()
wsURL := "ws" + strings.TrimPrefix(server.URL, "http") + "/v1/responses/ws"
conn, _, err := websocket.DefaultDialer.Dial(wsURL, nil)
if err != nil {
t.Fatalf("dial websocket: %v", err)
}
defer func() {
if errClose := conn.Close(); errClose != nil {
t.Fatalf("close websocket: %v", errClose)
}
}()
requests := []string{
`{"type":"response.create","model":"test-model","input":[{"type":"message","id":"msg-1"}]}`,
`{"type":"response.create","input":[{"type":"message","id":"msg-2"}]}`,
}
for i := range requests {
if errWrite := conn.WriteMessage(websocket.TextMessage, []byte(requests[i])); errWrite != nil {
t.Fatalf("write websocket message %d: %v", i+1, errWrite)
}
_, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
t.Fatalf("read websocket message %d: %v", i+1, errReadMessage)
}
if got := gjson.GetBytes(payload, "type").String(); got != wsEventTypeCompleted {
t.Fatalf("message %d payload type = %s, want %s", i+1, got, wsEventTypeCompleted)
}
}
if got := executor.AuthIDs(); len(got) != 2 || got[0] != "auth-sse" || got[1] != "auth-ws" {
t.Fatalf("selected auth IDs = %v, want [auth-sse auth-ws]", got)
}
}
================================================
FILE: sdk/api/handlers/openai_responses_stream_error.go
================================================
package handlers
import (
"encoding/json"
"fmt"
"net/http"
"strings"
)
type openAIResponsesStreamErrorChunk struct {
Type string `json:"type"`
Code string `json:"code"`
Message string `json:"message"`
SequenceNumber int `json:"sequence_number"`
}
func openAIResponsesStreamErrorCode(status int) string {
switch status {
case http.StatusUnauthorized:
return "invalid_api_key"
case http.StatusForbidden:
return "insufficient_quota"
case http.StatusTooManyRequests:
return "rate_limit_exceeded"
case http.StatusNotFound:
return "model_not_found"
case http.StatusRequestTimeout:
return "request_timeout"
default:
if status >= http.StatusInternalServerError {
return "internal_server_error"
}
if status >= http.StatusBadRequest {
return "invalid_request_error"
}
return "unknown_error"
}
}
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
//
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
// of chunks that requires a top-level `type` field.
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
if status <= 0 {
status = http.StatusInternalServerError
}
if sequenceNumber < 0 {
sequenceNumber = 0
}
message := strings.TrimSpace(errText)
if message == "" {
message = http.StatusText(status)
}
code := openAIResponsesStreamErrorCode(status)
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
var payload map[string]any
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := payload["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
sequenceNumber = int(v)
}
}
if e, ok := payload["error"].(map[string]any); ok {
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
message = strings.TrimSpace(m)
}
if v, ok := e["code"]; ok && v != nil {
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
code = strings.TrimSpace(c)
} else {
code = strings.TrimSpace(fmt.Sprint(v))
}
}
}
}
}
if strings.TrimSpace(code) == "" {
code = "unknown_error"
}
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: code,
Message: message,
SequenceNumber: sequenceNumber,
})
if err == nil {
return data
}
// Extremely defensive fallback.
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
Type: "error",
Code: "internal_server_error",
Message: message,
SequenceNumber: sequenceNumber,
})
if len(data) > 0 {
return data
}
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
}
================================================
FILE: sdk/api/handlers/openai_responses_stream_error_test.go
================================================
package handlers
import (
"encoding/json"
"net/http"
"testing"
)
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "unexpected EOF" {
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
}
if payload["sequence_number"] != float64(0) {
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
}
}
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
chunk := BuildOpenAIResponsesStreamErrorChunk(
http.StatusInternalServerError,
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
0,
)
var payload map[string]any
if err := json.Unmarshal(chunk, &payload); err != nil {
t.Fatalf("unmarshal: %v", err)
}
if payload["type"] != "error" {
t.Fatalf("type = %v, want %q", payload["type"], "error")
}
if payload["code"] != "internal_server_error" {
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
}
if payload["message"] != "oops" {
t.Fatalf("message = %v, want %q", payload["message"], "oops")
}
}
================================================
FILE: sdk/api/handlers/stream_forwarder.go
================================================
package handlers
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
)
type StreamForwardOptions struct {
// KeepAliveInterval overrides the configured streaming keep-alive interval.
// If nil, the configured default is used. If set to <= 0, keep-alives are disabled.
KeepAliveInterval *time.Duration
// WriteChunk writes a single data chunk to the response body. It should not flush.
WriteChunk func(chunk []byte)
// WriteTerminalError writes an error payload to the response body when streaming fails
// after headers have already been committed. It should not flush.
WriteTerminalError func(errMsg *interfaces.ErrorMessage)
// WriteDone optionally writes a terminal marker when the upstream data channel closes
// without an error (e.g. OpenAI's `[DONE]`). It should not flush.
WriteDone func()
// WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush.
// When nil, a standard SSE comment heartbeat is used.
WriteKeepAlive func()
}
func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) {
if c == nil {
return
}
if cancel == nil {
return
}
writeChunk := opts.WriteChunk
if writeChunk == nil {
writeChunk = func([]byte) {}
}
writeKeepAlive := opts.WriteKeepAlive
if writeKeepAlive == nil {
writeKeepAlive = func() {
_, _ = c.Writer.Write([]byte(": keep-alive\n\n"))
}
}
keepAliveInterval := StreamingKeepAliveInterval(h.Cfg)
if opts.KeepAliveInterval != nil {
keepAliveInterval = *opts.KeepAliveInterval
}
var keepAlive *time.Ticker
var keepAliveC <-chan time.Time
if keepAliveInterval > 0 {
keepAlive = time.NewTicker(keepAliveInterval)
defer keepAlive.Stop()
keepAliveC = keepAlive.C
}
var terminalErr *interfaces.ErrorMessage
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
// Prefer surfacing a terminal error if one is pending.
if terminalErr == nil {
select {
case errMsg, ok := <-errs:
if ok && errMsg != nil {
terminalErr = errMsg
}
default:
}
}
if terminalErr != nil {
if opts.WriteTerminalError != nil {
opts.WriteTerminalError(terminalErr)
}
flusher.Flush()
cancel(terminalErr.Error)
return
}
if opts.WriteDone != nil {
opts.WriteDone()
}
flusher.Flush()
cancel(nil)
return
}
writeChunk(chunk)
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
terminalErr = errMsg
if opts.WriteTerminalError != nil {
opts.WriteTerminalError(errMsg)
flusher.Flush()
}
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-keepAliveC:
writeKeepAlive()
flusher.Flush()
}
}
}
================================================
FILE: sdk/api/management.go
================================================
// Package api exposes helpers for embedding CLIProxyAPI.
//
// It wraps internal management handler types so external projects can integrate
// management endpoints without importing internal packages.
package api
import (
"github.com/gin-gonic/gin"
internalmanagement "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
// ManagementTokenRequester exposes a limited subset of management endpoints for requesting tokens.
type ManagementTokenRequester interface {
RequestAnthropicToken(*gin.Context)
RequestGeminiCLIToken(*gin.Context)
RequestCodexToken(*gin.Context)
RequestAntigravityToken(*gin.Context)
RequestQwenToken(*gin.Context)
RequestKimiToken(*gin.Context)
RequestIFlowToken(*gin.Context)
RequestIFlowCookieToken(*gin.Context)
GetAuthStatus(c *gin.Context)
PostOAuthCallback(c *gin.Context)
}
type managementTokenRequester struct {
handler *internalmanagement.Handler
}
// NewManagementTokenRequester creates a limited management handler exposing only token request endpoints.
func NewManagementTokenRequester(cfg *config.Config, manager *coreauth.Manager) ManagementTokenRequester {
return &managementTokenRequester{
handler: internalmanagement.NewHandlerWithoutConfigFilePath(cfg, manager),
}
}
func (m *managementTokenRequester) RequestAnthropicToken(c *gin.Context) {
m.handler.RequestAnthropicToken(c)
}
func (m *managementTokenRequester) RequestGeminiCLIToken(c *gin.Context) {
m.handler.RequestGeminiCLIToken(c)
}
func (m *managementTokenRequester) RequestCodexToken(c *gin.Context) {
m.handler.RequestCodexToken(c)
}
func (m *managementTokenRequester) RequestAntigravityToken(c *gin.Context) {
m.handler.RequestAntigravityToken(c)
}
func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
m.handler.RequestQwenToken(c)
}
func (m *managementTokenRequester) RequestKimiToken(c *gin.Context) {
m.handler.RequestKimiToken(c)
}
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
m.handler.RequestIFlowToken(c)
}
func (m *managementTokenRequester) RequestIFlowCookieToken(c *gin.Context) {
m.handler.RequestIFlowCookieToken(c)
}
func (m *managementTokenRequester) GetAuthStatus(c *gin.Context) {
m.handler.GetAuthStatus(c)
}
func (m *managementTokenRequester) PostOAuthCallback(c *gin.Context) {
m.handler.PostOAuthCallback(c)
}
================================================
FILE: sdk/api/options.go
================================================
// Package api exposes server option helpers for embedding CLIProxyAPI.
//
// It wraps internal server option types so external projects can configure the embedded
// HTTP server without importing internal packages.
package api
import (
"time"
"github.com/gin-gonic/gin"
internalapi "github.com/router-for-me/CLIProxyAPI/v6/internal/api"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
)
// ServerOption customises HTTP server construction.
type ServerOption = internalapi.ServerOption
// WithMiddleware appends additional Gin middleware during server construction.
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption { return internalapi.WithMiddleware(mw...) }
// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
return internalapi.WithEngineConfigurator(fn)
}
// WithRouterConfigurator appends a callback after default routes are registered.
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
return internalapi.WithRouterConfigurator(fn)
}
// WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests.
func WithLocalManagementPassword(password string) ServerOption {
return internalapi.WithLocalManagementPassword(password)
}
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
return internalapi.WithKeepAliveEndpoint(timeout, onTimeout)
}
// WithRequestLoggerFactory customises request logger creation.
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
return internalapi.WithRequestLoggerFactory(factory)
}
================================================
FILE: sdk/auth/antigravity.go
================================================
package auth
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/antigravity"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// AntigravityAuthenticator implements OAuth login for the antigravity provider.
type AntigravityAuthenticator struct{}
// NewAntigravityAuthenticator constructs a new authenticator instance.
func NewAntigravityAuthenticator() Authenticator { return &AntigravityAuthenticator{} }
// Provider returns the provider key for antigravity.
func (AntigravityAuthenticator) Provider() string { return "antigravity" }
// RefreshLead instructs the manager to refresh five minutes before expiry.
func (AntigravityAuthenticator) RefreshLead() *time.Duration {
return new(5 * time.Minute)
}
// Login launches a local OAuth flow to obtain antigravity tokens and persists them.
func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
callbackPort := antigravity.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
authSvc := antigravity.NewAntigravityAuth(cfg, nil)
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("antigravity: failed to generate state: %w", err)
}
srv, port, cbChan, errServer := startAntigravityCallbackServer(callbackPort)
if errServer != nil {
return nil, fmt.Errorf("antigravity: failed to start callback server: %w", errServer)
}
defer func() {
shutdownCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
_ = srv.Shutdown(shutdownCtx)
}()
redirectURI := fmt.Sprintf("http://localhost:%d/oauth-callback", port)
authURL := authSvc.BuildAuthURL(state, redirectURI)
if !opts.NoBrowser {
fmt.Println("Opening browser for antigravity authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if errOpen := browser.OpenURL(authURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
util.PrintSSHTunnelInstructions(port)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for antigravity authentication callback...")
var cbRes callbackResult
timeoutTimer := time.NewTimer(5 * time.Minute)
defer timeoutTimer.Stop()
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case res := <-cbChan:
cbRes = res
break waitForCallback
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case res := <-cbChan:
cbRes = res
break waitForCallback
default:
}
input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
cbRes = callbackResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
case <-timeoutTimer.C:
return nil, fmt.Errorf("antigravity: authentication timed out")
}
}
if cbRes.Error != "" {
return nil, fmt.Errorf("antigravity: authentication failed: %s", cbRes.Error)
}
if cbRes.State != state {
return nil, fmt.Errorf("antigravity: invalid state")
}
if cbRes.Code == "" {
return nil, fmt.Errorf("antigravity: missing authorization code")
}
tokenResp, errToken := authSvc.ExchangeCodeForTokens(ctx, cbRes.Code, redirectURI)
if errToken != nil {
return nil, fmt.Errorf("antigravity: token exchange failed: %w", errToken)
}
accessToken := strings.TrimSpace(tokenResp.AccessToken)
if accessToken == "" {
return nil, fmt.Errorf("antigravity: token exchange returned empty access token")
}
email, errInfo := authSvc.FetchUserInfo(ctx, accessToken)
if errInfo != nil {
return nil, fmt.Errorf("antigravity: fetch user info failed: %w", errInfo)
}
email = strings.TrimSpace(email)
if email == "" {
return nil, fmt.Errorf("antigravity: empty email returned from user info")
}
// Fetch project ID via loadCodeAssist (same approach as Gemini CLI)
projectID := ""
if accessToken != "" {
fetchedProjectID, errProject := authSvc.FetchProjectID(ctx, accessToken)
if errProject != nil {
log.Warnf("antigravity: failed to fetch project ID: %v", errProject)
} else {
projectID = fetchedProjectID
log.Infof("antigravity: obtained project ID %s", projectID)
}
}
now := time.Now()
metadata := map[string]any{
"type": "antigravity",
"access_token": tokenResp.AccessToken,
"refresh_token": tokenResp.RefreshToken,
"expires_in": tokenResp.ExpiresIn,
"timestamp": now.UnixMilli(),
"expired": now.Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
}
if email != "" {
metadata["email"] = email
}
if projectID != "" {
metadata["project_id"] = projectID
}
fileName := antigravity.CredentialFileName(email)
label := email
if label == "" {
label = "antigravity"
}
fmt.Println("Antigravity authentication successful")
if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID)
}
return &coreauth.Auth{
ID: fileName,
Provider: "antigravity",
FileName: fileName,
Label: label,
Metadata: metadata,
}, nil
}
type callbackResult struct {
Code string
Error string
State string
}
func startAntigravityCallbackServer(port int) (*http.Server, int, <-chan callbackResult, error) {
if port <= 0 {
port = antigravity.CallbackPort
}
addr := fmt.Sprintf(":%d", port)
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, 0, nil, err
}
port = listener.Addr().(*net.TCPAddr).Port
resultCh := make(chan callbackResult, 1)
mux := http.NewServeMux()
mux.HandleFunc("/oauth-callback", func(w http.ResponseWriter, r *http.Request) {
q := r.URL.Query()
res := callbackResult{
Code: strings.TrimSpace(q.Get("code")),
Error: strings.TrimSpace(q.Get("error")),
State: strings.TrimSpace(q.Get("state")),
}
resultCh <- res
if res.Code != "" && res.Error == "" {
_, _ = w.Write([]byte("Login successful You can close this window.
"))
} else {
_, _ = w.Write([]byte("Login failed Please check the CLI output.
"))
}
})
srv := &http.Server{Handler: mux}
go func() {
if errServe := srv.Serve(listener); errServe != nil && !strings.Contains(errServe.Error(), "Server closed") {
log.Warnf("antigravity callback server error: %v", errServe)
}
}()
return srv, port, resultCh, nil
}
// FetchAntigravityProjectID exposes project discovery for external callers.
func FetchAntigravityProjectID(ctx context.Context, accessToken string, httpClient *http.Client) (string, error) {
cfg := &config.Config{}
authSvc := antigravity.NewAntigravityAuth(cfg, httpClient)
return authSvc.FetchProjectID(ctx, accessToken)
}
================================================
FILE: sdk/auth/claude.go
================================================
package auth
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
// legacy client removed
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// ClaudeAuthenticator implements the OAuth login flow for Anthropic Claude accounts.
type ClaudeAuthenticator struct {
CallbackPort int
}
// NewClaudeAuthenticator constructs a Claude authenticator with default settings.
func NewClaudeAuthenticator() *ClaudeAuthenticator {
return &ClaudeAuthenticator{CallbackPort: 54545}
}
func (a *ClaudeAuthenticator) Provider() string {
return "claude"
}
func (a *ClaudeAuthenticator) RefreshLead() *time.Duration {
return new(4 * time.Hour)
}
func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
callbackPort := a.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
pkceCodes, err := claude.GeneratePKCECodes()
if err != nil {
return nil, fmt.Errorf("claude pkce generation failed: %w", err)
}
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("claude state generation failed: %w", err)
}
oauthServer := claude.NewOAuthServer(callbackPort)
if err = oauthServer.Start(); err != nil {
if strings.Contains(err.Error(), "already in use") {
return nil, claude.NewAuthenticationError(claude.ErrPortInUse, err)
}
return nil, claude.NewAuthenticationError(claude.ErrServerStartFailed, err)
}
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if stopErr := oauthServer.Stop(stopCtx); stopErr != nil {
log.Warnf("claude oauth server stop error: %v", stopErr)
}
}()
authSvc := claude.NewClaudeAuth(cfg)
authURL, returnedState, err := authSvc.GenerateAuthURL(state, pkceCodes)
if err != nil {
return nil, fmt.Errorf("claude authorization url generation failed: %w", err)
}
state = returnedState
if !opts.NoBrowser {
fmt.Println("Opening browser for Claude authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for Claude authentication callback...")
callbackCh := make(chan *claude.OAuthResult, 1)
callbackErrCh := make(chan error, 1)
manualDescription := ""
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *claude.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
}
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
}
return nil, err
default:
}
input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
manualDescription = parsed.ErrorDescription
result = &claude.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
}
if result.Error != "" {
return nil, claude.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
}
if result.State != state {
log.Errorf("State mismatch: expected %s, got %s", state, result.State)
return nil, claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("state mismatch"))
}
log.Debug("Claude authorization code received; exchanging for tokens")
log.Debugf("Code: %s, State: %s", result.Code[:min(20, len(result.Code))], state)
authBundle, err := authSvc.ExchangeCodeForTokens(ctx, result.Code, state, pkceCodes)
if err != nil {
log.Errorf("Token exchange failed: %v", err)
return nil, claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, err)
}
tokenStorage := authSvc.CreateTokenStorage(authBundle)
if tokenStorage == nil || tokenStorage.Email == "" {
return nil, fmt.Errorf("claude token storage missing account information")
}
fileName := fmt.Sprintf("claude-%s.json", tokenStorage.Email)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Claude authentication successful")
if authBundle.APIKey != "" {
fmt.Println("Claude API key obtained and stored")
}
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}, nil
}
================================================
FILE: sdk/auth/codex.go
================================================
package auth
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
// legacy client removed
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// CodexAuthenticator implements the OAuth login flow for Codex accounts.
type CodexAuthenticator struct {
CallbackPort int
}
// NewCodexAuthenticator constructs a Codex authenticator with default settings.
func NewCodexAuthenticator() *CodexAuthenticator {
return &CodexAuthenticator{CallbackPort: 1455}
}
func (a *CodexAuthenticator) Provider() string {
return "codex"
}
func (a *CodexAuthenticator) RefreshLead() *time.Duration {
return new(5 * 24 * time.Hour)
}
func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
if shouldUseCodexDeviceFlow(opts) {
return a.loginWithDeviceFlow(ctx, cfg, opts)
}
callbackPort := a.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
pkceCodes, err := codex.GeneratePKCECodes()
if err != nil {
return nil, fmt.Errorf("codex pkce generation failed: %w", err)
}
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("codex state generation failed: %w", err)
}
oauthServer := codex.NewOAuthServer(callbackPort)
if err = oauthServer.Start(); err != nil {
if strings.Contains(err.Error(), "already in use") {
return nil, codex.NewAuthenticationError(codex.ErrPortInUse, err)
}
return nil, codex.NewAuthenticationError(codex.ErrServerStartFailed, err)
}
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if stopErr := oauthServer.Stop(stopCtx); stopErr != nil {
log.Warnf("codex oauth server stop error: %v", stopErr)
}
}()
authSvc := codex.NewCodexAuth(cfg)
authURL, err := authSvc.GenerateAuthURL(state, pkceCodes)
if err != nil {
return nil, fmt.Errorf("codex authorization url generation failed: %w", err)
}
if !opts.NoBrowser {
fmt.Println("Opening browser for Codex authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for Codex authentication callback...")
callbackCh := make(chan *codex.OAuthResult, 1)
callbackErrCh := make(chan error, 1)
manualDescription := ""
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *codex.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
}
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
}
return nil, err
default:
}
input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
manualDescription = parsed.ErrorDescription
result = &codex.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
}
if result.Error != "" {
return nil, codex.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
}
if result.State != state {
return nil, codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("state mismatch"))
}
log.Debug("Codex authorization code received; exchanging for tokens")
authBundle, err := authSvc.ExchangeCodeForTokens(ctx, result.Code, pkceCodes)
if err != nil {
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
}
return a.buildAuthRecord(authSvc, authBundle)
}
================================================
FILE: sdk/auth/codex_device.go
================================================
package auth
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"strconv"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
const (
codexLoginModeMetadataKey = "codex_login_mode"
codexLoginModeDevice = "device"
codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode"
codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token"
codexDeviceVerificationURL = "https://auth.openai.com/codex/device"
codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback"
codexDeviceTimeout = 15 * time.Minute
codexDeviceDefaultPollIntervalSeconds = 5
)
type codexDeviceUserCodeRequest struct {
ClientID string `json:"client_id"`
}
type codexDeviceUserCodeResponse struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
UserCodeAlt string `json:"usercode"`
Interval json.RawMessage `json:"interval"`
}
type codexDeviceTokenRequest struct {
DeviceAuthID string `json:"device_auth_id"`
UserCode string `json:"user_code"`
}
type codexDeviceTokenResponse struct {
AuthorizationCode string `json:"authorization_code"`
CodeVerifier string `json:"code_verifier"`
CodeChallenge string `json:"code_challenge"`
}
func shouldUseCodexDeviceFlow(opts *LoginOptions) bool {
if opts == nil || opts.Metadata == nil {
return false
}
return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice)
}
func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if ctx == nil {
ctx = context.Background()
}
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient)
if err != nil {
return nil, err
}
deviceCode := strings.TrimSpace(userCodeResp.UserCode)
if deviceCode == "" {
deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt)
}
deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID)
if deviceCode == "" || deviceAuthID == "" {
return nil, fmt.Errorf("codex device flow did not return required fields")
}
pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval)
fmt.Println("Starting Codex device authentication...")
fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL)
fmt.Printf("Codex device code: %s\n", deviceCode)
if !opts.NoBrowser {
if !browser.IsAvailable() {
log.Warn("No browser available; please open the device URL manually")
} else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
}
}
tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval)
if err != nil {
return nil, err
}
authCode := strings.TrimSpace(tokenResp.AuthorizationCode)
codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier)
codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge)
if authCode == "" || codeVerifier == "" || codeChallenge == "" {
return nil, fmt.Errorf("codex device flow token response missing required fields")
}
authSvc := codex.NewCodexAuth(cfg)
authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect(
ctx,
authCode,
codexDeviceTokenExchangeRedirectURI,
&codex.PKCECodes{
CodeVerifier: codeVerifier,
CodeChallenge: codeChallenge,
},
)
if err != nil {
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
}
return a.buildAuthRecord(authSvc, authBundle)
}
func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) {
body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID})
if err != nil {
return nil, fmt.Errorf("failed to encode codex device request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create codex device request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to request codex device code: %w", err)
}
defer func() { _ = resp.Body.Close() }()
respBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("failed to read codex device code response: %w", err)
}
if !codexDeviceIsSuccessStatus(resp.StatusCode) {
trimmed := strings.TrimSpace(string(respBody))
if resp.StatusCode == http.StatusNotFound {
return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode)
}
if trimmed == "" {
trimmed = "empty response body"
}
return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed)
}
var parsed codexDeviceUserCodeResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("failed to decode codex device code response: %w", err)
}
return &parsed, nil
}
func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) {
deadline := time.Now().Add(codexDeviceTimeout)
for {
if time.Now().After(deadline) {
return nil, fmt.Errorf("codex device authentication timed out after 15 minutes")
}
body, err := json.Marshal(codexDeviceTokenRequest{
DeviceAuthID: deviceAuthID,
UserCode: userCode,
})
if err != nil {
return nil, fmt.Errorf("failed to encode codex device poll request: %w", err)
}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body))
if err != nil {
return nil, fmt.Errorf("failed to create codex device poll request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to poll codex device token: %w", err)
}
respBody, readErr := io.ReadAll(resp.Body)
_ = resp.Body.Close()
if readErr != nil {
return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr)
}
switch {
case codexDeviceIsSuccessStatus(resp.StatusCode):
var parsed codexDeviceTokenResponse
if err := json.Unmarshal(respBody, &parsed); err != nil {
return nil, fmt.Errorf("failed to decode codex device token response: %w", err)
}
return &parsed, nil
case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound:
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-time.After(interval):
continue
}
default:
trimmed := strings.TrimSpace(string(respBody))
if trimmed == "" {
trimmed = "empty response body"
}
return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed)
}
}
}
func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration {
defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second
if len(raw) == 0 {
return defaultInterval
}
var asString string
if err := json.Unmarshal(raw, &asString); err == nil {
if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 {
return time.Duration(seconds) * time.Second
}
}
var asInt int
if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 {
return time.Duration(asInt) * time.Second
}
return defaultInterval
}
func codexDeviceIsSuccessStatus(code int) bool {
return code >= 200 && code < 300
}
func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) {
tokenStorage := authSvc.CreateTokenStorage(authBundle)
if tokenStorage == nil || tokenStorage.Email == "" {
return nil, fmt.Errorf("codex token storage missing account information")
}
planType := ""
hashAccountID := ""
if tokenStorage.IDToken != "" {
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
if accountID != "" {
digest := sha256.Sum256([]byte(accountID))
hashAccountID = hex.EncodeToString(digest[:])[:8]
}
}
}
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Codex authentication successful")
if authBundle.APIKey != "" {
fmt.Println("Codex API key obtained and stored")
}
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"plan_type": planType,
},
}, nil
}
================================================
FILE: sdk/auth/errors.go
================================================
package auth
import (
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
)
// ProjectSelectionError indicates that the user must choose a specific project ID.
type ProjectSelectionError struct {
Email string
Projects []interfaces.GCPProjectProjects
}
func (e *ProjectSelectionError) Error() string {
if e == nil {
return "cliproxy auth: project selection required"
}
return fmt.Sprintf("cliproxy auth: project selection required for %s", e.Email)
}
// ProjectsDisplay returns the projects list for caller presentation.
func (e *ProjectSelectionError) ProjectsDisplay() []interfaces.GCPProjectProjects {
if e == nil {
return nil
}
return e.Projects
}
// EmailRequiredError indicates that the calling context must provide an email or alias.
type EmailRequiredError struct {
Prompt string
}
func (e *EmailRequiredError) Error() string {
if e == nil || e.Prompt == "" {
return "cliproxy auth: email is required"
}
return e.Prompt
}
================================================
FILE: sdk/auth/filestore.go
================================================
package auth
import (
"context"
"encoding/json"
"fmt"
"io"
"io/fs"
"net/http"
"net/url"
"os"
"path/filepath"
"runtime"
"strings"
"sync"
"time"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// FileTokenStore persists token records and auth metadata using the filesystem as backing storage.
type FileTokenStore struct {
mu sync.Mutex
dirLock sync.RWMutex
baseDir string
}
// NewFileTokenStore creates a token store that saves credentials to disk through the
// TokenStorage implementation embedded in the token record.
func NewFileTokenStore() *FileTokenStore {
return &FileTokenStore{}
}
// SetBaseDir updates the default directory used for auth JSON persistence when no explicit path is provided.
func (s *FileTokenStore) SetBaseDir(dir string) {
s.dirLock.Lock()
s.baseDir = strings.TrimSpace(dir)
s.dirLock.Unlock()
}
// Save persists token storage and metadata to the resolved auth file path.
func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (string, error) {
if auth == nil {
return "", fmt.Errorf("auth filestore: auth is nil")
}
path, err := s.resolveAuthPath(auth)
if err != nil {
return "", err
}
if path == "" {
return "", fmt.Errorf("auth filestore: missing file path attribute for %s", auth.ID)
}
if auth.Disabled {
if _, statErr := os.Stat(path); os.IsNotExist(statErr) {
return "", nil
}
}
s.mu.Lock()
defer s.mu.Unlock()
if err = os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
}
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
type metadataSetter interface {
SetMetadata(map[string]any)
}
switch {
case auth.Storage != nil:
if setter, ok := auth.Storage.(metadataSetter); ok {
setter.SetMetadata(auth.Metadata)
}
if err = auth.Storage.SaveTokenToFile(path); err != nil {
return "", err
}
case auth.Metadata != nil:
auth.Metadata["disabled"] = auth.Disabled
raw, errMarshal := json.Marshal(auth.Metadata)
if errMarshal != nil {
return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal)
}
if existing, errRead := os.ReadFile(path); errRead == nil {
if jsonEqual(existing, raw) {
return path, nil
}
file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600)
if errOpen != nil {
return "", fmt.Errorf("auth filestore: open existing failed: %w", errOpen)
}
if _, errWrite := file.Write(raw); errWrite != nil {
_ = file.Close()
return "", fmt.Errorf("auth filestore: write existing failed: %w", errWrite)
}
if errClose := file.Close(); errClose != nil {
return "", fmt.Errorf("auth filestore: close existing failed: %w", errClose)
}
return path, nil
} else if !os.IsNotExist(errRead) {
return "", fmt.Errorf("auth filestore: read existing failed: %w", errRead)
}
if errWrite := os.WriteFile(path, raw, 0o600); errWrite != nil {
return "", fmt.Errorf("auth filestore: write file failed: %w", errWrite)
}
default:
return "", fmt.Errorf("auth filestore: nothing to persist for %s", auth.ID)
}
if auth.Attributes == nil {
auth.Attributes = make(map[string]string)
}
auth.Attributes["path"] = path
if strings.TrimSpace(auth.FileName) == "" {
auth.FileName = auth.ID
}
return path, nil
}
// List enumerates all auth JSON files under the configured directory.
func (s *FileTokenStore) List(ctx context.Context) ([]*cliproxyauth.Auth, error) {
dir := s.baseDirSnapshot()
if dir == "" {
return nil, fmt.Errorf("auth filestore: directory not configured")
}
entries := make([]*cliproxyauth.Auth, 0)
err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, walkErr error) error {
if walkErr != nil {
return walkErr
}
if d.IsDir() {
return nil
}
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
return nil
}
auth, err := s.readAuthFile(path, dir)
if err != nil {
return nil
}
if auth != nil {
entries = append(entries, auth)
}
return nil
})
if err != nil {
return nil, err
}
return entries, nil
}
// Delete removes the auth file.
func (s *FileTokenStore) Delete(ctx context.Context, id string) error {
id = strings.TrimSpace(id)
if id == "" {
return fmt.Errorf("auth filestore: id is empty")
}
path, err := s.resolveDeletePath(id)
if err != nil {
return err
}
if err = os.Remove(path); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("auth filestore: delete failed: %w", err)
}
return nil
}
func (s *FileTokenStore) resolveDeletePath(id string) (string, error) {
if strings.ContainsRune(id, os.PathSeparator) || filepath.IsAbs(id) {
return id, nil
}
dir := s.baseDirSnapshot()
if dir == "" {
return "", fmt.Errorf("auth filestore: directory not configured")
}
return filepath.Join(dir, id), nil
}
func (s *FileTokenStore) readAuthFile(path, baseDir string) (*cliproxyauth.Auth, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("read file: %w", err)
}
if len(data) == 0 {
return nil, nil
}
metadata := make(map[string]any)
if err = json.Unmarshal(data, &metadata); err != nil {
return nil, fmt.Errorf("unmarshal auth json: %w", err)
}
provider, _ := metadata["type"].(string)
if provider == "" {
provider = "unknown"
}
if provider == "antigravity" || provider == "gemini" {
projectID := ""
if pid, ok := metadata["project_id"].(string); ok {
projectID = strings.TrimSpace(pid)
}
if projectID == "" {
accessToken := extractAccessToken(metadata)
// For gemini type, the stored access_token is likely expired (~1h lifetime).
// Refresh it using the long-lived refresh_token before querying.
if provider == "gemini" {
if tokenMap, ok := metadata["token"].(map[string]any); ok {
if refreshed, errRefresh := refreshGeminiAccessToken(tokenMap, http.DefaultClient); errRefresh == nil {
accessToken = refreshed
}
}
}
if accessToken != "" {
fetchedProjectID, errFetch := FetchAntigravityProjectID(context.Background(), accessToken, http.DefaultClient)
if errFetch == nil && strings.TrimSpace(fetchedProjectID) != "" {
metadata["project_id"] = strings.TrimSpace(fetchedProjectID)
if raw, errMarshal := json.Marshal(metadata); errMarshal == nil {
if file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600); errOpen == nil {
_, _ = file.Write(raw)
_ = file.Close()
}
}
}
}
}
}
info, err := os.Stat(path)
if err != nil {
return nil, fmt.Errorf("stat file: %w", err)
}
id := s.idFor(path, baseDir)
disabled, _ := metadata["disabled"].(bool)
status := cliproxyauth.StatusActive
if disabled {
status = cliproxyauth.StatusDisabled
}
auth := &cliproxyauth.Auth{
ID: id,
Provider: provider,
FileName: id,
Label: s.labelFor(metadata),
Status: status,
Disabled: disabled,
Attributes: map[string]string{"path": path},
Metadata: metadata,
CreatedAt: info.ModTime(),
UpdatedAt: info.ModTime(),
LastRefreshedAt: time.Time{},
NextRefreshAfter: time.Time{},
}
if email, ok := metadata["email"].(string); ok && email != "" {
auth.Attributes["email"] = email
}
return auth, nil
}
func (s *FileTokenStore) idFor(path, baseDir string) string {
id := path
if baseDir != "" {
if rel, errRel := filepath.Rel(baseDir, path); errRel == nil && rel != "" {
id = rel
}
}
// On Windows, normalize ID casing to avoid duplicate auth entries caused by case-insensitive paths.
if runtime.GOOS == "windows" {
id = strings.ToLower(id)
}
return id
}
func (s *FileTokenStore) resolveAuthPath(auth *cliproxyauth.Auth) (string, error) {
if auth == nil {
return "", fmt.Errorf("auth filestore: auth is nil")
}
if auth.Attributes != nil {
if p := strings.TrimSpace(auth.Attributes["path"]); p != "" {
return p, nil
}
}
if fileName := strings.TrimSpace(auth.FileName); fileName != "" {
if filepath.IsAbs(fileName) {
return fileName, nil
}
if dir := s.baseDirSnapshot(); dir != "" {
return filepath.Join(dir, fileName), nil
}
return fileName, nil
}
if auth.ID == "" {
return "", fmt.Errorf("auth filestore: missing id")
}
if filepath.IsAbs(auth.ID) {
return auth.ID, nil
}
dir := s.baseDirSnapshot()
if dir == "" {
return "", fmt.Errorf("auth filestore: directory not configured")
}
return filepath.Join(dir, auth.ID), nil
}
func (s *FileTokenStore) labelFor(metadata map[string]any) string {
if metadata == nil {
return ""
}
if v, ok := metadata["label"].(string); ok && v != "" {
return v
}
if v, ok := metadata["email"].(string); ok && v != "" {
return v
}
if project, ok := metadata["project_id"].(string); ok && project != "" {
return project
}
return ""
}
func (s *FileTokenStore) baseDirSnapshot() string {
s.dirLock.RLock()
defer s.dirLock.RUnlock()
return s.baseDir
}
func extractAccessToken(metadata map[string]any) string {
if at, ok := metadata["access_token"].(string); ok {
if v := strings.TrimSpace(at); v != "" {
return v
}
}
if tokenMap, ok := metadata["token"].(map[string]any); ok {
if at, ok := tokenMap["access_token"].(string); ok {
if v := strings.TrimSpace(at); v != "" {
return v
}
}
}
return ""
}
func refreshGeminiAccessToken(tokenMap map[string]any, httpClient *http.Client) (string, error) {
refreshToken, _ := tokenMap["refresh_token"].(string)
clientID, _ := tokenMap["client_id"].(string)
clientSecret, _ := tokenMap["client_secret"].(string)
tokenURI, _ := tokenMap["token_uri"].(string)
if refreshToken == "" || clientID == "" || clientSecret == "" {
return "", fmt.Errorf("missing refresh credentials")
}
if tokenURI == "" {
tokenURI = "https://oauth2.googleapis.com/token"
}
data := url.Values{
"grant_type": {"refresh_token"},
"refresh_token": {refreshToken},
"client_id": {clientID},
"client_secret": {clientSecret},
}
resp, err := httpClient.PostForm(tokenURI, data)
if err != nil {
return "", fmt.Errorf("refresh request: %w", err)
}
defer func() { _ = resp.Body.Close() }()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("refresh failed: status %d", resp.StatusCode)
}
var result map[string]any
if errUnmarshal := json.Unmarshal(body, &result); errUnmarshal != nil {
return "", fmt.Errorf("decode refresh response: %w", errUnmarshal)
}
newAccessToken, _ := result["access_token"].(string)
if newAccessToken == "" {
return "", fmt.Errorf("no access_token in refresh response")
}
tokenMap["access_token"] = newAccessToken
return newAccessToken, nil
}
// jsonEqual compares two JSON blobs by parsing them into Go objects and deep comparing.
func jsonEqual(a, b []byte) bool {
var objA any
var objB any
if err := json.Unmarshal(a, &objA); err != nil {
return false
}
if err := json.Unmarshal(b, &objB); err != nil {
return false
}
return deepEqualJSON(objA, objB)
}
func deepEqualJSON(a, b any) bool {
switch valA := a.(type) {
case map[string]any:
valB, ok := b.(map[string]any)
if !ok || len(valA) != len(valB) {
return false
}
for key, subA := range valA {
subB, ok1 := valB[key]
if !ok1 || !deepEqualJSON(subA, subB) {
return false
}
}
return true
case []any:
sliceB, ok := b.([]any)
if !ok || len(valA) != len(sliceB) {
return false
}
for i := range valA {
if !deepEqualJSON(valA[i], sliceB[i]) {
return false
}
}
return true
case float64:
valB, ok := b.(float64)
if !ok {
return false
}
return valA == valB
case string:
valB, ok := b.(string)
if !ok {
return false
}
return valA == valB
case bool:
valB, ok := b.(bool)
if !ok {
return false
}
return valA == valB
case nil:
return b == nil
default:
return false
}
}
================================================
FILE: sdk/auth/filestore_test.go
================================================
package auth
import "testing"
func TestExtractAccessToken(t *testing.T) {
t.Parallel()
tests := []struct {
name string
metadata map[string]any
expected string
}{
{
"antigravity top-level access_token",
map[string]any{"access_token": "tok-abc"},
"tok-abc",
},
{
"gemini nested token.access_token",
map[string]any{
"token": map[string]any{"access_token": "tok-nested"},
},
"tok-nested",
},
{
"top-level takes precedence over nested",
map[string]any{
"access_token": "tok-top",
"token": map[string]any{"access_token": "tok-nested"},
},
"tok-top",
},
{
"empty metadata",
map[string]any{},
"",
},
{
"whitespace-only access_token",
map[string]any{"access_token": " "},
"",
},
{
"wrong type access_token",
map[string]any{"access_token": 12345},
"",
},
{
"token is not a map",
map[string]any{"token": "not-a-map"},
"",
},
{
"nested whitespace-only",
map[string]any{
"token": map[string]any{"access_token": " "},
},
"",
},
{
"fallback to nested when top-level empty",
map[string]any{
"access_token": "",
"token": map[string]any{"access_token": "tok-fallback"},
},
"tok-fallback",
},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := extractAccessToken(tt.metadata)
if got != tt.expected {
t.Errorf("extractAccessToken() = %q, want %q", got, tt.expected)
}
})
}
}
================================================
FILE: sdk/auth/gemini.go
================================================
package auth
import (
"context"
"fmt"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
// legacy client removed
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// GeminiAuthenticator implements the login flow for Google Gemini CLI accounts.
type GeminiAuthenticator struct{}
// NewGeminiAuthenticator constructs a Gemini authenticator.
func NewGeminiAuthenticator() *GeminiAuthenticator {
return &GeminiAuthenticator{}
}
func (a *GeminiAuthenticator) Provider() string {
return "gemini"
}
func (a *GeminiAuthenticator) RefreshLead() *time.Duration {
return nil
}
func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
var ts gemini.GeminiTokenStorage
if opts.ProjectID != "" {
ts.ProjectID = opts.ProjectID
}
geminiAuth := gemini.NewGeminiAuth()
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
NoBrowser: opts.NoBrowser,
CallbackPort: opts.CallbackPort,
Prompt: opts.Prompt,
})
if err != nil {
return nil, fmt.Errorf("gemini authentication failed: %w", err)
}
// Skip onboarding here; rely on upstream configuration
fileName := fmt.Sprintf("%s-%s.json", ts.Email, ts.ProjectID)
metadata := map[string]any{
"email": ts.Email,
"project_id": ts.ProjectID,
}
fmt.Println("Gemini authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: &ts,
Metadata: metadata,
}, nil
}
================================================
FILE: sdk/auth/iflow.go
================================================
package auth
import (
"context"
"fmt"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// IFlowAuthenticator implements the OAuth login flow for iFlow accounts.
type IFlowAuthenticator struct{}
// NewIFlowAuthenticator constructs a new authenticator instance.
func NewIFlowAuthenticator() *IFlowAuthenticator { return &IFlowAuthenticator{} }
// Provider returns the provider key for the authenticator.
func (a *IFlowAuthenticator) Provider() string { return "iflow" }
// RefreshLead indicates how soon before expiry a refresh should be attempted.
func (a *IFlowAuthenticator) RefreshLead() *time.Duration {
return new(24 * time.Hour)
}
// Login performs the OAuth code flow using a local callback server.
func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
callbackPort := iflow.CallbackPort
if opts.CallbackPort > 0 {
callbackPort = opts.CallbackPort
}
authSvc := iflow.NewIFlowAuth(cfg)
oauthServer := iflow.NewOAuthServer(callbackPort)
if err := oauthServer.Start(); err != nil {
if strings.Contains(err.Error(), "already in use") {
return nil, fmt.Errorf("iflow authentication server port in use: %w", err)
}
return nil, fmt.Errorf("iflow authentication server failed: %w", err)
}
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
defer cancel()
if stopErr := oauthServer.Stop(stopCtx); stopErr != nil {
log.Warnf("iflow oauth server stop error: %v", stopErr)
}
}()
state, err := misc.GenerateRandomState()
if err != nil {
return nil, fmt.Errorf("iflow auth: failed to generate state: %w", err)
}
authURL, redirectURI := authSvc.AuthorizationURL(state, callbackPort)
if !opts.NoBrowser {
fmt.Println("Opening browser for iFlow authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err)
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
util.PrintSSHTunnelInstructions(callbackPort)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for iFlow authentication callback...")
callbackCh := make(chan *iflow.OAuthResult, 1)
callbackErrCh := make(chan error, 1)
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *iflow.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
default:
}
input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
result = &iflow.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
}
if result.Error != "" {
return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error)
}
if result.State != state {
return nil, fmt.Errorf("iflow auth: state mismatch")
}
tokenData, err := authSvc.ExchangeCodeForTokens(ctx, result.Code, redirectURI)
if err != nil {
return nil, fmt.Errorf("iflow authentication failed: %w", err)
}
tokenStorage := authSvc.CreateTokenStorage(tokenData)
email := strings.TrimSpace(tokenStorage.Email)
if email == "" {
return nil, fmt.Errorf("iflow authentication failed: missing account identifier")
}
fileName := fmt.Sprintf("iflow-%s-%d.json", email, time.Now().Unix())
metadata := map[string]any{
"email": email,
"api_key": tokenStorage.APIKey,
"access_token": tokenStorage.AccessToken,
"refresh_token": tokenStorage.RefreshToken,
"expired": tokenStorage.Expire,
}
fmt.Println("iFlow authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
Attributes: map[string]string{
"api_key": tokenStorage.APIKey,
},
}, nil
}
================================================
FILE: sdk/auth/interfaces.go
================================================
package auth
import (
"context"
"errors"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
var ErrRefreshNotSupported = errors.New("cliproxy auth: refresh not supported")
// LoginOptions captures generic knobs shared across authenticators.
// Provider-specific logic can inspect Metadata for extra parameters.
type LoginOptions struct {
NoBrowser bool
ProjectID string
CallbackPort int
Metadata map[string]string
Prompt func(prompt string) (string, error)
}
// Authenticator manages login and optional refresh flows for a provider.
type Authenticator interface {
Provider() string
Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error)
RefreshLead() *time.Duration
}
================================================
FILE: sdk/auth/kimi.go
================================================
package auth
import (
"context"
"fmt"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// kimiRefreshLead is the duration before token expiry when refresh should occur.
var kimiRefreshLead = 5 * time.Minute
// KimiAuthenticator implements the OAuth device flow login for Kimi (Moonshot AI).
type KimiAuthenticator struct{}
// NewKimiAuthenticator constructs a new Kimi authenticator.
func NewKimiAuthenticator() Authenticator {
return &KimiAuthenticator{}
}
// Provider returns the provider key for kimi.
func (KimiAuthenticator) Provider() string {
return "kimi"
}
// RefreshLead returns the duration before token expiry when refresh should occur.
// Kimi tokens expire and need to be refreshed before expiry.
func (KimiAuthenticator) RefreshLead() *time.Duration {
return &kimiRefreshLead
}
// Login initiates the Kimi device flow authentication.
func (a KimiAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if opts == nil {
opts = &LoginOptions{}
}
authSvc := kimi.NewKimiAuth(cfg)
// Start the device flow
fmt.Println("Starting Kimi authentication...")
deviceCode, err := authSvc.StartDeviceFlow(ctx)
if err != nil {
return nil, fmt.Errorf("kimi: failed to start device flow: %w", err)
}
// Display the verification URL
verificationURL := deviceCode.VerificationURIComplete
if verificationURL == "" {
verificationURL = deviceCode.VerificationURI
}
fmt.Printf("\nTo authenticate, please visit:\n%s\n\n", verificationURL)
if deviceCode.UserCode != "" {
fmt.Printf("User code: %s\n\n", deviceCode.UserCode)
}
// Try to open the browser automatically
if !opts.NoBrowser {
if browser.IsAvailable() {
if errOpen := browser.OpenURL(verificationURL); errOpen != nil {
log.Warnf("Failed to open browser automatically: %v", errOpen)
} else {
fmt.Println("Browser opened automatically.")
}
}
}
fmt.Println("Waiting for authorization...")
if deviceCode.ExpiresIn > 0 {
fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn)
}
// Wait for user authorization
authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode)
if err != nil {
return nil, fmt.Errorf("kimi: %w", err)
}
// Create the token storage
tokenStorage := authSvc.CreateTokenStorage(authBundle)
// Build metadata with token information
metadata := map[string]any{
"type": "kimi",
"access_token": authBundle.TokenData.AccessToken,
"refresh_token": authBundle.TokenData.RefreshToken,
"token_type": authBundle.TokenData.TokenType,
"scope": authBundle.TokenData.Scope,
"timestamp": time.Now().UnixMilli(),
}
if authBundle.TokenData.ExpiresAt > 0 {
exp := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
metadata["expired"] = exp
}
if strings.TrimSpace(authBundle.DeviceID) != "" {
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
}
// Generate a unique filename
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
fmt.Println("\nKimi authentication successful!")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Label: "Kimi User",
Storage: tokenStorage,
Metadata: metadata,
}, nil
}
================================================
FILE: sdk/auth/manager.go
================================================
package auth
import (
"context"
"fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// Manager aggregates authenticators and coordinates persistence via a token store.
type Manager struct {
authenticators map[string]Authenticator
store coreauth.Store
}
// NewManager constructs a manager with the provided token store and authenticators.
// If store is nil, the caller must set it later using SetStore.
func NewManager(store coreauth.Store, authenticators ...Authenticator) *Manager {
mgr := &Manager{
authenticators: make(map[string]Authenticator),
store: store,
}
for i := range authenticators {
mgr.Register(authenticators[i])
}
return mgr
}
// Register adds or replaces an authenticator keyed by its provider identifier.
func (m *Manager) Register(a Authenticator) {
if a == nil {
return
}
if m.authenticators == nil {
m.authenticators = make(map[string]Authenticator)
}
m.authenticators[a.Provider()] = a
}
// SetStore updates the token store used for persistence.
func (m *Manager) SetStore(store coreauth.Store) {
m.store = store
}
// Login executes the provider login flow and persists the resulting auth record.
func (m *Manager) Login(ctx context.Context, provider string, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, string, error) {
auth, ok := m.authenticators[provider]
if !ok {
return nil, "", fmt.Errorf("cliproxy auth: authenticator %s not registered", provider)
}
record, err := auth.Login(ctx, cfg, opts)
if err != nil {
return nil, "", err
}
if record == nil {
return nil, "", fmt.Errorf("cliproxy auth: authenticator %s returned nil record", provider)
}
if m.store == nil {
return record, "", nil
}
if cfg != nil {
if dirSetter, ok := m.store.(interface{ SetBaseDir(string) }); ok {
dirSetter.SetBaseDir(cfg.AuthDir)
}
}
savedPath, err := m.store.Save(ctx, record)
if err != nil {
return record, "", err
}
return record, savedPath, nil
}
================================================
FILE: sdk/auth/qwen.go
================================================
package auth
import (
"context"
"fmt"
"strings"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
// legacy client removed
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// QwenAuthenticator implements the device flow login for Qwen accounts.
type QwenAuthenticator struct{}
// NewQwenAuthenticator constructs a Qwen authenticator.
func NewQwenAuthenticator() *QwenAuthenticator {
return &QwenAuthenticator{}
}
func (a *QwenAuthenticator) Provider() string {
return "qwen"
}
func (a *QwenAuthenticator) RefreshLead() *time.Duration {
return new(3 * time.Hour)
}
func (a *QwenAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
if cfg == nil {
return nil, fmt.Errorf("cliproxy auth: configuration is required")
}
if ctx == nil {
ctx = context.Background()
}
if opts == nil {
opts = &LoginOptions{}
}
authSvc := qwen.NewQwenAuth(cfg)
deviceFlow, err := authSvc.InitiateDeviceFlow(ctx)
if err != nil {
return nil, fmt.Errorf("qwen device flow initiation failed: %w", err)
}
authURL := deviceFlow.VerificationURIComplete
if !opts.NoBrowser {
fmt.Println("Opening browser for Qwen authentication")
if !browser.IsAvailable() {
log.Warn("No browser available; please open the URL manually")
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
} else if err = browser.OpenURL(authURL); err != nil {
log.Warnf("Failed to open browser automatically: %v", err)
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
} else {
fmt.Printf("Visit the following URL to continue authentication:\n%s\n", authURL)
}
fmt.Println("Waiting for Qwen authentication...")
tokenData, err := authSvc.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
if err != nil {
return nil, fmt.Errorf("qwen authentication failed: %w", err)
}
tokenStorage := authSvc.CreateTokenStorage(tokenData)
email := ""
if opts.Metadata != nil {
email = opts.Metadata["email"]
if email == "" {
email = opts.Metadata["alias"]
}
}
if email == "" && opts.Prompt != nil {
email, err = opts.Prompt("Please input your email address or alias for Qwen:")
if err != nil {
return nil, err
}
}
email = strings.TrimSpace(email)
if email == "" {
return nil, &EmailRequiredError{Prompt: "Please provide an email address or alias for Qwen."}
}
tokenStorage.Email = email
// no legacy client construction
fileName := fmt.Sprintf("qwen-%s.json", tokenStorage.Email)
metadata := map[string]any{
"email": tokenStorage.Email,
}
fmt.Println("Qwen authentication successful")
return &coreauth.Auth{
ID: fileName,
Provider: a.Provider(),
FileName: fileName,
Storage: tokenStorage,
Metadata: metadata,
}, nil
}
================================================
FILE: sdk/auth/refresh_registry.go
================================================
package auth
import (
"time"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func init() {
registerRefreshLead("codex", func() Authenticator { return NewCodexAuthenticator() })
registerRefreshLead("claude", func() Authenticator { return NewClaudeAuthenticator() })
registerRefreshLead("qwen", func() Authenticator { return NewQwenAuthenticator() })
registerRefreshLead("iflow", func() Authenticator { return NewIFlowAuthenticator() })
registerRefreshLead("gemini", func() Authenticator { return NewGeminiAuthenticator() })
registerRefreshLead("gemini-cli", func() Authenticator { return NewGeminiAuthenticator() })
registerRefreshLead("antigravity", func() Authenticator { return NewAntigravityAuthenticator() })
registerRefreshLead("kimi", func() Authenticator { return NewKimiAuthenticator() })
}
func registerRefreshLead(provider string, factory func() Authenticator) {
cliproxyauth.RegisterRefreshLeadProvider(provider, func() *time.Duration {
if factory == nil {
return nil
}
auth := factory()
if auth == nil {
return nil
}
return auth.RefreshLead()
})
}
================================================
FILE: sdk/auth/store_registry.go
================================================
package auth
import (
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
var (
storeMu sync.RWMutex
registeredStore coreauth.Store
)
// RegisterTokenStore sets the global token store used by the authentication helpers.
func RegisterTokenStore(store coreauth.Store) {
storeMu.Lock()
registeredStore = store
storeMu.Unlock()
}
// GetTokenStore returns the globally registered token store.
func GetTokenStore() coreauth.Store {
storeMu.RLock()
s := registeredStore
storeMu.RUnlock()
if s != nil {
return s
}
storeMu.Lock()
defer storeMu.Unlock()
if registeredStore == nil {
registeredStore = NewFileTokenStore()
}
return registeredStore
}
================================================
FILE: sdk/cliproxy/auth/api_key_model_alias_test.go
================================================
package auth
import (
"context"
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestLookupAPIKeyUpstreamModel(t *testing.T) {
cfg := &internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{
{
APIKey: "k",
BaseURL: "https://example.com",
Models: []internalconfig.GeminiModel{
{Name: "gemini-2.5-pro-exp-03-25", Alias: "g25p"},
{Name: "gemini-2.5-flash(low)", Alias: "g25f"},
},
},
},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(cfg)
ctx := context.Background()
_, _ = mgr.Register(ctx, &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k", "base_url": "https://example.com"}})
tests := []struct {
name string
authID string
input string
want string
}{
// Fast path + suffix preservation
{"alias with suffix", "a1", "g25p(8192)", "gemini-2.5-pro-exp-03-25(8192)"},
{"alias without suffix", "a1", "g25p", "gemini-2.5-pro-exp-03-25"},
// Config suffix takes priority
{"config suffix priority", "a1", "g25f(high)", "gemini-2.5-flash(low)"},
{"config suffix no user suffix", "a1", "g25f", "gemini-2.5-flash(low)"},
// Case insensitive
{"uppercase alias", "a1", "G25P", "gemini-2.5-pro-exp-03-25"},
{"mixed case with suffix", "a1", "G25p(4096)", "gemini-2.5-pro-exp-03-25(4096)"},
// Direct name lookup
{"upstream name direct", "a1", "gemini-2.5-pro-exp-03-25", "gemini-2.5-pro-exp-03-25"},
{"upstream name with suffix", "a1", "gemini-2.5-pro-exp-03-25(8192)", "gemini-2.5-pro-exp-03-25(8192)"},
// Cache miss scenarios
{"non-existent auth", "non-existent", "g25p", ""},
{"unknown alias", "a1", "unknown-alias", ""},
{"empty auth ID", "", "g25p", ""},
{"empty model", "a1", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resolved := mgr.lookupAPIKeyUpstreamModel(tt.authID, tt.input)
if resolved != tt.want {
t.Errorf("lookupAPIKeyUpstreamModel(%q, %q) = %q, want %q", tt.authID, tt.input, resolved, tt.want)
}
})
}
}
func TestAPIKeyModelAlias_ConfigHotReload(t *testing.T) {
cfg := &internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{
{
APIKey: "k",
Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-pro-exp-03-25", Alias: "g25p"}},
},
},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(cfg)
ctx := context.Background()
_, _ = mgr.Register(ctx, &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k"}})
// Initial alias
if resolved := mgr.lookupAPIKeyUpstreamModel("a1", "g25p"); resolved != "gemini-2.5-pro-exp-03-25" {
t.Fatalf("before reload: got %q, want %q", resolved, "gemini-2.5-pro-exp-03-25")
}
// Hot reload with new alias
mgr.SetConfig(&internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{
{
APIKey: "k",
Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-flash", Alias: "g25p"}},
},
},
})
// New alias should take effect
if resolved := mgr.lookupAPIKeyUpstreamModel("a1", "g25p"); resolved != "gemini-2.5-flash" {
t.Fatalf("after reload: got %q, want %q", resolved, "gemini-2.5-flash")
}
}
func TestAPIKeyModelAlias_MultipleProviders(t *testing.T) {
cfg := &internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{{APIKey: "gemini-key", Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-pro", Alias: "gp"}}}},
ClaudeKey: []internalconfig.ClaudeKey{{APIKey: "claude-key", Models: []internalconfig.ClaudeModel{{Name: "claude-sonnet-4", Alias: "cs4"}}}},
CodexKey: []internalconfig.CodexKey{{APIKey: "codex-key", Models: []internalconfig.CodexModel{{Name: "o3", Alias: "o"}}}},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(cfg)
ctx := context.Background()
_, _ = mgr.Register(ctx, &Auth{ID: "gemini-auth", Provider: "gemini", Attributes: map[string]string{"api_key": "gemini-key"}})
_, _ = mgr.Register(ctx, &Auth{ID: "claude-auth", Provider: "claude", Attributes: map[string]string{"api_key": "claude-key"}})
_, _ = mgr.Register(ctx, &Auth{ID: "codex-auth", Provider: "codex", Attributes: map[string]string{"api_key": "codex-key"}})
tests := []struct {
authID, input, want string
}{
{"gemini-auth", "gp", "gemini-2.5-pro"},
{"claude-auth", "cs4", "claude-sonnet-4"},
{"codex-auth", "o", "o3"},
}
for _, tt := range tests {
if resolved := mgr.lookupAPIKeyUpstreamModel(tt.authID, tt.input); resolved != tt.want {
t.Errorf("lookupAPIKeyUpstreamModel(%q, %q) = %q, want %q", tt.authID, tt.input, resolved, tt.want)
}
}
}
func TestApplyAPIKeyModelAlias(t *testing.T) {
cfg := &internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{
{APIKey: "k", Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-pro-exp-03-25", Alias: "g25p"}}},
},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(cfg)
ctx := context.Background()
apiKeyAuth := &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k"}}
oauthAuth := &Auth{ID: "oauth-auth", Provider: "gemini", Attributes: map[string]string{"auth_kind": "oauth"}}
_, _ = mgr.Register(ctx, apiKeyAuth)
tests := []struct {
name string
auth *Auth
inputModel string
wantModel string
}{
{
name: "api_key auth with alias",
auth: apiKeyAuth,
inputModel: "g25p(8192)",
wantModel: "gemini-2.5-pro-exp-03-25(8192)",
},
{
name: "oauth auth passthrough",
auth: oauthAuth,
inputModel: "some-model",
wantModel: "some-model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resolvedModel := mgr.applyAPIKeyModelAlias(tt.auth, tt.inputModel)
if resolvedModel != tt.wantModel {
t.Errorf("model = %q, want %q", resolvedModel, tt.wantModel)
}
})
}
}
================================================
FILE: sdk/cliproxy/auth/conductor.go
================================================
package auth
import (
"bytes"
"context"
"encoding/json"
"errors"
"io"
"net/http"
"path/filepath"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/google/uuid"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
)
// ProviderExecutor defines the contract required by Manager to execute provider calls.
type ProviderExecutor interface {
// Identifier returns the provider key handled by this executor.
Identifier() string
// Execute handles non-streaming execution and returns the provider response payload.
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
// ExecuteStream handles streaming execution and returns a StreamResult containing
// upstream headers and a channel of provider chunks.
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error)
// Refresh attempts to refresh provider credentials and returns the updated auth state.
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
// CountTokens returns the token count for the given request.
CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
// HttpRequest injects provider credentials into the supplied HTTP request and executes it.
// Callers must close the response body when non-nil.
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
}
// ExecutionSessionCloser allows executors to release per-session runtime resources.
type ExecutionSessionCloser interface {
CloseExecutionSession(sessionID string)
}
const (
// CloseAllExecutionSessionsID asks an executor to release all active execution sessions.
// Executors that do not support this marker may ignore it.
CloseAllExecutionSessionsID = "__all_execution_sessions__"
)
// RefreshEvaluator allows runtime state to override refresh decisions.
type RefreshEvaluator interface {
ShouldRefresh(now time.Time, auth *Auth) bool
}
const (
refreshCheckInterval = 5 * time.Second
refreshMaxConcurrency = 16
refreshPendingBackoff = time.Minute
refreshFailureBackoff = 5 * time.Minute
quotaBackoffBase = time.Second
quotaBackoffMax = 30 * time.Minute
)
var quotaCooldownDisabled atomic.Bool
// SetQuotaCooldownDisabled toggles quota cooldown scheduling globally.
func SetQuotaCooldownDisabled(disable bool) {
quotaCooldownDisabled.Store(disable)
}
func quotaCooldownDisabledForAuth(auth *Auth) bool {
if auth != nil {
if override, ok := auth.DisableCoolingOverride(); ok {
return override
}
}
return quotaCooldownDisabled.Load()
}
// Result captures execution outcome used to adjust auth state.
type Result struct {
// AuthID references the auth that produced this result.
AuthID string
// Provider is copied for convenience when emitting hooks.
Provider string
// Model is the upstream model identifier used for the request.
Model string
// Success marks whether the execution succeeded.
Success bool
// RetryAfter carries a provider supplied retry hint (e.g. 429 retryDelay).
RetryAfter *time.Duration
// Error describes the failure when Success is false.
Error *Error
}
// Selector chooses an auth candidate for execution.
type Selector interface {
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
}
// Hook captures lifecycle callbacks for observing auth changes.
type Hook interface {
// OnAuthRegistered fires when a new auth is registered.
OnAuthRegistered(ctx context.Context, auth *Auth)
// OnAuthUpdated fires when an existing auth changes state.
OnAuthUpdated(ctx context.Context, auth *Auth)
// OnResult fires when execution result is recorded.
OnResult(ctx context.Context, result Result)
}
// NoopHook provides optional hook defaults.
type NoopHook struct{}
// OnAuthRegistered implements Hook.
func (NoopHook) OnAuthRegistered(context.Context, *Auth) {}
// OnAuthUpdated implements Hook.
func (NoopHook) OnAuthUpdated(context.Context, *Auth) {}
// OnResult implements Hook.
func (NoopHook) OnResult(context.Context, Result) {}
// Manager orchestrates auth lifecycle, selection, execution, and persistence.
type Manager struct {
store Store
executors map[string]ProviderExecutor
selector Selector
hook Hook
mu sync.RWMutex
auths map[string]*Auth
scheduler *authScheduler
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
providerOffsets map[string]int
// Retry controls request retry behavior.
requestRetry atomic.Int32
maxRetryCredentials atomic.Int32
maxRetryInterval atomic.Int64
// oauthModelAlias stores global OAuth model alias mappings (alias -> upstream name) keyed by channel.
oauthModelAlias atomic.Value
// apiKeyModelAlias caches resolved model alias mappings for API-key auths.
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
apiKeyModelAlias atomic.Value
// modelPoolOffsets tracks per-auth alias pool rotation state.
modelPoolOffsets map[string]int
// runtimeConfig stores the latest application config for request-time decisions.
// It is initialized in NewManager; never Load() before first Store().
runtimeConfig atomic.Value
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider
// Auto refresh state
refreshCancel context.CancelFunc
refreshSemaphore chan struct{}
}
// NewManager constructs a manager with optional custom selector and hook.
func NewManager(store Store, selector Selector, hook Hook) *Manager {
if selector == nil {
selector = &RoundRobinSelector{}
}
if hook == nil {
hook = NoopHook{}
}
manager := &Manager{
store: store,
executors: make(map[string]ProviderExecutor),
selector: selector,
hook: hook,
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
modelPoolOffsets: make(map[string]int),
refreshSemaphore: make(chan struct{}, refreshMaxConcurrency),
}
// atomic.Value requires non-nil initial value.
manager.runtimeConfig.Store(&internalconfig.Config{})
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
manager.scheduler = newAuthScheduler(selector)
return manager
}
func isBuiltInSelector(selector Selector) bool {
switch selector.(type) {
case *RoundRobinSelector, *FillFirstSelector:
return true
default:
return false
}
}
func (m *Manager) syncSchedulerFromSnapshot(auths []*Auth) {
if m == nil || m.scheduler == nil {
return
}
m.scheduler.rebuild(auths)
}
func (m *Manager) syncScheduler() {
if m == nil || m.scheduler == nil {
return
}
m.syncSchedulerFromSnapshot(m.snapshotAuths())
}
// RefreshSchedulerEntry re-upserts a single auth into the scheduler so that its
// supportedModelSet is rebuilt from the current global model registry state.
// This must be called after models have been registered for a newly added auth,
// because the initial scheduler.upsertAuth during Register/Update runs before
// registerModelsForAuth and therefore snapshots an empty model set.
func (m *Manager) RefreshSchedulerEntry(authID string) {
if m == nil || m.scheduler == nil || authID == "" {
return
}
m.mu.RLock()
auth, ok := m.auths[authID]
if !ok || auth == nil {
m.mu.RUnlock()
return
}
snapshot := auth.Clone()
m.mu.RUnlock()
m.scheduler.upsertAuth(snapshot)
}
func (m *Manager) SetSelector(selector Selector) {
if m == nil {
return
}
if selector == nil {
selector = &RoundRobinSelector{}
}
m.mu.Lock()
m.selector = selector
m.mu.Unlock()
if m.scheduler != nil {
m.scheduler.setSelector(selector)
m.syncScheduler()
}
}
// SetStore swaps the underlying persistence store.
func (m *Manager) SetStore(store Store) {
m.mu.Lock()
defer m.mu.Unlock()
m.store = store
}
// SetRoundTripperProvider register a provider that returns a per-auth RoundTripper.
func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
m.mu.Lock()
m.rtProvider = p
m.mu.Unlock()
}
// SetConfig updates the runtime config snapshot used by request-time helpers.
// Callers should provide the latest config on reload so per-credential alias mapping stays in sync.
func (m *Manager) SetConfig(cfg *internalconfig.Config) {
if m == nil {
return
}
if cfg == nil {
cfg = &internalconfig.Config{}
}
m.runtimeConfig.Store(cfg)
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
}
func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) string {
if m == nil {
return ""
}
authID = strings.TrimSpace(authID)
if authID == "" {
return ""
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
}
table, _ := m.apiKeyModelAlias.Load().(apiKeyModelAliasTable)
if table == nil {
return ""
}
byAlias := table[authID]
if len(byAlias) == 0 {
return ""
}
key := strings.ToLower(thinking.ParseSuffix(requestedModel).ModelName)
if key == "" {
key = strings.ToLower(requestedModel)
}
resolved := strings.TrimSpace(byAlias[key])
if resolved == "" {
return ""
}
return preserveRequestedModelSuffix(requestedModel, resolved)
}
func isAPIKeyAuth(auth *Auth) bool {
if auth == nil {
return false
}
kind, _ := auth.AccountInfo()
return strings.EqualFold(strings.TrimSpace(kind), "api_key")
}
func isOpenAICompatAPIKeyAuth(auth *Auth) bool {
if !isAPIKeyAuth(auth) {
return false
}
if strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return true
}
if auth.Attributes == nil {
return false
}
return strings.TrimSpace(auth.Attributes["compat_name"]) != ""
}
func openAICompatProviderKey(auth *Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
if providerKey := strings.TrimSpace(auth.Attributes["provider_key"]); providerKey != "" {
return strings.ToLower(providerKey)
}
if compatName := strings.TrimSpace(auth.Attributes["compat_name"]); compatName != "" {
return strings.ToLower(compatName)
}
}
return strings.ToLower(strings.TrimSpace(auth.Provider))
}
func openAICompatModelPoolKey(auth *Auth, requestedModel string) string {
base := strings.TrimSpace(thinking.ParseSuffix(requestedModel).ModelName)
if base == "" {
base = strings.TrimSpace(requestedModel)
}
return strings.ToLower(strings.TrimSpace(auth.ID)) + "|" + openAICompatProviderKey(auth) + "|" + strings.ToLower(base)
}
func (m *Manager) nextModelPoolOffset(key string, size int) int {
if m == nil || size <= 1 {
return 0
}
key = strings.TrimSpace(key)
if key == "" {
return 0
}
m.mu.Lock()
defer m.mu.Unlock()
if m.modelPoolOffsets == nil {
m.modelPoolOffsets = make(map[string]int)
}
offset := m.modelPoolOffsets[key]
if offset >= 2_147_483_640 {
offset = 0
}
m.modelPoolOffsets[key] = offset + 1
if size <= 0 {
return 0
}
return offset % size
}
func rotateStrings(values []string, offset int) []string {
if len(values) <= 1 {
return values
}
if offset <= 0 {
out := make([]string, len(values))
copy(out, values)
return out
}
offset = offset % len(values)
out := make([]string, 0, len(values))
out = append(out, values[offset:]...)
out = append(out, values[:offset]...)
return out
}
func (m *Manager) resolveOpenAICompatUpstreamModelPool(auth *Auth, requestedModel string) []string {
if m == nil || !isOpenAICompatAPIKeyAuth(auth) {
return nil
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
providerKey := ""
compatName := ""
if auth.Attributes != nil {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
if entry == nil {
return nil
}
return resolveModelAliasPoolFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func preserveRequestedModelSuffix(requestedModel, resolved string) string {
return preserveResolvedModelSuffix(resolved, thinking.ParseSuffix(requestedModel))
}
func (m *Manager) executionModelCandidates(auth *Auth, routeModel string) []string {
return m.prepareExecutionModels(auth, routeModel)
}
func (m *Manager) prepareExecutionModels(auth *Auth, routeModel string) []string {
requestedModel := rewriteModelForAuth(routeModel, auth)
requestedModel = m.applyOAuthModelAlias(auth, requestedModel)
if pool := m.resolveOpenAICompatUpstreamModelPool(auth, requestedModel); len(pool) > 0 {
if len(pool) == 1 {
return pool
}
offset := m.nextModelPoolOffset(openAICompatModelPoolKey(auth, requestedModel), len(pool))
return rotateStrings(pool, offset)
}
resolved := m.applyAPIKeyModelAlias(auth, requestedModel)
if strings.TrimSpace(resolved) == "" {
resolved = requestedModel
}
return []string{resolved}
}
func discardStreamChunks(ch <-chan cliproxyexecutor.StreamChunk) {
if ch == nil {
return
}
go func() {
for range ch {
}
}()
}
func readStreamBootstrap(ctx context.Context, ch <-chan cliproxyexecutor.StreamChunk) ([]cliproxyexecutor.StreamChunk, bool, error) {
if ch == nil {
return nil, true, nil
}
buffered := make([]cliproxyexecutor.StreamChunk, 0, 1)
for {
var (
chunk cliproxyexecutor.StreamChunk
ok bool
)
if ctx != nil {
select {
case <-ctx.Done():
return nil, false, ctx.Err()
case chunk, ok = <-ch:
}
} else {
chunk, ok = <-ch
}
if !ok {
return buffered, true, nil
}
if chunk.Err != nil {
return nil, false, chunk.Err
}
buffered = append(buffered, chunk)
if len(chunk.Payload) > 0 {
return buffered, false, nil
}
}
}
func (m *Manager) wrapStreamResult(ctx context.Context, auth *Auth, provider, routeModel string, headers http.Header, buffered []cliproxyexecutor.StreamChunk, remaining <-chan cliproxyexecutor.StreamChunk) *cliproxyexecutor.StreamResult {
out := make(chan cliproxyexecutor.StreamChunk)
go func() {
defer close(out)
var failed bool
forward := true
emit := func(chunk cliproxyexecutor.StreamChunk) bool {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](chunk.Err); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
return false
}
if ctx == nil {
out <- chunk
return true
}
select {
case <-ctx.Done():
forward = false
return false
case out <- chunk:
return true
}
}
for _, chunk := range buffered {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
for chunk := range remaining {
if ok := emit(chunk); !ok {
discardStreamChunks(remaining)
return
}
}
if !failed {
m.MarkResult(ctx, Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: true})
}
}()
return &cliproxyexecutor.StreamResult{Headers: headers, Chunks: out}
}
func (m *Manager) executeStreamWithModelPool(ctx context.Context, executor ProviderExecutor, auth *Auth, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, routeModel string) (*cliproxyexecutor.StreamResult, error) {
if executor == nil {
return nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
execModels := m.prepareExecutionModels(auth, routeModel)
var lastErr error
for idx, execModel := range execModels {
execReq := req
execReq.Model = execModel
streamResult, errStream := executor.ExecuteStream(ctx, auth, execReq, opts)
if errStream != nil {
if errCtx := ctx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errStream); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(ctx, result)
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
buffered, closed, bootstrapErr := readStreamBootstrap(ctx, streamResult.Chunks)
if bootstrapErr != nil {
if errCtx := ctx.Err(); errCtx != nil {
discardStreamChunks(streamResult.Chunks)
return nil, errCtx
}
if isRequestInvalidError(bootstrapErr) {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
return nil, bootstrapErr
}
if idx < len(execModels)-1 {
rerr := &Error{Message: bootstrapErr.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](bootstrapErr); ok && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(bootstrapErr)
m.MarkResult(ctx, result)
discardStreamChunks(streamResult.Chunks)
lastErr = bootstrapErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: bootstrapErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
if closed && len(buffered) == 0 {
emptyErr := &Error{Code: "empty_stream", Message: "upstream stream closed before first payload", Retryable: true}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: emptyErr}
m.MarkResult(ctx, result)
if idx < len(execModels)-1 {
lastErr = emptyErr
continue
}
errCh := make(chan cliproxyexecutor.StreamChunk, 1)
errCh <- cliproxyexecutor.StreamChunk{Err: emptyErr}
close(errCh)
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, nil, errCh), nil
}
remaining := streamResult.Chunks
if closed {
closedCh := make(chan cliproxyexecutor.StreamChunk)
close(closedCh)
remaining = closedCh
}
return m.wrapStreamResult(ctx, auth.Clone(), provider, routeModel, streamResult.Headers, buffered, remaining), nil
}
if lastErr == nil {
lastErr = &Error{Code: "auth_not_found", Message: "no upstream model available"}
}
return nil, lastErr
}
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
if m == nil {
return
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
m.mu.Lock()
defer m.mu.Unlock()
m.rebuildAPIKeyModelAliasLocked(cfg)
}
func (m *Manager) rebuildAPIKeyModelAliasLocked(cfg *internalconfig.Config) {
if m == nil {
return
}
if cfg == nil {
cfg = &internalconfig.Config{}
}
out := make(apiKeyModelAliasTable)
for _, auth := range m.auths {
if auth == nil {
continue
}
if strings.TrimSpace(auth.ID) == "" {
continue
}
kind, _ := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
continue
}
byAlias := make(map[string]string)
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
switch provider {
case "gemini":
if entry := resolveGeminiAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
case "claude":
if entry := resolveClaudeAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
case "codex":
if entry := resolveCodexAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
case "vertex":
if entry := resolveVertexAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
default:
// OpenAI-compat uses config selection from auth.Attributes.
providerKey := ""
compatName := ""
if auth.Attributes != nil {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
if entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider); entry != nil {
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
}
}
}
if len(byAlias) > 0 {
out[auth.ID] = byAlias
}
}
m.apiKeyModelAlias.Store(out)
}
func compileAPIKeyModelAliasForModels[T interface {
GetName() string
GetAlias() string
}](out map[string]string, models []T) {
if out == nil {
return
}
for i := range models {
alias := strings.TrimSpace(models[i].GetAlias())
name := strings.TrimSpace(models[i].GetName())
if alias == "" || name == "" {
continue
}
aliasKey := strings.ToLower(thinking.ParseSuffix(alias).ModelName)
if aliasKey == "" {
aliasKey = strings.ToLower(alias)
}
// Config priority: first alias wins.
if _, exists := out[aliasKey]; exists {
continue
}
out[aliasKey] = name
// Also allow direct lookup by upstream name (case-insensitive), so lookups on already-upstream
// models remain a cheap no-op.
nameKey := strings.ToLower(thinking.ParseSuffix(name).ModelName)
if nameKey == "" {
nameKey = strings.ToLower(name)
}
if nameKey != "" {
if _, exists := out[nameKey]; !exists {
out[nameKey] = name
}
}
// Preserve config suffix priority by seeding a base-name lookup when name already has suffix.
nameResult := thinking.ParseSuffix(name)
if nameResult.HasSuffix {
baseKey := strings.ToLower(strings.TrimSpace(nameResult.ModelName))
if baseKey != "" {
if _, exists := out[baseKey]; !exists {
out[baseKey] = name
}
}
}
}
}
// SetRetryConfig updates retry attempts, credential retry limit and cooldown wait interval.
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration, maxRetryCredentials int) {
if m == nil {
return
}
if retry < 0 {
retry = 0
}
if maxRetryCredentials < 0 {
maxRetryCredentials = 0
}
if maxRetryInterval < 0 {
maxRetryInterval = 0
}
m.requestRetry.Store(int32(retry))
m.maxRetryCredentials.Store(int32(maxRetryCredentials))
m.maxRetryInterval.Store(maxRetryInterval.Nanoseconds())
}
// RegisterExecutor registers a provider executor with the manager.
func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
if executor == nil {
return
}
provider := strings.TrimSpace(executor.Identifier())
if provider == "" {
return
}
var replaced ProviderExecutor
m.mu.Lock()
replaced = m.executors[provider]
m.executors[provider] = executor
m.mu.Unlock()
if replaced == nil || replaced == executor {
return
}
if closer, ok := replaced.(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(CloseAllExecutionSessionsID)
}
}
// UnregisterExecutor removes the executor associated with the provider key.
func (m *Manager) UnregisterExecutor(provider string) {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return
}
m.mu.Lock()
delete(m.executors, provider)
m.mu.Unlock()
}
// Register inserts a new auth entry into the manager.
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil {
return nil, nil
}
if auth.ID == "" {
auth.ID = uuid.NewString()
}
auth.EnsureIndex()
authClone := auth.Clone()
m.mu.Lock()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil
}
// Update replaces an existing auth entry and notifies hooks.
func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil || auth.ID == "" {
return nil, nil
}
m.mu.Lock()
if existing, ok := m.auths[auth.ID]; ok && existing != nil {
if !auth.indexAssigned && auth.Index == "" {
auth.Index = existing.Index
auth.indexAssigned = existing.indexAssigned
}
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
auth.ModelStates = existing.ModelStates
}
}
auth.EnsureIndex()
authClone := auth.Clone()
m.auths[auth.ID] = authClone
m.mu.Unlock()
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
if m.scheduler != nil {
m.scheduler.upsertAuth(authClone)
}
_ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil
}
// Load resets manager state from the backing store.
func (m *Manager) Load(ctx context.Context) error {
m.mu.Lock()
if m.store == nil {
m.mu.Unlock()
return nil
}
items, err := m.store.List(ctx)
if err != nil {
m.mu.Unlock()
return err
}
m.auths = make(map[string]*Auth, len(items))
for _, auth := range items {
if auth == nil || auth.ID == "" {
continue
}
auth.EnsureIndex()
m.auths[auth.ID] = auth.Clone()
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
m.rebuildAPIKeyModelAliasLocked(cfg)
m.mu.Unlock()
m.syncScheduler()
return nil
}
// Execute performs a non-streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
_, maxRetryCredentials, maxWait := m.retrySettings()
var lastErr error
for attempt := 0; ; attempt++ {
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts, maxRetryCredentials)
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
if errWait := waitForCooldown(ctx, wait); errWait != nil {
return cliproxyexecutor.Response{}, errWait
}
}
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
// ExecuteCount performs a non-streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
_, maxRetryCredentials, maxWait := m.retrySettings()
var lastErr error
for attempt := 0; ; attempt++ {
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts, maxRetryCredentials)
if errExec == nil {
return resp, nil
}
lastErr = errExec
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
if errWait := waitForCooldown(ctx, wait); errWait != nil {
return cliproxyexecutor.Response{}, errWait
}
}
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
// ExecuteStream performs a streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
_, maxRetryCredentials, maxWait := m.retrySettings()
var lastErr error
for attempt := 0; ; attempt++ {
result, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts, maxRetryCredentials)
if errStream == nil {
return result, nil
}
lastErr = errStream
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, normalized, req.Model, maxWait)
if !shouldRetry {
break
}
if errWait := waitForCooldown(ctx, wait); errWait != nil {
return nil, errWait
}
}
if lastErr != nil {
return nil, lastErr
}
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = authErr
continue
}
}
}
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
models := m.prepareExecutionModels(auth, routeModel)
var authErr error
for _, upstreamModel := range models {
execReq := req
execReq.Model = upstreamModel
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
if se, ok := errors.AsType[cliproxyexecutor.StatusError](errExec); ok && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.hook.OnResult(execCtx, result)
if isRequestInvalidError(errExec) {
return cliproxyexecutor.Response{}, errExec
}
authErr = errExec
continue
}
m.hook.OnResult(execCtx, result)
return resp, nil
}
if authErr != nil {
if isRequestInvalidError(authErr) {
return cliproxyexecutor.Response{}, authErr
}
lastErr = authErr
continue
}
}
}
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, maxRetryCredentials int) (*cliproxyexecutor.StreamResult, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
if maxRetryCredentials > 0 && len(tried) >= maxRetryCredentials {
if lastErr != nil {
return nil, lastErr
}
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return nil, lastErr
}
return nil, errPick
}
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
streamResult, errStream := m.executeStreamWithModelPool(execCtx, executor, auth, provider, req, opts, routeModel)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
if isRequestInvalidError(errStream) {
return nil, errStream
}
lastErr = errStream
continue
}
return streamResult, nil
}
}
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return opts
}
if hasRequestedModelMetadata(opts.Metadata) {
return opts
}
if len(opts.Metadata) == 0 {
opts.Metadata = map[string]any{cliproxyexecutor.RequestedModelMetadataKey: requestedModel}
return opts
}
meta := make(map[string]any, len(opts.Metadata)+1)
for k, v := range opts.Metadata {
meta[k] = v
}
meta[cliproxyexecutor.RequestedModelMetadataKey] = requestedModel
opts.Metadata = meta
return opts
}
func hasRequestedModelMetadata(meta map[string]any) bool {
if len(meta) == 0 {
return false
}
raw, ok := meta[cliproxyexecutor.RequestedModelMetadataKey]
if !ok || raw == nil {
return false
}
switch v := raw.(type) {
case string:
return strings.TrimSpace(v) != ""
case []byte:
return strings.TrimSpace(string(v)) != ""
default:
return false
}
}
func pinnedAuthIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[cliproxyexecutor.PinnedAuthMetadataKey]
if !ok || raw == nil {
return ""
}
switch val := raw.(type) {
case string:
return strings.TrimSpace(val)
case []byte:
return strings.TrimSpace(string(val))
default:
return ""
}
}
func publishSelectedAuthMetadata(meta map[string]any, authID string) {
if len(meta) == 0 {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
meta[cliproxyexecutor.SelectedAuthMetadataKey] = authID
if callback, ok := meta[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil {
callback(authID)
}
}
func rewriteModelForAuth(model string, auth *Auth) string {
if auth == nil || model == "" {
return model
}
prefix := strings.TrimSpace(auth.Prefix)
if prefix == "" {
return model
}
needle := prefix + "/"
if !strings.HasPrefix(model, needle) {
return model
}
return strings.TrimPrefix(model, needle)
}
func (m *Manager) applyAPIKeyModelAlias(auth *Auth, requestedModel string) string {
if m == nil || auth == nil {
return requestedModel
}
kind, _ := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
return requestedModel
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return requestedModel
}
// Fast path: lookup per-auth mapping table (keyed by auth.ID).
if resolved := m.lookupAPIKeyUpstreamModel(auth.ID, requestedModel); resolved != "" {
return resolved
}
// Slow path: scan config for the matching credential entry and resolve alias.
// This acts as a safety net if mappings are stale or auth.ID is missing.
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
upstreamModel := ""
switch provider {
case "gemini":
upstreamModel = resolveUpstreamModelForGeminiAPIKey(cfg, auth, requestedModel)
case "claude":
upstreamModel = resolveUpstreamModelForClaudeAPIKey(cfg, auth, requestedModel)
case "codex":
upstreamModel = resolveUpstreamModelForCodexAPIKey(cfg, auth, requestedModel)
case "vertex":
upstreamModel = resolveUpstreamModelForVertexAPIKey(cfg, auth, requestedModel)
default:
upstreamModel = resolveUpstreamModelForOpenAICompatAPIKey(cfg, auth, requestedModel)
}
// Return upstream model if found, otherwise return requested model.
if upstreamModel != "" {
return upstreamModel
}
return requestedModel
}
// APIKeyConfigEntry is a generic interface for API key configurations.
type APIKeyConfigEntry interface {
GetAPIKey() string
GetBaseURL() string
}
func resolveAPIKeyConfig[T APIKeyConfigEntry](entries []T, auth *Auth) *T {
if auth == nil || len(entries) == 0 {
return nil
}
attrKey, attrBase := "", ""
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range entries {
entry := &entries[i]
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range entries {
entry := &entries[i]
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
return entry
}
}
}
return nil
}
func resolveGeminiAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.GeminiKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.GeminiKey, auth)
}
func resolveClaudeAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.ClaudeKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.ClaudeKey, auth)
}
func resolveCodexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.CodexKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.CodexKey, auth)
}
func resolveVertexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.VertexCompatKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.VertexCompatAPIKey, auth)
}
func resolveUpstreamModelForGeminiAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveGeminiAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForClaudeAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveClaudeAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForCodexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveCodexAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForVertexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveVertexAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForOpenAICompatAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
providerKey := ""
compatName := ""
if auth != nil && len(auth.Attributes) > 0 {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
if compatName == "" && !strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return ""
}
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
type apiKeyModelAliasTable map[string]map[string]string
func resolveOpenAICompatConfig(cfg *internalconfig.Config, providerKey, compatName, authProvider string) *internalconfig.OpenAICompatibility {
if cfg == nil {
return nil
}
candidates := make([]string, 0, 3)
if v := strings.TrimSpace(compatName); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(providerKey); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(authProvider); v != "" {
candidates = append(candidates, v)
}
for i := range cfg.OpenAICompatibility {
compat := &cfg.OpenAICompatibility[i]
for _, candidate := range candidates {
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
return compat
}
}
}
return nil
}
func asModelAliasEntries[T interface {
GetName() string
GetAlias() string
}](models []T) []modelAliasEntry {
if len(models) == 0 {
return nil
}
out := make([]modelAliasEntry, 0, len(models))
for i := range models {
out = append(out, models[i])
}
return out
}
func (m *Manager) normalizeProviders(providers []string) []string {
if len(providers) == 0 {
return nil
}
result := make([]string, 0, len(providers))
seen := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
if p == "" {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
result = append(result, p)
}
return result
}
func (m *Manager) retrySettings() (int, int, time.Duration) {
if m == nil {
return 0, 0, 0
}
return int(m.requestRetry.Load()), int(m.maxRetryCredentials.Load()), time.Duration(m.maxRetryInterval.Load())
}
func (m *Manager) closestCooldownWait(providers []string, model string, attempt int) (time.Duration, bool) {
if m == nil || len(providers) == 0 {
return 0, false
}
now := time.Now()
defaultRetry := int(m.requestRetry.Load())
if defaultRetry < 0 {
defaultRetry = 0
}
providerSet := make(map[string]struct{}, len(providers))
for i := range providers {
key := strings.TrimSpace(strings.ToLower(providers[i]))
if key == "" {
continue
}
providerSet[key] = struct{}{}
}
m.mu.RLock()
defer m.mu.RUnlock()
var (
found bool
minWait time.Duration
)
for _, auth := range m.auths {
if auth == nil {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
if _, ok := providerSet[providerKey]; !ok {
continue
}
effectiveRetry := defaultRetry
if override, ok := auth.RequestRetryOverride(); ok {
effectiveRetry = override
}
if effectiveRetry < 0 {
effectiveRetry = 0
}
if attempt >= effectiveRetry {
continue
}
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
if !blocked || next.IsZero() || reason == blockReasonDisabled {
continue
}
wait := next.Sub(now)
if wait < 0 {
continue
}
if !found || wait < minWait {
minWait = wait
found = true
}
}
return minWait, found
}
func (m *Manager) shouldRetryAfterError(err error, attempt int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
if err == nil {
return 0, false
}
if maxWait <= 0 {
return 0, false
}
if status := statusCodeFromError(err); status == http.StatusOK {
return 0, false
}
if isRequestInvalidError(err) {
return 0, false
}
wait, found := m.closestCooldownWait(providers, model, attempt)
if !found || wait > maxWait {
return 0, false
}
return wait, true
}
func waitForCooldown(ctx context.Context, wait time.Duration) error {
if wait <= 0 {
return nil
}
timer := time.NewTimer(wait)
defer timer.Stop()
select {
case <-ctx.Done():
return ctx.Err()
case <-timer.C:
return nil
}
}
// MarkResult records an execution result and notifies hooks.
func (m *Manager) MarkResult(ctx context.Context, result Result) {
if result.AuthID == "" {
return
}
shouldResumeModel := false
shouldSuspendModel := false
suspendReason := ""
clearModelQuota := false
setModelQuota := false
var authSnapshot *Auth
m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
now := time.Now()
if result.Success {
if result.Model != "" {
state := ensureModelState(auth, result.Model)
resetModelState(state, now)
updateAggregatedAvailability(auth, now)
if !hasModelError(auth, now) {
auth.LastError = nil
auth.StatusMessage = ""
auth.Status = StatusActive
}
auth.UpdatedAt = now
shouldResumeModel = true
clearModelQuota = true
} else {
clearAuthStateOnSuccess(auth, now)
}
} else {
if result.Model != "" {
state := ensureModelState(auth, result.Model)
state.Unavailable = true
state.Status = StatusError
state.UpdatedAt = now
if result.Error != nil {
state.LastError = cloneError(result.Error)
state.StatusMessage = result.Error.Message
auth.LastError = cloneError(result.Error)
auth.StatusMessage = result.Error.Message
}
statusCode := statusCodeFromResult(result.Error)
switch statusCode {
case 401:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "unauthorized"
shouldSuspendModel = true
case 402, 403:
next := now.Add(30 * time.Minute)
state.NextRetryAfter = next
suspendReason = "payment_required"
shouldSuspendModel = true
case 404:
next := now.Add(12 * time.Hour)
state.NextRetryAfter = next
suspendReason = "not_found"
shouldSuspendModel = true
case 429:
var next time.Time
backoffLevel := state.Quota.BackoffLevel
if result.RetryAfter != nil {
next = now.Add(*result.RetryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(backoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
backoffLevel = nextLevel
}
state.NextRetryAfter = next
state.Quota = QuotaState{
Exceeded: true,
Reason: "quota",
NextRecoverAt: next,
BackoffLevel: backoffLevel,
}
suspendReason = "quota"
shouldSuspendModel = true
setModelQuota = true
case 408, 500, 502, 503, 504:
if quotaCooldownDisabledForAuth(auth) {
state.NextRetryAfter = time.Time{}
} else {
next := now.Add(1 * time.Minute)
state.NextRetryAfter = next
}
default:
state.NextRetryAfter = time.Time{}
}
auth.Status = StatusError
auth.UpdatedAt = now
updateAggregatedAvailability(auth, now)
} else {
applyAuthFailureState(auth, result.Error, result.RetryAfter, now)
}
}
_ = m.persist(ctx, auth)
authSnapshot = auth.Clone()
}
m.mu.Unlock()
if m.scheduler != nil && authSnapshot != nil {
m.scheduler.upsertAuth(authSnapshot)
}
if clearModelQuota && result.Model != "" {
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
}
if setModelQuota && result.Model != "" {
registry.GetGlobalRegistry().SetModelQuotaExceeded(result.AuthID, result.Model)
}
if shouldResumeModel {
registry.GetGlobalRegistry().ResumeClientModel(result.AuthID, result.Model)
} else if shouldSuspendModel {
registry.GetGlobalRegistry().SuspendClientModel(result.AuthID, result.Model, suspendReason)
}
m.hook.OnResult(ctx, result)
}
func ensureModelState(auth *Auth, model string) *ModelState {
if auth == nil || model == "" {
return nil
}
if auth.ModelStates == nil {
auth.ModelStates = make(map[string]*ModelState)
}
if state, ok := auth.ModelStates[model]; ok && state != nil {
return state
}
state := &ModelState{Status: StatusActive}
auth.ModelStates[model] = state
return state
}
func resetModelState(state *ModelState, now time.Time) {
if state == nil {
return
}
state.Unavailable = false
state.Status = StatusActive
state.StatusMessage = ""
state.NextRetryAfter = time.Time{}
state.LastError = nil
state.Quota = QuotaState{}
state.UpdatedAt = now
}
func updateAggregatedAvailability(auth *Auth, now time.Time) {
if auth == nil || len(auth.ModelStates) == 0 {
return
}
allUnavailable := true
earliestRetry := time.Time{}
quotaExceeded := false
quotaRecover := time.Time{}
maxBackoffLevel := 0
for _, state := range auth.ModelStates {
if state == nil {
continue
}
stateUnavailable := false
if state.Status == StatusDisabled {
stateUnavailable = true
} else if state.Unavailable {
if state.NextRetryAfter.IsZero() {
stateUnavailable = false
} else if state.NextRetryAfter.After(now) {
stateUnavailable = true
if earliestRetry.IsZero() || state.NextRetryAfter.Before(earliestRetry) {
earliestRetry = state.NextRetryAfter
}
} else {
state.Unavailable = false
state.NextRetryAfter = time.Time{}
}
}
if !stateUnavailable {
allUnavailable = false
}
if state.Quota.Exceeded {
quotaExceeded = true
if quotaRecover.IsZero() || (!state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.Before(quotaRecover)) {
quotaRecover = state.Quota.NextRecoverAt
}
if state.Quota.BackoffLevel > maxBackoffLevel {
maxBackoffLevel = state.Quota.BackoffLevel
}
}
}
auth.Unavailable = allUnavailable
if allUnavailable {
auth.NextRetryAfter = earliestRetry
} else {
auth.NextRetryAfter = time.Time{}
}
if quotaExceeded {
auth.Quota.Exceeded = true
auth.Quota.Reason = "quota"
auth.Quota.NextRecoverAt = quotaRecover
auth.Quota.BackoffLevel = maxBackoffLevel
} else {
auth.Quota.Exceeded = false
auth.Quota.Reason = ""
auth.Quota.NextRecoverAt = time.Time{}
auth.Quota.BackoffLevel = 0
}
}
func hasModelError(auth *Auth, now time.Time) bool {
if auth == nil || len(auth.ModelStates) == 0 {
return false
}
for _, state := range auth.ModelStates {
if state == nil {
continue
}
if state.LastError != nil {
return true
}
if state.Status == StatusError {
if state.Unavailable && (state.NextRetryAfter.IsZero() || state.NextRetryAfter.After(now)) {
return true
}
}
}
return false
}
func clearAuthStateOnSuccess(auth *Auth, now time.Time) {
if auth == nil {
return
}
auth.Unavailable = false
auth.Status = StatusActive
auth.StatusMessage = ""
auth.Quota.Exceeded = false
auth.Quota.Reason = ""
auth.Quota.NextRecoverAt = time.Time{}
auth.Quota.BackoffLevel = 0
auth.LastError = nil
auth.NextRetryAfter = time.Time{}
auth.UpdatedAt = now
}
func cloneError(err *Error) *Error {
if err == nil {
return nil
}
return &Error{
Code: err.Code,
Message: err.Message,
Retryable: err.Retryable,
HTTPStatus: err.HTTPStatus,
}
}
func statusCodeFromError(err error) int {
if err == nil {
return 0
}
type statusCoder interface {
StatusCode() int
}
var sc statusCoder
if errors.As(err, &sc) && sc != nil {
return sc.StatusCode()
}
return 0
}
func retryAfterFromError(err error) *time.Duration {
if err == nil {
return nil
}
type retryAfterProvider interface {
RetryAfter() *time.Duration
}
rap, ok := err.(retryAfterProvider)
if !ok || rap == nil {
return nil
}
retryAfter := rap.RetryAfter()
if retryAfter == nil {
return nil
}
return new(*retryAfter)
}
func statusCodeFromResult(err *Error) int {
if err == nil {
return 0
}
return err.StatusCode()
}
// isRequestInvalidError returns true if the error represents a client request
// error that should not be retried. Specifically, it treats 400 responses with
// "invalid_request_error" and all 422 responses as request-shape failures,
// where switching auths or pooled upstream models will not help.
func isRequestInvalidError(err error) bool {
if err == nil {
return false
}
status := statusCodeFromError(err)
switch status {
case http.StatusBadRequest:
return strings.Contains(err.Error(), "invalid_request_error")
case http.StatusUnprocessableEntity:
return true
default:
return false
}
}
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
if auth == nil {
return
}
auth.Unavailable = true
auth.Status = StatusError
auth.UpdatedAt = now
if resultErr != nil {
auth.LastError = cloneError(resultErr)
if resultErr.Message != "" {
auth.StatusMessage = resultErr.Message
}
}
statusCode := statusCodeFromResult(resultErr)
switch statusCode {
case 401:
auth.StatusMessage = "unauthorized"
auth.NextRetryAfter = now.Add(30 * time.Minute)
case 402, 403:
auth.StatusMessage = "payment_required"
auth.NextRetryAfter = now.Add(30 * time.Minute)
case 404:
auth.StatusMessage = "not_found"
auth.NextRetryAfter = now.Add(12 * time.Hour)
case 429:
auth.StatusMessage = "quota exhausted"
auth.Quota.Exceeded = true
auth.Quota.Reason = "quota"
var next time.Time
if retryAfter != nil {
next = now.Add(*retryAfter)
} else {
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel, quotaCooldownDisabledForAuth(auth))
if cooldown > 0 {
next = now.Add(cooldown)
}
auth.Quota.BackoffLevel = nextLevel
}
auth.Quota.NextRecoverAt = next
auth.NextRetryAfter = next
case 408, 500, 502, 503, 504:
auth.StatusMessage = "transient upstream error"
if quotaCooldownDisabledForAuth(auth) {
auth.NextRetryAfter = time.Time{}
} else {
auth.NextRetryAfter = now.Add(1 * time.Minute)
}
default:
if auth.StatusMessage == "" {
auth.StatusMessage = "request failed"
}
}
}
// nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors.
func nextQuotaCooldown(prevLevel int, disableCooling bool) (time.Duration, int) {
if prevLevel < 0 {
prevLevel = 0
}
if disableCooling {
return 0, prevLevel
}
cooldown := quotaBackoffBase * time.Duration(1<= quotaBackoffMax {
return quotaBackoffMax, prevLevel
}
return cooldown, prevLevel + 1
}
// List returns all auth entries currently known by the manager.
func (m *Manager) List() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
list := make([]*Auth, 0, len(m.auths))
for _, auth := range m.auths {
list = append(list, auth.Clone())
}
return list
}
// GetByID retrieves an auth entry by its ID.
func (m *Manager) GetByID(id string) (*Auth, bool) {
if id == "" {
return nil, false
}
m.mu.RLock()
defer m.mu.RUnlock()
auth, ok := m.auths[id]
if !ok {
return nil, false
}
return auth.Clone(), true
}
// Executor returns the registered provider executor for a provider key.
func (m *Manager) Executor(provider string) (ProviderExecutor, bool) {
if m == nil {
return nil, false
}
provider = strings.TrimSpace(provider)
if provider == "" {
return nil, false
}
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
lowerProvider := strings.ToLower(provider)
if lowerProvider != provider {
executor, okExecutor = m.executors[lowerProvider]
}
}
m.mu.RUnlock()
if !okExecutor || executor == nil {
return nil, false
}
return executor, true
}
// CloseExecutionSession asks all registered executors to release the supplied execution session.
func (m *Manager) CloseExecutionSession(sessionID string) {
sessionID = strings.TrimSpace(sessionID)
if m == nil || sessionID == "" {
return
}
m.mu.RLock()
executors := make([]ProviderExecutor, 0, len(m.executors))
for _, exec := range m.executors {
executors = append(executors, exec)
}
m.mu.RUnlock()
for i := range executors {
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(sessionID)
}
}
}
func (m *Manager) useSchedulerFastPath() bool {
if m == nil || m.scheduler == nil {
return false
}
return isBuiltInSelector(m.selector)
}
func shouldRetrySchedulerPick(err error) bool {
if err == nil {
return false
}
var cooldownErr *modelCooldownError
if errors.As(err, &cooldownErr) {
return true
}
var authErr *Error
if !errors.As(err, &authErr) || authErr == nil {
return false
}
return authErr.Code == "auth_not_found" || authErr.Code == "auth_unavailable"
}
func (m *Manager) pickNextLegacy(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
m.mu.RUnlock()
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
candidates := make([]*Auth, 0, len(m.auths))
modelKey := strings.TrimSpace(model)
// Always use base model name (without thinking suffix) for auth matching.
if modelKey != "" {
parsed := thinking.ParseSuffix(modelKey)
if parsed.ModelName != "" {
modelKey = strings.TrimSpace(parsed.ModelName)
}
}
registryRef := registry.GetGlobalRegistry()
for _, candidate := range m.auths {
if candidate.Provider != provider || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
continue
}
candidates = append(candidates, candidate)
}
if len(candidates) == 0 {
m.mu.RUnlock()
return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, provider, model, opts, candidates)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, errPick
}
if selected == nil {
m.mu.RUnlock()
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
authCopy := selected.Clone()
m.mu.RUnlock()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, nil
}
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
if !m.useSchedulerFastPath() {
return m.pickNextLegacy(ctx, provider, model, opts, tried)
}
executor, okExecutor := m.Executor(provider)
if !okExecutor {
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
selected, errPick := m.scheduler.pickSingle(ctx, provider, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, errPick = m.scheduler.pickSingle(ctx, provider, model, opts, tried)
}
if errPick != nil {
return nil, nil, errPick
}
if selected == nil {
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, nil
}
func (m *Manager) pickNextMixedLegacy(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
providerSet := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
if p == "" {
continue
}
providerSet[p] = struct{}{}
}
if len(providerSet) == 0 {
return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
m.mu.RLock()
candidates := make([]*Auth, 0, len(m.auths))
modelKey := strings.TrimSpace(model)
// Always use base model name (without thinking suffix) for auth matching.
if modelKey != "" {
parsed := thinking.ParseSuffix(modelKey)
if parsed.ModelName != "" {
modelKey = strings.TrimSpace(parsed.ModelName)
}
}
registryRef := registry.GetGlobalRegistry()
for _, candidate := range m.auths {
if candidate == nil || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
if providerKey == "" {
continue
}
if _, ok := providerSet[providerKey]; !ok {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
if _, ok := m.executors[providerKey]; !ok {
continue
}
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
continue
}
candidates = append(candidates, candidate)
}
if len(candidates) == 0 {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
if errPick != nil {
m.mu.RUnlock()
return nil, nil, "", errPick
}
if selected == nil {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
providerKey := strings.TrimSpace(strings.ToLower(selected.Provider))
executor, okExecutor := m.executors[providerKey]
if !okExecutor {
m.mu.RUnlock()
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
m.mu.RUnlock()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
if !m.useSchedulerFastPath() {
return m.pickNextMixedLegacy(ctx, providers, model, opts, tried)
}
eligibleProviders := make([]string, 0, len(providers))
seenProviders := make(map[string]struct{}, len(providers))
for _, provider := range providers {
providerKey := strings.TrimSpace(strings.ToLower(provider))
if providerKey == "" {
continue
}
if _, seen := seenProviders[providerKey]; seen {
continue
}
if _, okExecutor := m.Executor(providerKey); !okExecutor {
continue
}
seenProviders[providerKey] = struct{}{}
eligibleProviders = append(eligibleProviders, providerKey)
}
if len(eligibleProviders) == 0 {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
selected, providerKey, errPick := m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
if errPick != nil && model != "" && shouldRetrySchedulerPick(errPick) {
m.syncScheduler()
selected, providerKey, errPick = m.scheduler.pickMixed(ctx, eligibleProviders, model, opts, tried)
}
if errPick != nil {
return nil, nil, "", errPick
}
if selected == nil {
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
executor, okExecutor := m.Executor(providerKey)
if !okExecutor {
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
}
authCopy := selected.Clone()
if !selected.indexAssigned {
m.mu.Lock()
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
current.EnsureIndex()
authCopy = current.Clone()
}
m.mu.Unlock()
}
return authCopy, executor, providerKey, nil
}
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil {
return nil
}
if shouldSkipPersist(ctx) {
return nil
}
if auth.Attributes != nil {
if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" {
return nil
}
}
// Skip persistence when metadata is absent (e.g., runtime-only auths).
if auth.Metadata == nil {
return nil
}
_, err := m.store.Save(ctx, auth)
return err
}
// StartAutoRefresh launches a background loop that evaluates auth freshness
// every few seconds and triggers refresh operations when required.
// Only one loop is kept alive; starting a new one cancels the previous run.
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
if interval <= 0 {
interval = refreshCheckInterval
}
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
}
ctx, cancel := context.WithCancel(parent)
m.refreshCancel = cancel
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
m.checkRefreshes(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.checkRefreshes(ctx)
}
}
}()
}
// StopAutoRefresh cancels the background refresh loop, if running.
func (m *Manager) StopAutoRefresh() {
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
}
}
func (m *Manager) checkRefreshes(ctx context.Context) {
// log.Debugf("checking refreshes")
now := time.Now()
snapshot := m.snapshotAuths()
for _, a := range snapshot {
typ, _ := a.AccountInfo()
if typ != "api_key" {
if !m.shouldRefresh(a, now) {
continue
}
log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ)
if exec := m.executorFor(a.Provider); exec == nil {
continue
}
if !m.markRefreshPending(a.ID, now) {
continue
}
go m.refreshAuthWithLimit(ctx, a.ID)
}
}
}
func (m *Manager) refreshAuthWithLimit(ctx context.Context, id string) {
if m.refreshSemaphore == nil {
m.refreshAuth(ctx, id)
return
}
select {
case m.refreshSemaphore <- struct{}{}:
defer func() { <-m.refreshSemaphore }()
case <-ctx.Done():
return
}
m.refreshAuth(ctx, id)
}
func (m *Manager) snapshotAuths() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
out := make([]*Auth, 0, len(m.auths))
for _, a := range m.auths {
out = append(out, a.Clone())
}
return out
}
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
if a == nil || a.Disabled {
return false
}
if !a.NextRefreshAfter.IsZero() && now.Before(a.NextRefreshAfter) {
return false
}
if evaluator, ok := a.Runtime.(RefreshEvaluator); ok && evaluator != nil {
return evaluator.ShouldRefresh(now, a)
}
lastRefresh := a.LastRefreshedAt
if lastRefresh.IsZero() {
if ts, ok := authLastRefreshTimestamp(a); ok {
lastRefresh = ts
}
}
expiry, hasExpiry := a.ExpirationTime()
if interval := authPreferredInterval(a); interval > 0 {
if hasExpiry && !expiry.IsZero() {
if !expiry.After(now) {
return true
}
if expiry.Sub(now) <= interval {
return true
}
}
if lastRefresh.IsZero() {
return true
}
return now.Sub(lastRefresh) >= interval
}
provider := strings.ToLower(a.Provider)
lead := ProviderRefreshLead(provider, a.Runtime)
if lead == nil {
return false
}
if *lead <= 0 {
if hasExpiry && !expiry.IsZero() {
return now.After(expiry)
}
return false
}
if hasExpiry && !expiry.IsZero() {
return time.Until(expiry) <= *lead
}
if !lastRefresh.IsZero() {
return now.Sub(lastRefresh) >= *lead
}
return true
}
func authPreferredInterval(a *Auth) time.Duration {
if a == nil {
return 0
}
if d := durationFromMetadata(a.Metadata, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
return d
}
if d := durationFromAttributes(a.Attributes, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
return d
}
return 0
}
func durationFromMetadata(meta map[string]any, keys ...string) time.Duration {
if len(meta) == 0 {
return 0
}
for _, key := range keys {
if val, ok := meta[key]; ok {
if dur := parseDurationValue(val); dur > 0 {
return dur
}
}
}
return 0
}
func durationFromAttributes(attrs map[string]string, keys ...string) time.Duration {
if len(attrs) == 0 {
return 0
}
for _, key := range keys {
if val, ok := attrs[key]; ok {
if dur := parseDurationString(val); dur > 0 {
return dur
}
}
}
return 0
}
func parseDurationValue(val any) time.Duration {
switch v := val.(type) {
case time.Duration:
if v <= 0 {
return 0
}
return v
case int:
if v <= 0 {
return 0
}
return time.Duration(v) * time.Second
case int32:
if v <= 0 {
return 0
}
return time.Duration(v) * time.Second
case int64:
if v <= 0 {
return 0
}
return time.Duration(v) * time.Second
case uint:
if v == 0 {
return 0
}
return time.Duration(v) * time.Second
case uint32:
if v == 0 {
return 0
}
return time.Duration(v) * time.Second
case uint64:
if v == 0 {
return 0
}
return time.Duration(v) * time.Second
case float32:
if v <= 0 {
return 0
}
return time.Duration(float64(v) * float64(time.Second))
case float64:
if v <= 0 {
return 0
}
return time.Duration(v * float64(time.Second))
case json.Number:
if i, err := v.Int64(); err == nil {
if i <= 0 {
return 0
}
return time.Duration(i) * time.Second
}
if f, err := v.Float64(); err == nil && f > 0 {
return time.Duration(f * float64(time.Second))
}
case string:
return parseDurationString(v)
}
return 0
}
func parseDurationString(raw string) time.Duration {
s := strings.TrimSpace(raw)
if s == "" {
return 0
}
if dur, err := time.ParseDuration(s); err == nil && dur > 0 {
return dur
}
if secs, err := strconv.ParseFloat(s, 64); err == nil && secs > 0 {
return time.Duration(secs * float64(time.Second))
}
return 0
}
func authLastRefreshTimestamp(a *Auth) (time.Time, bool) {
if a == nil {
return time.Time{}, false
}
if a.Metadata != nil {
if ts, ok := lookupMetadataTime(a.Metadata, "last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"); ok {
return ts, true
}
}
if a.Attributes != nil {
for _, key := range []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} {
if val := strings.TrimSpace(a.Attributes[key]); val != "" {
if ts, ok := parseTimeValue(val); ok {
return ts, true
}
}
}
}
return time.Time{}, false
}
func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
for _, key := range keys {
if val, ok := meta[key]; ok {
if ts, ok1 := parseTimeValue(val); ok1 {
return ts, true
}
}
}
return time.Time{}, false
}
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
m.mu.Lock()
defer m.mu.Unlock()
auth, ok := m.auths[id]
if !ok || auth == nil || auth.Disabled {
return false
}
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
return false
}
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
m.auths[id] = auth
return true
}
func (m *Manager) refreshAuth(ctx context.Context, id string) {
if ctx == nil {
ctx = context.Background()
}
m.mu.RLock()
auth := m.auths[id]
var exec ProviderExecutor
if auth != nil {
exec = m.executors[auth.Provider]
}
m.mu.RUnlock()
if auth == nil || exec == nil {
return
}
cloned := auth.Clone()
updated, err := exec.Refresh(ctx, cloned)
if err != nil && errors.Is(err, context.Canceled) {
log.Debugf("refresh canceled for %s, %s", auth.Provider, auth.ID)
return
}
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
now := time.Now()
if err != nil {
m.mu.Lock()
if current := m.auths[id]; current != nil {
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
current.LastError = &Error{Message: err.Error()}
m.auths[id] = current
if m.scheduler != nil {
m.scheduler.upsertAuth(current.Clone())
}
}
m.mu.Unlock()
return
}
if updated == nil {
updated = cloned
}
// Preserve runtime created by the executor during Refresh.
// If executor didn't set one, fall back to the previous runtime.
if updated.Runtime == nil {
updated.Runtime = auth.Runtime
}
updated.LastRefreshedAt = now
updated.NextRefreshAfter = time.Time{}
updated.LastError = nil
updated.UpdatedAt = now
_, _ = m.Update(ctx, updated)
}
func (m *Manager) executorFor(provider string) ProviderExecutor {
m.mu.RLock()
defer m.mu.RUnlock()
return m.executors[provider]
}
// roundTripperContextKey is an unexported context key type to avoid collisions.
type roundTripperContextKey struct{}
// roundTripperFor retrieves an HTTP RoundTripper for the given auth if a provider is registered.
func (m *Manager) roundTripperFor(auth *Auth) http.RoundTripper {
m.mu.RLock()
p := m.rtProvider
m.mu.RUnlock()
if p == nil || auth == nil {
return nil
}
return p.RoundTripperFor(auth)
}
// RoundTripperProvider defines a minimal provider of per-auth HTTP transports.
type RoundTripperProvider interface {
RoundTripperFor(auth *Auth) http.RoundTripper
}
// RequestPreparer is an optional interface that provider executors can implement
// to mutate outbound HTTP requests with provider credentials.
type RequestPreparer interface {
PrepareRequest(req *http.Request, auth *Auth) error
}
func executorKeyFromAuth(auth *Auth) string {
if auth == nil {
return ""
}
if auth.Attributes != nil {
providerKey := strings.TrimSpace(auth.Attributes["provider_key"])
compatName := strings.TrimSpace(auth.Attributes["compat_name"])
if compatName != "" {
if providerKey == "" {
providerKey = compatName
}
return strings.ToLower(providerKey)
}
}
return strings.ToLower(strings.TrimSpace(auth.Provider))
}
// logEntryWithRequestID returns a logrus entry with request_id field if available in context.
func logEntryWithRequestID(ctx context.Context) *log.Entry {
if ctx == nil {
return log.NewEntry(log.StandardLogger())
}
if reqID := logging.GetRequestID(ctx); reqID != "" {
return log.WithField("request_id", reqID)
}
return log.NewEntry(log.StandardLogger())
}
func debugLogAuthSelection(entry *log.Entry, auth *Auth, provider string, model string) {
if !log.IsLevelEnabled(log.DebugLevel) {
return
}
if entry == nil || auth == nil {
return
}
accountType, accountInfo := auth.AccountInfo()
proxyInfo := auth.ProxyInfo()
suffix := ""
if proxyInfo != "" {
suffix = " " + proxyInfo
}
switch accountType {
case "api_key":
entry.Debugf("Use API key %s for model %s%s", util.HideAPIKey(accountInfo), model, suffix)
case "oauth":
ident := formatOauthIdentity(auth, provider, accountInfo)
entry.Debugf("Use OAuth %s for model %s%s", ident, model, suffix)
}
}
func formatOauthIdentity(auth *Auth, provider string, accountInfo string) string {
if auth == nil {
return ""
}
// Prefer the auth's provider when available.
providerName := strings.TrimSpace(auth.Provider)
if providerName == "" {
providerName = strings.TrimSpace(provider)
}
// Only log the basename to avoid leaking host paths.
// FileName may be unset for some auth backends; fall back to ID.
authFile := strings.TrimSpace(auth.FileName)
if authFile == "" {
authFile = strings.TrimSpace(auth.ID)
}
if authFile != "" {
authFile = filepath.Base(authFile)
}
parts := make([]string, 0, 3)
if providerName != "" {
parts = append(parts, "provider="+providerName)
}
if authFile != "" {
parts = append(parts, "auth_file="+authFile)
}
if len(parts) == 0 {
return accountInfo
}
return strings.Join(parts, " ")
}
// InjectCredentials delegates per-provider HTTP request preparation when supported.
// If the registered executor for the auth provider implements RequestPreparer,
// it will be invoked to modify the request (e.g., add headers).
func (m *Manager) InjectCredentials(req *http.Request, authID string) error {
if req == nil || authID == "" {
return nil
}
m.mu.RLock()
a := m.auths[authID]
var exec ProviderExecutor
if a != nil {
exec = m.executors[executorKeyFromAuth(a)]
}
m.mu.RUnlock()
if a == nil || exec == nil {
return nil
}
if p, ok := exec.(RequestPreparer); ok && p != nil {
return p.PrepareRequest(req, a)
}
return nil
}
// PrepareHttpRequest injects provider credentials into the supplied HTTP request.
func (m *Manager) PrepareHttpRequest(ctx context.Context, auth *Auth, req *http.Request) error {
if m == nil {
return &Error{Code: "provider_not_found", Message: "manager is nil"}
}
if auth == nil {
return &Error{Code: "auth_not_found", Message: "auth is nil"}
}
if req == nil {
return &Error{Code: "invalid_request", Message: "http request is nil"}
}
if ctx != nil {
*req = *req.WithContext(ctx)
}
providerKey := executorKeyFromAuth(auth)
if providerKey == "" {
return &Error{Code: "provider_not_found", Message: "auth provider is empty"}
}
exec := m.executorFor(providerKey)
if exec == nil {
return &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey}
}
preparer, ok := exec.(RequestPreparer)
if !ok || preparer == nil {
return &Error{Code: "not_supported", Message: "executor does not support http request preparation"}
}
return preparer.PrepareRequest(req, auth)
}
// NewHttpRequest constructs a new HTTP request and injects provider credentials into it.
func (m *Manager) NewHttpRequest(ctx context.Context, auth *Auth, method, targetURL string, body []byte, headers http.Header) (*http.Request, error) {
if ctx == nil {
ctx = context.Background()
}
method = strings.TrimSpace(method)
if method == "" {
method = http.MethodGet
}
var reader io.Reader
if body != nil {
reader = bytes.NewReader(body)
}
httpReq, err := http.NewRequestWithContext(ctx, method, targetURL, reader)
if err != nil {
return nil, err
}
if headers != nil {
httpReq.Header = headers.Clone()
}
if errPrepare := m.PrepareHttpRequest(ctx, auth, httpReq); errPrepare != nil {
return nil, errPrepare
}
return httpReq, nil
}
// HttpRequest injects provider credentials into the supplied HTTP request and executes it.
func (m *Manager) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
if m == nil {
return nil, &Error{Code: "provider_not_found", Message: "manager is nil"}
}
if auth == nil {
return nil, &Error{Code: "auth_not_found", Message: "auth is nil"}
}
if req == nil {
return nil, &Error{Code: "invalid_request", Message: "http request is nil"}
}
providerKey := executorKeyFromAuth(auth)
if providerKey == "" {
return nil, &Error{Code: "provider_not_found", Message: "auth provider is empty"}
}
exec := m.executorFor(providerKey)
if exec == nil {
return nil, &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey}
}
return exec.HttpRequest(ctx, auth, req)
}
================================================
FILE: sdk/cliproxy/auth/conductor_availability_test.go
================================================
package auth
import (
"testing"
"time"
)
func TestUpdateAggregatedAvailability_UnavailableWithoutNextRetryDoesNotBlockAuth(t *testing.T) {
t.Parallel()
now := time.Now()
model := "test-model"
auth := &Auth{
ID: "a",
ModelStates: map[string]*ModelState{
model: {
Status: StatusError,
Unavailable: true,
},
},
}
updateAggregatedAvailability(auth, now)
if auth.Unavailable {
t.Fatalf("auth.Unavailable = true, want false")
}
if !auth.NextRetryAfter.IsZero() {
t.Fatalf("auth.NextRetryAfter = %v, want zero", auth.NextRetryAfter)
}
}
func TestUpdateAggregatedAvailability_FutureNextRetryBlocksAuth(t *testing.T) {
t.Parallel()
now := time.Now()
model := "test-model"
next := now.Add(5 * time.Minute)
auth := &Auth{
ID: "a",
ModelStates: map[string]*ModelState{
model: {
Status: StatusError,
Unavailable: true,
NextRetryAfter: next,
},
},
}
updateAggregatedAvailability(auth, now)
if !auth.Unavailable {
t.Fatalf("auth.Unavailable = false, want true")
}
if auth.NextRetryAfter.IsZero() {
t.Fatalf("auth.NextRetryAfter = zero, want %v", next)
}
if auth.NextRetryAfter.Sub(next) > time.Second || next.Sub(auth.NextRetryAfter) > time.Second {
t.Fatalf("auth.NextRetryAfter = %v, want %v", auth.NextRetryAfter, next)
}
}
================================================
FILE: sdk/cliproxy/auth/conductor_executor_replace_test.go
================================================
package auth
import (
"context"
"net/http"
"sync"
"testing"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type replaceAwareExecutor struct {
id string
mu sync.Mutex
closedSessionIDs []string
}
func (e *replaceAwareExecutor) Identifier() string {
return e.id
}
func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
ch := make(chan cliproxyexecutor.StreamChunk)
close(ch)
return &cliproxyexecutor.StreamResult{Chunks: ch}, nil
}
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *replaceAwareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e *replaceAwareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
return nil, nil
}
func (e *replaceAwareExecutor) CloseExecutionSession(sessionID string) {
e.mu.Lock()
defer e.mu.Unlock()
e.closedSessionIDs = append(e.closedSessionIDs, sessionID)
}
func (e *replaceAwareExecutor) ClosedSessionIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.closedSessionIDs))
copy(out, e.closedSessionIDs)
return out
}
func TestManagerRegisterExecutorClosesReplacedExecutionSessions(t *testing.T) {
t.Parallel()
manager := NewManager(nil, nil, nil)
replaced := &replaceAwareExecutor{id: "codex"}
current := &replaceAwareExecutor{id: "codex"}
manager.RegisterExecutor(replaced)
manager.RegisterExecutor(current)
closed := replaced.ClosedSessionIDs()
if len(closed) != 1 {
t.Fatalf("expected replaced executor close calls = 1, got %d", len(closed))
}
if closed[0] != CloseAllExecutionSessionsID {
t.Fatalf("expected close marker %q, got %q", CloseAllExecutionSessionsID, closed[0])
}
if len(current.ClosedSessionIDs()) != 0 {
t.Fatalf("expected current executor to stay open")
}
}
func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
t.Parallel()
manager := NewManager(nil, nil, nil)
current := &replaceAwareExecutor{id: "codex"}
manager.RegisterExecutor(current)
resolved, okResolved := manager.Executor("CODEX")
if !okResolved {
t.Fatal("expected registered executor to be found")
}
resolvedExecutor, okResolvedExecutor := resolved.(*replaceAwareExecutor)
if !okResolvedExecutor {
t.Fatalf("expected resolved executor type %T, got %T", current, resolved)
}
if resolvedExecutor != current {
t.Fatal("expected resolved executor to match registered executor")
}
_, okMissing := manager.Executor("unknown")
if okMissing {
t.Fatal("expected unknown provider lookup to fail")
}
}
================================================
FILE: sdk/cliproxy/auth/conductor_overrides_test.go
================================================
package auth
import (
"context"
"net/http"
"sync"
"testing"
"time"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
func TestManager_ShouldRetryAfterError_RespectsAuthRequestRetryOverride(t *testing.T) {
m := NewManager(nil, nil, nil)
m.SetRetryConfig(3, 30*time.Second, 0)
model := "test-model"
next := time.Now().Add(5 * time.Second)
auth := &Auth{
ID: "auth-1",
Provider: "claude",
Metadata: map[string]any{
"request_retry": float64(0),
},
ModelStates: map[string]*ModelState{
model: {
Unavailable: true,
Status: StatusError,
NextRetryAfter: next,
},
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
_, _, maxWait := m.retrySettings()
wait, shouldRetry := m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
if shouldRetry {
t.Fatalf("expected shouldRetry=false for request_retry=0, got true (wait=%v)", wait)
}
auth.Metadata["request_retry"] = float64(1)
if _, errUpdate := m.Update(context.Background(), auth); errUpdate != nil {
t.Fatalf("update auth: %v", errUpdate)
}
wait, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 0, []string{"claude"}, model, maxWait)
if !shouldRetry {
t.Fatalf("expected shouldRetry=true for request_retry=1, got false")
}
if wait <= 0 {
t.Fatalf("expected wait > 0, got %v", wait)
}
_, shouldRetry = m.shouldRetryAfterError(&Error{HTTPStatus: 500, Message: "boom"}, 1, []string{"claude"}, model, maxWait)
if shouldRetry {
t.Fatalf("expected shouldRetry=false on attempt=1 for request_retry=1, got true")
}
}
type credentialRetryLimitExecutor struct {
id string
mu sync.Mutex
calls int
}
func (e *credentialRetryLimitExecutor) Identifier() string {
return e.id
}
func (e *credentialRetryLimitExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
e.recordCall()
return cliproxyexecutor.Response{}, &Error{HTTPStatus: 500, Message: "boom"}
}
func (e *credentialRetryLimitExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
e.recordCall()
return nil, &Error{HTTPStatus: 500, Message: "boom"}
}
func (e *credentialRetryLimitExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *credentialRetryLimitExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
e.recordCall()
return cliproxyexecutor.Response{}, &Error{HTTPStatus: 500, Message: "boom"}
}
func (e *credentialRetryLimitExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
return nil, nil
}
func (e *credentialRetryLimitExecutor) recordCall() {
e.mu.Lock()
defer e.mu.Unlock()
e.calls++
}
func (e *credentialRetryLimitExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
func newCredentialRetryLimitTestManager(t *testing.T, maxRetryCredentials int) (*Manager, *credentialRetryLimitExecutor) {
t.Helper()
m := NewManager(nil, nil, nil)
m.SetRetryConfig(0, 0, maxRetryCredentials)
executor := &credentialRetryLimitExecutor{id: "claude"}
m.RegisterExecutor(executor)
baseID := uuid.NewString()
auth1 := &Auth{ID: baseID + "-auth-1", Provider: "claude"}
auth2 := &Auth{ID: baseID + "-auth-2", Provider: "claude"}
// Auth selection requires that the global model registry knows each credential supports the model.
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth1.ID, "claude", []*registry.ModelInfo{{ID: "test-model"}})
reg.RegisterClient(auth2.ID, "claude", []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
reg.UnregisterClient(auth1.ID)
reg.UnregisterClient(auth2.ID)
})
if _, errRegister := m.Register(context.Background(), auth1); errRegister != nil {
t.Fatalf("register auth1: %v", errRegister)
}
if _, errRegister := m.Register(context.Background(), auth2); errRegister != nil {
t.Fatalf("register auth2: %v", errRegister)
}
return m, executor
}
func TestManager_MaxRetryCredentials_LimitsCrossCredentialRetries(t *testing.T) {
request := cliproxyexecutor.Request{Model: "test-model"}
testCases := []struct {
name string
invoke func(*Manager) error
}{
{
name: "execute",
invoke: func(m *Manager) error {
_, errExecute := m.Execute(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
return errExecute
},
},
{
name: "execute_count",
invoke: func(m *Manager) error {
_, errExecute := m.ExecuteCount(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
return errExecute
},
},
{
name: "execute_stream",
invoke: func(m *Manager) error {
_, errExecute := m.ExecuteStream(context.Background(), []string{"claude"}, request, cliproxyexecutor.Options{})
return errExecute
},
},
}
for _, tc := range testCases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
limitedManager, limitedExecutor := newCredentialRetryLimitTestManager(t, 1)
if errInvoke := tc.invoke(limitedManager); errInvoke == nil {
t.Fatalf("expected error for limited retry execution")
}
if calls := limitedExecutor.Calls(); calls != 1 {
t.Fatalf("expected 1 call with max-retry-credentials=1, got %d", calls)
}
unlimitedManager, unlimitedExecutor := newCredentialRetryLimitTestManager(t, 0)
if errInvoke := tc.invoke(unlimitedManager); errInvoke == nil {
t.Fatalf("expected error for unlimited retry execution")
}
if calls := unlimitedExecutor.Calls(); calls != 2 {
t.Fatalf("expected 2 calls with max-retry-credentials=0, got %d", calls)
}
})
}
}
func TestManager_MarkResult_RespectsAuthDisableCoolingOverride(t *testing.T) {
prev := quotaCooldownDisabled.Load()
quotaCooldownDisabled.Store(false)
t.Cleanup(func() { quotaCooldownDisabled.Store(prev) })
m := NewManager(nil, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "claude",
Metadata: map[string]any{
"disable_cooling": true,
},
}
if _, errRegister := m.Register(context.Background(), auth); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
model := "test-model"
m.MarkResult(context.Background(), Result{
AuthID: "auth-1",
Provider: "claude",
Model: model,
Success: false,
Error: &Error{HTTPStatus: 500, Message: "boom"},
})
updated, ok := m.GetByID("auth-1")
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
state := updated.ModelStates[model]
if state == nil {
t.Fatalf("expected model state to be present")
}
if !state.NextRetryAfter.IsZero() {
t.Fatalf("expected NextRetryAfter to be zero when disable_cooling=true, got %v", state.NextRetryAfter)
}
}
================================================
FILE: sdk/cliproxy/auth/conductor_scheduler_refresh_test.go
================================================
package auth
import (
"context"
"errors"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerProviderTestExecutor struct {
provider string
}
func (e schedulerProviderTestExecutor) Identifier() string { return e.provider }
func (e schedulerProviderTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerProviderTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerProviderTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerProviderTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func TestManager_RefreshSchedulerEntry_RebuildsSupportedModelSetAfterModelRegistration(t *testing.T) {
ctx := context.Background()
testCases := []struct {
name string
prime func(*Manager, *Auth) error
}{
{
name: "register",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
return errRegister
},
},
{
name: "update",
prime: func(manager *Manager, auth *Auth) error {
_, errRegister := manager.Register(ctx, auth)
if errRegister != nil {
return errRegister
}
updated := auth.Clone()
updated.Metadata = map[string]any{"updated": true}
_, errUpdate := manager.Update(ctx, updated)
return errUpdate
},
},
}
for _, testCase := range testCases {
testCase := testCase
t.Run(testCase.name, func(t *testing.T) {
manager := NewManager(nil, &RoundRobinSelector{}, nil)
auth := &Auth{
ID: "refresh-entry-" + testCase.name,
Provider: "gemini",
}
if errPrime := testCase.prime(manager, auth); errPrime != nil {
t.Fatalf("prime auth %s: %v", testCase.name, errPrime)
}
registerSchedulerModels(t, "gemini", "scheduler-refresh-model", auth.ID)
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
var authErr *Error
if !errors.As(errPick, &authErr) || authErr == nil {
t.Fatalf("pickSingle() before refresh error = %v, want auth_not_found", errPick)
}
if authErr.Code != "auth_not_found" {
t.Fatalf("pickSingle() before refresh code = %q, want %q", authErr.Code, "auth_not_found")
}
if got != nil {
t.Fatalf("pickSingle() before refresh auth = %v, want nil", got)
}
manager.RefreshSchedulerEntry(auth.ID)
got, errPick = manager.scheduler.pickSingle(ctx, "gemini", "scheduler-refresh-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() after refresh error = %v", errPick)
}
if got == nil || got.ID != auth.ID {
t.Fatalf("pickSingle() after refresh auth = %v, want %q", got, auth.ID)
}
})
}
}
func TestManager_PickNext_RebuildsSchedulerAfterModelCooldownError(t *testing.T) {
ctx := context.Background()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.RegisterExecutor(schedulerProviderTestExecutor{provider: "gemini"})
registerSchedulerModels(t, "gemini", "scheduler-cooldown-rebuild-model", "cooldown-stale-old")
oldAuth := &Auth{
ID: "cooldown-stale-old",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, oldAuth); errRegister != nil {
t.Fatalf("register old auth: %v", errRegister)
}
manager.MarkResult(ctx, Result{
AuthID: oldAuth.ID,
Provider: "gemini",
Model: "scheduler-cooldown-rebuild-model",
Success: false,
Error: &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"},
})
newAuth := &Auth{
ID: "cooldown-stale-new",
Provider: "gemini",
}
if _, errRegister := manager.Register(ctx, newAuth); errRegister != nil {
t.Fatalf("register new auth: %v", errRegister)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(newAuth.ID, "gemini", []*registry.ModelInfo{{ID: "scheduler-cooldown-rebuild-model"}})
t.Cleanup(func() {
reg.UnregisterClient(newAuth.ID)
})
got, errPick := manager.scheduler.pickSingle(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
var cooldownErr *modelCooldownError
if !errors.As(errPick, &cooldownErr) {
t.Fatalf("pickSingle() before sync error = %v, want modelCooldownError", errPick)
}
if got != nil {
t.Fatalf("pickSingle() before sync auth = %v, want nil", got)
}
got, executor, errPick := manager.pickNext(ctx, "gemini", "scheduler-cooldown-rebuild-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if executor == nil {
t.Fatal("pickNext() executor = nil")
}
if got == nil || got.ID != newAuth.ID {
t.Fatalf("pickNext() auth = %v, want %q", got, newAuth.ID)
}
}
================================================
FILE: sdk/cliproxy/auth/conductor_update_test.go
================================================
package auth
import (
"context"
"testing"
)
func TestManager_Update_PreservesModelStates(t *testing.T) {
m := NewManager(nil, nil, nil)
model := "test-model"
backoffLevel := 7
if _, errRegister := m.Register(context.Background(), &Auth{
ID: "auth-1",
Provider: "claude",
Metadata: map[string]any{"k": "v"},
ModelStates: map[string]*ModelState{
model: {
Quota: QuotaState{BackoffLevel: backoffLevel},
},
},
}); errRegister != nil {
t.Fatalf("register auth: %v", errRegister)
}
if _, errUpdate := m.Update(context.Background(), &Auth{
ID: "auth-1",
Provider: "claude",
Metadata: map[string]any{"k": "v2"},
}); errUpdate != nil {
t.Fatalf("update auth: %v", errUpdate)
}
updated, ok := m.GetByID("auth-1")
if !ok || updated == nil {
t.Fatalf("expected auth to be present")
}
if len(updated.ModelStates) == 0 {
t.Fatalf("expected ModelStates to be preserved")
}
state := updated.ModelStates[model]
if state == nil {
t.Fatalf("expected model state to be present")
}
if state.Quota.BackoffLevel != backoffLevel {
t.Fatalf("expected BackoffLevel to be %d, got %d", backoffLevel, state.Quota.BackoffLevel)
}
}
================================================
FILE: sdk/cliproxy/auth/errors.go
================================================
package auth
// Error describes an authentication related failure in a provider agnostic format.
type Error struct {
// Code is a short machine readable identifier.
Code string `json:"code,omitempty"`
// Message is a human readable description of the failure.
Message string `json:"message"`
// Retryable indicates whether a retry might fix the issue automatically.
Retryable bool `json:"retryable"`
// HTTPStatus optionally records an HTTP-like status code for the error.
HTTPStatus int `json:"http_status,omitempty"`
}
// Error implements the error interface.
func (e *Error) Error() string {
if e == nil {
return ""
}
if e.Code == "" {
return e.Message
}
return e.Code + ": " + e.Message
}
// StatusCode implements optional status accessor for manager decision making.
func (e *Error) StatusCode() int {
if e == nil {
return 0
}
return e.HTTPStatus
}
================================================
FILE: sdk/cliproxy/auth/oauth_model_alias.go
================================================
package auth
import (
"strings"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
)
type modelAliasEntry interface {
GetName() string
GetAlias() string
}
type oauthModelAliasTable struct {
// reverse maps channel -> alias (lower) -> original upstream model name.
reverse map[string]map[string]string
}
func compileOAuthModelAliasTable(aliases map[string][]internalconfig.OAuthModelAlias) *oauthModelAliasTable {
if len(aliases) == 0 {
return &oauthModelAliasTable{}
}
out := &oauthModelAliasTable{
reverse: make(map[string]map[string]string, len(aliases)),
}
for rawChannel, entries := range aliases {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(entries) == 0 {
continue
}
rev := make(map[string]string, len(entries))
for _, entry := range entries {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
aliasKey := strings.ToLower(alias)
if _, exists := rev[aliasKey]; exists {
continue
}
rev[aliasKey] = name
}
if len(rev) > 0 {
out.reverse[channel] = rev
}
}
if len(out.reverse) == 0 {
out.reverse = nil
}
return out
}
// SetOAuthModelAlias updates the OAuth model name alias table used during execution.
// The alias is applied per-auth channel to resolve the upstream model name while keeping the
// client-visible model name unchanged for translation/response formatting.
func (m *Manager) SetOAuthModelAlias(aliases map[string][]internalconfig.OAuthModelAlias) {
if m == nil {
return
}
table := compileOAuthModelAliasTable(aliases)
// atomic.Value requires non-nil store values.
if table == nil {
table = &oauthModelAliasTable{}
}
m.oauthModelAlias.Store(table)
}
// applyOAuthModelAlias resolves the upstream model from OAuth model alias.
// If an alias exists, the returned model is the upstream model.
func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string {
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
if upstreamModel == "" {
return requestedModel
}
return upstreamModel
}
func modelAliasLookupCandidates(requestedModel string) (thinking.SuffixResult, []string) {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return thinking.SuffixResult{}, nil
}
requestResult := thinking.ParseSuffix(requestedModel)
base := requestResult.ModelName
if base == "" {
base = requestedModel
}
candidates := []string{base}
if base != requestedModel {
candidates = append(candidates, requestedModel)
}
return requestResult, candidates
}
func preserveResolvedModelSuffix(resolved string, requestResult thinking.SuffixResult) string {
resolved = strings.TrimSpace(resolved)
if resolved == "" {
return ""
}
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
}
func resolveModelAliasPoolFromConfigModels(requestedModel string, models []modelAliasEntry) []string {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return nil
}
if len(models) == 0 {
return nil
}
requestResult, candidates := modelAliasLookupCandidates(requestedModel)
if len(candidates) == 0 {
return nil
}
out := make([]string, 0)
seen := make(map[string]struct{})
for i := range models {
name := strings.TrimSpace(models[i].GetName())
alias := strings.TrimSpace(models[i].GetAlias())
for _, candidate := range candidates {
if candidate == "" || alias == "" || !strings.EqualFold(alias, candidate) {
continue
}
resolved := candidate
if name != "" {
resolved = name
}
resolved = preserveResolvedModelSuffix(resolved, requestResult)
key := strings.ToLower(strings.TrimSpace(resolved))
if key == "" {
break
}
if _, exists := seen[key]; exists {
break
}
seen[key] = struct{}{}
out = append(out, resolved)
break
}
}
if len(out) > 0 {
return out
}
for i := range models {
name := strings.TrimSpace(models[i].GetName())
for _, candidate := range candidates {
if candidate == "" || name == "" || !strings.EqualFold(name, candidate) {
continue
}
return []string{preserveResolvedModelSuffix(name, requestResult)}
}
}
return nil
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
resolved := resolveModelAliasPoolFromConfigModels(requestedModel, models)
if len(resolved) > 0 {
return resolved[0]
}
return ""
}
// resolveOAuthUpstreamModel resolves the upstream model name from OAuth model alias.
// If an alias exists, returns the original (upstream) model name that corresponds
// to the requested alias.
//
// If the requested model contains a thinking suffix (e.g., "gemini-2.5-pro(8192)"),
// the suffix is preserved in the returned model name. However, if the alias's
// original name already contains a suffix, the config suffix takes priority.
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
return resolveUpstreamModelFromAliasTable(m, auth, requestedModel, modelAliasChannel(auth))
}
func resolveUpstreamModelFromAliasTable(m *Manager, auth *Auth, requestedModel, channel string) string {
if m == nil || auth == nil {
return ""
}
if channel == "" {
return ""
}
// Extract thinking suffix from requested model using ParseSuffix
requestResult := thinking.ParseSuffix(requestedModel)
baseModel := requestResult.ModelName
// Candidate keys to match: base model and raw input (handles suffix-parsing edge cases).
candidates := []string{baseModel}
if baseModel != requestedModel {
candidates = append(candidates, requestedModel)
}
raw := m.oauthModelAlias.Load()
table, _ := raw.(*oauthModelAliasTable)
if table == nil || table.reverse == nil {
return ""
}
rev := table.reverse[channel]
if rev == nil {
return ""
}
for _, candidate := range candidates {
key := strings.ToLower(strings.TrimSpace(candidate))
if key == "" {
continue
}
original := strings.TrimSpace(rev[key])
if original == "" {
continue
}
if strings.EqualFold(original, baseModel) {
return ""
}
// If config already has suffix, it takes priority.
if thinking.ParseSuffix(original).HasSuffix {
return original
}
// Preserve user's thinking suffix on the resolved model.
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return original + "(" + requestResult.RawSuffix + ")"
}
return original
}
return ""
}
// modelAliasChannel extracts the OAuth model alias channel from an Auth object.
// It determines the provider and auth kind from the Auth's attributes and delegates
// to OAuthModelAliasChannel for the actual channel resolution.
func modelAliasChannel(auth *Auth) string {
if auth == nil {
return ""
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
authKind := ""
if auth.Attributes != nil {
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
}
if authKind == "" {
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
return OAuthModelAliasChannel(provider, authKind)
}
// OAuthModelAliasChannel returns the OAuth model alias channel name for a given provider
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
// OAuth model alias (e.g., API key authentication).
//
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
func OAuthModelAliasChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
switch provider {
case "gemini":
// gemini provider uses gemini-api-key config, not oauth-model-alias.
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
return ""
case "vertex":
if authKind == "apikey" {
return ""
}
return "vertex"
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow", "kimi":
return provider
default:
return ""
}
}
================================================
FILE: sdk/cliproxy/auth/oauth_model_alias_test.go
================================================
package auth
import (
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
aliases map[string][]internalconfig.OAuthModelAlias
channel string
input string
want string
}{
{
name: "numeric suffix preserved",
aliases: map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(8192)",
want: "gemini-2.5-pro-exp-03-25(8192)",
},
{
name: "level suffix preserved",
aliases: map[string][]internalconfig.OAuthModelAlias{
"claude": {{Name: "claude-sonnet-4-5-20250514", Alias: "claude-sonnet-4-5"}},
},
channel: "claude",
input: "claude-sonnet-4-5(high)",
want: "claude-sonnet-4-5-20250514(high)",
},
{
name: "no suffix unchanged",
aliases: map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro",
want: "gemini-2.5-pro-exp-03-25",
},
{
name: "config suffix takes priority",
aliases: map[string][]internalconfig.OAuthModelAlias{
"claude": {{Name: "claude-sonnet-4-5-20250514(low)", Alias: "claude-sonnet-4-5"}},
},
channel: "claude",
input: "claude-sonnet-4-5(high)",
want: "claude-sonnet-4-5-20250514(low)",
},
{
name: "auto suffix preserved",
aliases: map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(auto)",
want: "gemini-2.5-pro-exp-03-25(auto)",
},
{
name: "none suffix preserved",
aliases: map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(none)",
want: "gemini-2.5-pro-exp-03-25(none)",
},
{
name: "kimi suffix preserved",
aliases: map[string][]internalconfig.OAuthModelAlias{
"kimi": {{Name: "kimi-k2.5", Alias: "k2.5"}},
},
channel: "kimi",
input: "k2.5(high)",
want: "kimi-k2.5(high)",
},
{
name: "case insensitive alias lookup with suffix",
aliases: map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "Gemini-2.5-Pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(high)",
want: "gemini-2.5-pro-exp-03-25(high)",
},
{
name: "no alias returns empty",
aliases: map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "unknown-model(high)",
want: "",
},
{
name: "wrong channel returns empty",
aliases: map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "claude",
input: "gemini-2.5-pro(high)",
want: "",
},
{
name: "empty suffix filtered out",
aliases: map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro()",
want: "gemini-2.5-pro-exp-03-25",
},
{
name: "incomplete suffix treated as no suffix",
aliases: map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro(high"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(high",
want: "gemini-2.5-pro-exp-03-25",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(&internalconfig.Config{})
mgr.SetOAuthModelAlias(tt.aliases)
auth := createAuthForChannel(tt.channel)
got := mgr.resolveOAuthUpstreamModel(auth, tt.input)
if got != tt.want {
t.Errorf("resolveOAuthUpstreamModel(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func createAuthForChannel(channel string) *Auth {
switch channel {
case "gemini-cli":
return &Auth{Provider: "gemini-cli"}
case "claude":
return &Auth{Provider: "claude", Attributes: map[string]string{"auth_kind": "oauth"}}
case "vertex":
return &Auth{Provider: "vertex", Attributes: map[string]string{"auth_kind": "oauth"}}
case "codex":
return &Auth{Provider: "codex", Attributes: map[string]string{"auth_kind": "oauth"}}
case "aistudio":
return &Auth{Provider: "aistudio"}
case "antigravity":
return &Auth{Provider: "antigravity"}
case "qwen":
return &Auth{Provider: "qwen"}
case "iflow":
return &Auth{Provider: "iflow"}
case "kimi":
return &Auth{Provider: "kimi"}
default:
return &Auth{Provider: channel}
}
}
func TestOAuthModelAliasChannel_Kimi(t *testing.T) {
t.Parallel()
if got := OAuthModelAliasChannel("kimi", "oauth"); got != "kimi" {
t.Fatalf("OAuthModelAliasChannel() = %q, want %q", got, "kimi")
}
}
func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) {
t.Parallel()
aliases := map[string][]internalconfig.OAuthModelAlias{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(&internalconfig.Config{})
mgr.SetOAuthModelAlias(aliases)
auth := &Auth{ID: "test-auth-id", Provider: "gemini-cli"}
resolvedModel := mgr.applyOAuthModelAlias(auth, "gemini-2.5-pro(8192)")
if resolvedModel != "gemini-2.5-pro-exp-03-25(8192)" {
t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "gemini-2.5-pro-exp-03-25(8192)")
}
}
================================================
FILE: sdk/cliproxy/auth/openai_compat_pool_test.go
================================================
package auth
import (
"context"
"net/http"
"sync"
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type openAICompatPoolExecutor struct {
id string
mu sync.Mutex
executeModels []string
countModels []string
streamModels []string
executeErrors map[string]error
countErrors map[string]error
streamFirstErrors map[string]error
streamPayloads map[string][]cliproxyexecutor.StreamChunk
}
func (e *openAICompatPoolExecutor) Identifier() string { return e.id }
func (e *openAICompatPoolExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.executeModels = append(e.executeModels, req.Model)
err := e.executeErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.streamModels = append(e.streamModels, req.Model)
err := e.streamFirstErrors[req.Model]
payloadChunks, hasCustomChunks := e.streamPayloads[req.Model]
chunks := append([]cliproxyexecutor.StreamChunk(nil), payloadChunks...)
e.mu.Unlock()
ch := make(chan cliproxyexecutor.StreamChunk, max(1, len(chunks)))
if err != nil {
ch <- cliproxyexecutor.StreamChunk{Err: err}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
if !hasCustomChunks {
ch <- cliproxyexecutor.StreamChunk{Payload: []byte(req.Model)}
} else {
for _, chunk := range chunks {
ch <- chunk
}
}
close(ch)
return &cliproxyexecutor.StreamResult{Headers: http.Header{"X-Model": {req.Model}}, Chunks: ch}, nil
}
func (e *openAICompatPoolExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *openAICompatPoolExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_ = ctx
_ = auth
_ = opts
e.mu.Lock()
e.countModels = append(e.countModels, req.Model)
err := e.countErrors[req.Model]
e.mu.Unlock()
if err != nil {
return cliproxyexecutor.Response{}, err
}
return cliproxyexecutor.Response{Payload: []byte(req.Model)}, nil
}
func (e *openAICompatPoolExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
_ = ctx
_ = auth
_ = req
return nil, &Error{HTTPStatus: http.StatusNotImplemented, Message: "HttpRequest not implemented"}
}
func (e *openAICompatPoolExecutor) ExecuteModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.executeModels))
copy(out, e.executeModels)
return out
}
func (e *openAICompatPoolExecutor) CountModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.countModels))
copy(out, e.countModels)
return out
}
func (e *openAICompatPoolExecutor) StreamModels() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.streamModels))
copy(out, e.streamModels)
return out
}
func newOpenAICompatPoolTestManager(t *testing.T, alias string, models []internalconfig.OpenAICompatibilityModel, executor *openAICompatPoolExecutor) *Manager {
t.Helper()
cfg := &internalconfig.Config{
OpenAICompatibility: []internalconfig.OpenAICompatibility{{
Name: "pool",
Models: models,
}},
}
m := NewManager(nil, nil, nil)
m.SetConfig(cfg)
if executor == nil {
executor = &openAICompatPoolExecutor{id: "pool"}
}
m.RegisterExecutor(executor)
auth := &Auth{
ID: "pool-auth-" + t.Name(),
Provider: "pool",
Status: StatusActive,
Attributes: map[string]string{
"api_key": "test-key",
"compat_name": "pool",
"provider_key": "pool",
},
}
if _, err := m.Register(context.Background(), auth); err != nil {
t.Fatalf("register auth: %v", err)
}
reg := registry.GetGlobalRegistry()
reg.RegisterClient(auth.ID, "pool", []*registry.ModelInfo{{ID: alias}})
t.Cleanup(func() {
reg.UnregisterClient(auth.ID)
})
return m
}
func TestManagerExecuteCount_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
countErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute count error = %v, want %v", err, invalidErr)
}
got := executor.CountModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("count calls = %v, want only first invalid model", got)
}
}
func TestResolveModelAliasPoolFromConfigModels(t *testing.T) {
models := []modelAliasEntry{
internalconfig.OpenAICompatibilityModel{Name: "qwen3.5-plus", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "glm-5", Alias: "claude-opus-4.66"},
internalconfig.OpenAICompatibilityModel{Name: "kimi-k2.5", Alias: "claude-opus-4.66"},
}
got := resolveModelAliasPoolFromConfigModels("claude-opus-4.66(8192)", models)
want := []string{"qwen3.5-plus(8192)", "glm-5(8192)", "kimi-k2.5(8192)"}
if len(got) != len(want) {
t.Fatalf("pool len = %d, want %d (%v)", len(got), len(want), got)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("pool[%d] = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 3; i++ {
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute %d returned empty payload", i)
}
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5", "qwen3.5-plus"}
if len(got) != len(want) {
t.Fatalf("execute calls = %v, want %v", got, want)
}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecute_OpenAICompatAliasPoolStopsOnBadRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute error = %v, want %v", err, invalidErr)
}
got := executor.ExecuteModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("execute calls = %v, want only first invalid model", got)
}
}
func TestManagerExecute_OpenAICompatAliasPoolFallsBackWithinSameAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
executeErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
resp, err := m.Execute(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute: %v", err)
}
if string(resp.Payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(resp.Payload), "glm-5")
}
got := executor.ExecuteModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("execute call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolRetriesOnEmptyBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamPayloads: map[string][]cliproxyexecutor.StreamChunk{
"qwen3.5-plus": {},
},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolFallsBackBeforeFirstByte(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": &Error{HTTPStatus: http.StatusTooManyRequests, Message: "quota"}},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute stream: %v", err)
}
var payload []byte
for chunk := range streamResult.Chunks {
if chunk.Err != nil {
t.Fatalf("unexpected stream error: %v", chunk.Err)
}
payload = append(payload, chunk.Payload...)
}
if string(payload) != "glm-5" {
t.Fatalf("payload = %q, want %q", string(payload), "glm-5")
}
got := executor.StreamModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("stream call %d model = %q, want %q", i, got[i], want[i])
}
}
if gotHeader := streamResult.Headers.Get("X-Model"); gotHeader != "glm-5" {
t.Fatalf("header X-Model = %q, want %q", gotHeader, "glm-5")
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidRequest(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusUnprocessableEntity, Message: "unprocessable entity"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
_, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil || err.Error() != invalidErr.Error() {
t.Fatalf("execute stream error = %v, want %v", err, invalidErr)
}
got := executor.StreamModels()
if len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first invalid model", got)
}
}
func TestManagerExecuteCount_OpenAICompatAliasPoolRotatesWithinAuth(t *testing.T) {
alias := "claude-opus-4.66"
executor := &openAICompatPoolExecutor{id: "pool"}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
for i := 0; i < 2; i++ {
resp, err := m.ExecuteCount(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err != nil {
t.Fatalf("execute count %d: %v", i, err)
}
if len(resp.Payload) == 0 {
t.Fatalf("execute count %d returned empty payload", i)
}
}
got := executor.CountModels()
want := []string{"qwen3.5-plus", "glm-5"}
for i := range want {
if got[i] != want[i] {
t.Fatalf("count call %d model = %q, want %q", i, got[i], want[i])
}
}
}
func TestManagerExecuteStream_OpenAICompatAliasPoolStopsOnInvalidBootstrap(t *testing.T) {
alias := "claude-opus-4.66"
invalidErr := &Error{HTTPStatus: http.StatusBadRequest, Message: "invalid_request_error: malformed payload"}
executor := &openAICompatPoolExecutor{
id: "pool",
streamFirstErrors: map[string]error{"qwen3.5-plus": invalidErr},
}
m := newOpenAICompatPoolTestManager(t, alias, []internalconfig.OpenAICompatibilityModel{
{Name: "qwen3.5-plus", Alias: alias},
{Name: "glm-5", Alias: alias},
}, executor)
streamResult, err := m.ExecuteStream(context.Background(), []string{"pool"}, cliproxyexecutor.Request{Model: alias}, cliproxyexecutor.Options{})
if err == nil {
t.Fatal("expected invalid request error")
}
if err != invalidErr {
t.Fatalf("error = %v, want %v", err, invalidErr)
}
if streamResult != nil {
t.Fatalf("streamResult = %#v, want nil on invalid bootstrap", streamResult)
}
if got := executor.StreamModels(); len(got) != 1 || got[0] != "qwen3.5-plus" {
t.Fatalf("stream calls = %v, want only first upstream model", got)
}
}
================================================
FILE: sdk/cliproxy/auth/persist_policy.go
================================================
package auth
import "context"
type skipPersistContextKey struct{}
// WithSkipPersist returns a derived context that disables persistence for Manager Update/Register calls.
// It is intended for code paths that are reacting to file watcher events, where the file on disk is
// already the source of truth and persisting again would create a write-back loop.
func WithSkipPersist(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, skipPersistContextKey{}, true)
}
func shouldSkipPersist(ctx context.Context) bool {
if ctx == nil {
return false
}
v := ctx.Value(skipPersistContextKey{})
enabled, ok := v.(bool)
return ok && enabled
}
================================================
FILE: sdk/cliproxy/auth/persist_policy_test.go
================================================
package auth
import (
"context"
"sync/atomic"
"testing"
)
type countingStore struct {
saveCount atomic.Int32
}
func (s *countingStore) List(context.Context) ([]*Auth, error) { return nil, nil }
func (s *countingStore) Save(context.Context, *Auth) (string, error) {
s.saveCount.Add(1)
return "", nil
}
func (s *countingStore) Delete(context.Context, string) error { return nil }
func TestWithSkipPersist_DisablesUpdatePersistence(t *testing.T) {
store := &countingStore{}
mgr := NewManager(store, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{"type": "antigravity"},
}
if _, err := mgr.Update(context.Background(), auth); err != nil {
t.Fatalf("Update returned error: %v", err)
}
if got := store.saveCount.Load(); got != 1 {
t.Fatalf("expected 1 Save call, got %d", got)
}
ctxSkip := WithSkipPersist(context.Background())
if _, err := mgr.Update(ctxSkip, auth); err != nil {
t.Fatalf("Update(skipPersist) returned error: %v", err)
}
if got := store.saveCount.Load(); got != 1 {
t.Fatalf("expected Save call count to remain 1, got %d", got)
}
}
func TestWithSkipPersist_DisablesRegisterPersistence(t *testing.T) {
store := &countingStore{}
mgr := NewManager(store, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{"type": "antigravity"},
}
if _, err := mgr.Register(WithSkipPersist(context.Background()), auth); err != nil {
t.Fatalf("Register(skipPersist) returned error: %v", err)
}
if got := store.saveCount.Load(); got != 0 {
t.Fatalf("expected 0 Save calls, got %d", got)
}
}
================================================
FILE: sdk/cliproxy/auth/scheduler.go
================================================
package auth
import (
"context"
"sort"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// schedulerStrategy identifies which built-in routing semantics the scheduler should apply.
type schedulerStrategy int
const (
schedulerStrategyCustom schedulerStrategy = iota
schedulerStrategyRoundRobin
schedulerStrategyFillFirst
)
// scheduledState describes how an auth currently participates in a model shard.
type scheduledState int
const (
scheduledStateReady scheduledState = iota
scheduledStateCooldown
scheduledStateBlocked
scheduledStateDisabled
)
// authScheduler keeps the incremental provider/model scheduling state used by Manager.
type authScheduler struct {
mu sync.Mutex
strategy schedulerStrategy
providers map[string]*providerScheduler
authProviders map[string]string
mixedCursors map[string]int
}
// providerScheduler stores auth metadata and model shards for a single provider.
type providerScheduler struct {
providerKey string
auths map[string]*scheduledAuthMeta
modelShards map[string]*modelScheduler
}
// scheduledAuthMeta stores the immutable scheduling fields derived from an auth snapshot.
type scheduledAuthMeta struct {
auth *Auth
providerKey string
priority int
virtualParent string
websocketEnabled bool
supportedModelSet map[string]struct{}
}
// modelScheduler tracks ready and blocked auths for one provider/model combination.
type modelScheduler struct {
modelKey string
entries map[string]*scheduledAuth
priorityOrder []int
readyByPriority map[int]*readyBucket
blocked cooldownQueue
}
// scheduledAuth stores the runtime scheduling state for a single auth inside a model shard.
type scheduledAuth struct {
meta *scheduledAuthMeta
auth *Auth
state scheduledState
nextRetryAt time.Time
}
// readyBucket keeps the ready views for one priority level.
type readyBucket struct {
all readyView
ws readyView
}
// readyView holds the selection order for flat or grouped round-robin traversal.
type readyView struct {
flat []*scheduledAuth
cursor int
parentOrder []string
parentCursor int
children map[string]*childBucket
}
// childBucket keeps the per-parent rotation state for grouped Gemini virtual auths.
type childBucket struct {
items []*scheduledAuth
cursor int
}
// cooldownQueue is the blocked auth collection ordered by next retry time during rebuilds.
type cooldownQueue []*scheduledAuth
// newAuthScheduler constructs an empty scheduler configured for the supplied selector strategy.
func newAuthScheduler(selector Selector) *authScheduler {
return &authScheduler{
strategy: selectorStrategy(selector),
providers: make(map[string]*providerScheduler),
authProviders: make(map[string]string),
mixedCursors: make(map[string]int),
}
}
// selectorStrategy maps a selector implementation to the scheduler semantics it should emulate.
func selectorStrategy(selector Selector) schedulerStrategy {
switch selector.(type) {
case *FillFirstSelector:
return schedulerStrategyFillFirst
case nil, *RoundRobinSelector:
return schedulerStrategyRoundRobin
default:
return schedulerStrategyCustom
}
}
// setSelector updates the active built-in strategy and resets mixed-provider cursors.
func (s *authScheduler) setSelector(selector Selector) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.strategy = selectorStrategy(selector)
clear(s.mixedCursors)
}
// rebuild recreates the complete scheduler state from an auth snapshot.
func (s *authScheduler) rebuild(auths []*Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.providers = make(map[string]*providerScheduler)
s.authProviders = make(map[string]string)
s.mixedCursors = make(map[string]int)
now := time.Now()
for _, auth := range auths {
s.upsertAuthLocked(auth, now)
}
}
// upsertAuth incrementally synchronizes one auth into the scheduler.
func (s *authScheduler) upsertAuth(auth *Auth) {
if s == nil {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.upsertAuthLocked(auth, time.Now())
}
// removeAuth deletes one auth from every scheduler shard that references it.
func (s *authScheduler) removeAuth(authID string) {
if s == nil {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
s.mu.Lock()
defer s.mu.Unlock()
s.removeAuthLocked(authID)
}
// pickSingle returns the next auth for a single provider/model request using scheduler state.
func (s *authScheduler) pickSingle(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, error) {
if s == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerKey := strings.ToLower(strings.TrimSpace(provider))
modelKey := canonicalModelKey(model)
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
preferWebsocket := cliproxyexecutor.DownstreamWebsocket(ctx) && providerKey == "codex" && pinnedAuthID == ""
s.mu.Lock()
defer s.mu.Unlock()
providerState := s.providers[providerKey]
if providerState == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
if shard == nil {
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
if pinnedAuthID != "" && entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) > 0 {
if _, ok := tried[entry.auth.ID]; ok {
return false
}
}
return true
}
if picked := shard.pickReadyLocked(preferWebsocket, s.strategy, predicate); picked != nil {
return picked, nil
}
return nil, shard.unavailableErrorLocked(provider, model, predicate)
}
// pickMixed returns the next auth and provider for a mixed-provider request.
func (s *authScheduler) pickMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, string, error) {
if s == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
normalized := normalizeProviderKeys(providers)
if len(normalized) == 0 {
return nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
modelKey := canonicalModelKey(model)
s.mu.Lock()
defer s.mu.Unlock()
if pinnedAuthID != "" {
providerKey := s.authProviders[pinnedAuthID]
if providerKey == "" || !containsProvider(normalized, providerKey) {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
providerState := s.providers[providerKey]
if providerState == nil {
return nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
}
shard := providerState.ensureModelLocked(modelKey, time.Now())
predicate := func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil || entry.auth.ID != pinnedAuthID {
return false
}
if len(tried) == 0 {
return true
}
_, ok := tried[pinnedAuthID]
return !ok
}
if picked := shard.pickReadyLocked(false, s.strategy, predicate); picked != nil {
return picked, providerKey, nil
}
return nil, "", shard.unavailableErrorLocked("mixed", model, predicate)
}
predicate := triedPredicate(tried)
candidateShards := make([]*modelScheduler, len(normalized))
bestPriority := 0
hasCandidate := false
now := time.Now()
for providerIndex, providerKey := range normalized {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(modelKey, now)
candidateShards[providerIndex] = shard
if shard == nil {
continue
}
priorityReady, okPriority := shard.highestReadyPriorityLocked(false, predicate)
if !okPriority {
continue
}
if !hasCandidate || priorityReady > bestPriority {
bestPriority = priorityReady
hasCandidate = true
}
}
if !hasCandidate {
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
if s.strategy == schedulerStrategyFillFirst {
for providerIndex, providerKey := range normalized {
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, s.strategy, predicate)
if picked != nil {
return picked, providerKey, nil
}
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
cursorKey := strings.Join(normalized, ",") + ":" + modelKey
start := 0
if len(normalized) > 0 {
start = s.mixedCursors[cursorKey] % len(normalized)
}
for offset := 0; offset < len(normalized); offset++ {
providerIndex := (start + offset) % len(normalized)
providerKey := normalized[providerIndex]
shard := candidateShards[providerIndex]
if shard == nil {
continue
}
picked := shard.pickReadyAtPriorityLocked(false, bestPriority, schedulerStrategyRoundRobin, predicate)
if picked == nil {
continue
}
s.mixedCursors[cursorKey] = providerIndex + 1
return picked, providerKey, nil
}
return nil, "", s.mixedUnavailableErrorLocked(normalized, model, tried)
}
// mixedUnavailableErrorLocked synthesizes the mixed-provider cooldown or unavailable error.
func (s *authScheduler) mixedUnavailableErrorLocked(providers []string, model string, tried map[string]struct{}) error {
now := time.Now()
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, providerKey := range providers {
providerState := s.providers[providerKey]
if providerState == nil {
continue
}
shard := providerState.ensureModelLocked(canonicalModelKey(model), now)
if shard == nil {
continue
}
localTotal, localCooldownCount, localEarliest := shard.availabilitySummaryLocked(triedPredicate(tried))
total += localTotal
cooldownCount += localCooldownCount
if !localEarliest.IsZero() && (earliest.IsZero() || localEarliest.Before(earliest)) {
earliest = localEarliest
}
}
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, "", resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// triedPredicate builds a filter that excludes auths already attempted for the current request.
func triedPredicate(tried map[string]struct{}) func(*scheduledAuth) bool {
if len(tried) == 0 {
return func(entry *scheduledAuth) bool { return entry != nil && entry.auth != nil }
}
return func(entry *scheduledAuth) bool {
if entry == nil || entry.auth == nil {
return false
}
_, ok := tried[entry.auth.ID]
return !ok
}
}
// normalizeProviderKeys lowercases, trims, and de-duplicates provider keys while preserving order.
func normalizeProviderKeys(providers []string) []string {
seen := make(map[string]struct{}, len(providers))
out := make([]string, 0, len(providers))
for _, provider := range providers {
providerKey := strings.ToLower(strings.TrimSpace(provider))
if providerKey == "" {
continue
}
if _, ok := seen[providerKey]; ok {
continue
}
seen[providerKey] = struct{}{}
out = append(out, providerKey)
}
return out
}
// containsProvider reports whether provider is present in the normalized provider list.
func containsProvider(providers []string, provider string) bool {
for _, candidate := range providers {
if candidate == provider {
return true
}
}
return false
}
// upsertAuthLocked updates one auth in-place while the scheduler mutex is held.
func (s *authScheduler) upsertAuthLocked(auth *Auth, now time.Time) {
if auth == nil {
return
}
authID := strings.TrimSpace(auth.ID)
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
if authID == "" || providerKey == "" || auth.Disabled {
s.removeAuthLocked(authID)
return
}
if previousProvider := s.authProviders[authID]; previousProvider != "" && previousProvider != providerKey {
if previousState := s.providers[previousProvider]; previousState != nil {
previousState.removeAuthLocked(authID)
}
}
meta := buildScheduledAuthMeta(auth)
s.authProviders[authID] = providerKey
s.ensureProviderLocked(providerKey).upsertAuthLocked(meta, now)
}
// removeAuthLocked removes one auth from the scheduler while the scheduler mutex is held.
func (s *authScheduler) removeAuthLocked(authID string) {
if authID == "" {
return
}
if providerKey := s.authProviders[authID]; providerKey != "" {
if providerState := s.providers[providerKey]; providerState != nil {
providerState.removeAuthLocked(authID)
}
delete(s.authProviders, authID)
}
}
// ensureProviderLocked returns the provider scheduler for providerKey, creating it when needed.
func (s *authScheduler) ensureProviderLocked(providerKey string) *providerScheduler {
if s.providers == nil {
s.providers = make(map[string]*providerScheduler)
}
providerState := s.providers[providerKey]
if providerState == nil {
providerState = &providerScheduler{
providerKey: providerKey,
auths: make(map[string]*scheduledAuthMeta),
modelShards: make(map[string]*modelScheduler),
}
s.providers[providerKey] = providerState
}
return providerState
}
// buildScheduledAuthMeta extracts the scheduling metadata needed for shard bookkeeping.
func buildScheduledAuthMeta(auth *Auth) *scheduledAuthMeta {
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
virtualParent := ""
if auth.Attributes != nil {
virtualParent = strings.TrimSpace(auth.Attributes["gemini_virtual_parent"])
}
return &scheduledAuthMeta{
auth: auth,
providerKey: providerKey,
priority: authPriority(auth),
virtualParent: virtualParent,
websocketEnabled: authWebsocketsEnabled(auth),
supportedModelSet: supportedModelSetForAuth(auth.ID),
}
}
// supportedModelSetForAuth snapshots the registry models currently registered for an auth.
func supportedModelSetForAuth(authID string) map[string]struct{} {
authID = strings.TrimSpace(authID)
if authID == "" {
return nil
}
models := registry.GetGlobalRegistry().GetModelsForClient(authID)
if len(models) == 0 {
return nil
}
set := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
modelKey := canonicalModelKey(model.ID)
if modelKey == "" {
continue
}
set[modelKey] = struct{}{}
}
return set
}
// upsertAuthLocked updates every existing model shard that can reference the auth metadata.
func (p *providerScheduler) upsertAuthLocked(meta *scheduledAuthMeta, now time.Time) {
if p == nil || meta == nil || meta.auth == nil {
return
}
p.auths[meta.auth.ID] = meta
for modelKey, shard := range p.modelShards {
if shard == nil {
continue
}
if !meta.supportsModel(modelKey) {
shard.removeEntryLocked(meta.auth.ID)
continue
}
shard.upsertEntryLocked(meta, now)
}
}
// removeAuthLocked removes an auth from all model shards owned by the provider scheduler.
func (p *providerScheduler) removeAuthLocked(authID string) {
if p == nil || authID == "" {
return
}
delete(p.auths, authID)
for _, shard := range p.modelShards {
if shard != nil {
shard.removeEntryLocked(authID)
}
}
}
// ensureModelLocked returns the shard for modelKey, building it lazily from provider auths.
func (p *providerScheduler) ensureModelLocked(modelKey string, now time.Time) *modelScheduler {
if p == nil {
return nil
}
modelKey = canonicalModelKey(modelKey)
if shard, ok := p.modelShards[modelKey]; ok && shard != nil {
shard.promoteExpiredLocked(now)
return shard
}
shard := &modelScheduler{
modelKey: modelKey,
entries: make(map[string]*scheduledAuth),
readyByPriority: make(map[int]*readyBucket),
}
for _, meta := range p.auths {
if meta == nil || !meta.supportsModel(modelKey) {
continue
}
shard.upsertEntryLocked(meta, now)
}
p.modelShards[modelKey] = shard
return shard
}
// supportsModel reports whether the auth metadata currently supports modelKey.
func (m *scheduledAuthMeta) supportsModel(modelKey string) bool {
modelKey = canonicalModelKey(modelKey)
if modelKey == "" {
return true
}
if len(m.supportedModelSet) == 0 {
return false
}
_, ok := m.supportedModelSet[modelKey]
return ok
}
// upsertEntryLocked updates or inserts one auth entry and rebuilds indexes when ordering changes.
func (m *modelScheduler) upsertEntryLocked(meta *scheduledAuthMeta, now time.Time) {
if m == nil || meta == nil || meta.auth == nil {
return
}
entry, ok := m.entries[meta.auth.ID]
if !ok || entry == nil {
entry = &scheduledAuth{}
m.entries[meta.auth.ID] = entry
}
previousState := entry.state
previousNextRetryAt := entry.nextRetryAt
previousPriority := 0
previousParent := ""
previousWebsocketEnabled := false
if entry.meta != nil {
previousPriority = entry.meta.priority
previousParent = entry.meta.virtualParent
previousWebsocketEnabled = entry.meta.websocketEnabled
}
entry.meta = meta
entry.auth = meta.auth
entry.nextRetryAt = time.Time{}
blocked, reason, next := isAuthBlockedForModel(meta.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
if ok && previousState == entry.state && previousNextRetryAt.Equal(entry.nextRetryAt) && previousPriority == meta.priority && previousParent == meta.virtualParent && previousWebsocketEnabled == meta.websocketEnabled {
return
}
m.rebuildIndexesLocked()
}
// removeEntryLocked deletes one auth entry and rebuilds the shard indexes if needed.
func (m *modelScheduler) removeEntryLocked(authID string) {
if m == nil || authID == "" {
return
}
if _, ok := m.entries[authID]; !ok {
return
}
delete(m.entries, authID)
m.rebuildIndexesLocked()
}
// promoteExpiredLocked reevaluates blocked auths whose retry time has elapsed.
func (m *modelScheduler) promoteExpiredLocked(now time.Time) {
if m == nil || len(m.blocked) == 0 {
return
}
changed := false
for _, entry := range m.blocked {
if entry == nil || entry.auth == nil {
continue
}
if entry.nextRetryAt.IsZero() || entry.nextRetryAt.After(now) {
continue
}
blocked, reason, next := isAuthBlockedForModel(entry.auth, m.modelKey, now)
switch {
case !blocked:
entry.state = scheduledStateReady
entry.nextRetryAt = time.Time{}
case reason == blockReasonCooldown:
entry.state = scheduledStateCooldown
entry.nextRetryAt = next
case reason == blockReasonDisabled:
entry.state = scheduledStateDisabled
entry.nextRetryAt = time.Time{}
default:
entry.state = scheduledStateBlocked
entry.nextRetryAt = next
}
changed = true
}
if changed {
m.rebuildIndexesLocked()
}
}
// pickReadyLocked selects the next ready auth from the highest available priority bucket.
func (m *modelScheduler) pickReadyLocked(preferWebsocket bool, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
m.promoteExpiredLocked(time.Now())
priorityReady, okPriority := m.highestReadyPriorityLocked(preferWebsocket, predicate)
if !okPriority {
return nil
}
return m.pickReadyAtPriorityLocked(preferWebsocket, priorityReady, strategy, predicate)
}
// highestReadyPriorityLocked returns the highest priority bucket that still has a matching ready auth.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) highestReadyPriorityLocked(preferWebsocket bool, predicate func(*scheduledAuth) bool) (int, bool) {
if m == nil {
return 0, false
}
for _, priority := range m.priorityOrder {
bucket := m.readyByPriority[priority]
if bucket == nil {
continue
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
if view.pickFirst(predicate) != nil {
return priority, true
}
}
return 0, false
}
// pickReadyAtPriorityLocked selects the next ready auth from a specific priority bucket.
// The caller must ensure expired entries are already promoted when needed.
func (m *modelScheduler) pickReadyAtPriorityLocked(preferWebsocket bool, priority int, strategy schedulerStrategy, predicate func(*scheduledAuth) bool) *Auth {
if m == nil {
return nil
}
bucket := m.readyByPriority[priority]
if bucket == nil {
return nil
}
view := &bucket.all
if preferWebsocket && len(bucket.ws.flat) > 0 {
view = &bucket.ws
}
var picked *scheduledAuth
if strategy == schedulerStrategyFillFirst {
picked = view.pickFirst(predicate)
} else {
picked = view.pickRoundRobin(predicate)
}
if picked == nil || picked.auth == nil {
return nil
}
return picked.auth
}
// unavailableErrorLocked returns the correct unavailable or cooldown error for the shard.
func (m *modelScheduler) unavailableErrorLocked(provider, model string, predicate func(*scheduledAuth) bool) error {
now := time.Now()
total, cooldownCount, earliest := m.availabilitySummaryLocked(predicate)
if total == 0 {
return &Error{Code: "auth_not_found", Message: "no auth available"}
}
if cooldownCount == total && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return newModelCooldownError(model, providerForError, resetIn)
}
return &Error{Code: "auth_unavailable", Message: "no auth available"}
}
// availabilitySummaryLocked summarizes total candidates, cooldown count, and earliest retry time.
func (m *modelScheduler) availabilitySummaryLocked(predicate func(*scheduledAuth) bool) (int, int, time.Time) {
if m == nil {
return 0, 0, time.Time{}
}
total := 0
cooldownCount := 0
earliest := time.Time{}
for _, entry := range m.entries {
if predicate != nil && !predicate(entry) {
continue
}
total++
if entry == nil || entry.auth == nil {
continue
}
if entry.state != scheduledStateCooldown {
continue
}
cooldownCount++
if !entry.nextRetryAt.IsZero() && (earliest.IsZero() || entry.nextRetryAt.Before(earliest)) {
earliest = entry.nextRetryAt
}
}
return total, cooldownCount, earliest
}
// rebuildIndexesLocked reconstructs ready and blocked views from the current entry map.
func (m *modelScheduler) rebuildIndexesLocked() {
m.readyByPriority = make(map[int]*readyBucket)
m.priorityOrder = m.priorityOrder[:0]
m.blocked = m.blocked[:0]
priorityBuckets := make(map[int][]*scheduledAuth)
for _, entry := range m.entries {
if entry == nil || entry.auth == nil {
continue
}
switch entry.state {
case scheduledStateReady:
priority := entry.meta.priority
priorityBuckets[priority] = append(priorityBuckets[priority], entry)
case scheduledStateCooldown, scheduledStateBlocked:
m.blocked = append(m.blocked, entry)
}
}
for priority, entries := range priorityBuckets {
sort.Slice(entries, func(i, j int) bool {
return entries[i].auth.ID < entries[j].auth.ID
})
m.readyByPriority[priority] = buildReadyBucket(entries)
m.priorityOrder = append(m.priorityOrder, priority)
}
sort.Slice(m.priorityOrder, func(i, j int) bool {
return m.priorityOrder[i] > m.priorityOrder[j]
})
sort.Slice(m.blocked, func(i, j int) bool {
left := m.blocked[i]
right := m.blocked[j]
if left == nil || right == nil {
return left != nil
}
if left.nextRetryAt.Equal(right.nextRetryAt) {
return left.auth.ID < right.auth.ID
}
if left.nextRetryAt.IsZero() {
return false
}
if right.nextRetryAt.IsZero() {
return true
}
return left.nextRetryAt.Before(right.nextRetryAt)
})
}
// buildReadyBucket prepares the general and websocket-only ready views for one priority bucket.
func buildReadyBucket(entries []*scheduledAuth) *readyBucket {
bucket := &readyBucket{}
bucket.all = buildReadyView(entries)
wsEntries := make([]*scheduledAuth, 0, len(entries))
for _, entry := range entries {
if entry != nil && entry.meta != nil && entry.meta.websocketEnabled {
wsEntries = append(wsEntries, entry)
}
}
bucket.ws = buildReadyView(wsEntries)
return bucket
}
// buildReadyView creates either a flat view or a grouped parent/child view for rotation.
func buildReadyView(entries []*scheduledAuth) readyView {
view := readyView{flat: append([]*scheduledAuth(nil), entries...)}
if len(entries) == 0 {
return view
}
groups := make(map[string][]*scheduledAuth)
for _, entry := range entries {
if entry == nil || entry.meta == nil || entry.meta.virtualParent == "" {
return view
}
groups[entry.meta.virtualParent] = append(groups[entry.meta.virtualParent], entry)
}
if len(groups) <= 1 {
return view
}
view.children = make(map[string]*childBucket, len(groups))
view.parentOrder = make([]string, 0, len(groups))
for parent := range groups {
view.parentOrder = append(view.parentOrder, parent)
}
sort.Strings(view.parentOrder)
for _, parent := range view.parentOrder {
view.children[parent] = &childBucket{items: append([]*scheduledAuth(nil), groups[parent]...)}
}
return view
}
// pickFirst returns the first ready entry that satisfies predicate without advancing cursors.
func (v *readyView) pickFirst(predicate func(*scheduledAuth) bool) *scheduledAuth {
for _, entry := range v.flat {
if predicate == nil || predicate(entry) {
return entry
}
}
return nil
}
// pickRoundRobin returns the next ready entry using flat or grouped round-robin traversal.
func (v *readyView) pickRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
if len(v.parentOrder) > 1 && len(v.children) > 0 {
return v.pickGroupedRoundRobin(predicate)
}
if len(v.flat) == 0 {
return nil
}
start := 0
if len(v.flat) > 0 {
start = v.cursor % len(v.flat)
}
for offset := 0; offset < len(v.flat); offset++ {
index := (start + offset) % len(v.flat)
entry := v.flat[index]
if predicate != nil && !predicate(entry) {
continue
}
v.cursor = index + 1
return entry
}
return nil
}
// pickGroupedRoundRobin rotates across parents first and then within the selected parent.
func (v *readyView) pickGroupedRoundRobin(predicate func(*scheduledAuth) bool) *scheduledAuth {
start := 0
if len(v.parentOrder) > 0 {
start = v.parentCursor % len(v.parentOrder)
}
for offset := 0; offset < len(v.parentOrder); offset++ {
parentIndex := (start + offset) % len(v.parentOrder)
parent := v.parentOrder[parentIndex]
child := v.children[parent]
if child == nil || len(child.items) == 0 {
continue
}
itemStart := child.cursor % len(child.items)
for itemOffset := 0; itemOffset < len(child.items); itemOffset++ {
itemIndex := (itemStart + itemOffset) % len(child.items)
entry := child.items[itemIndex]
if predicate != nil && !predicate(entry) {
continue
}
child.cursor = itemIndex + 1
v.parentCursor = parentIndex + 1
return entry
}
}
return nil
}
================================================
FILE: sdk/cliproxy/auth/scheduler_benchmark_test.go
================================================
package auth
import (
"context"
"fmt"
"net/http"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerBenchmarkExecutor struct {
id string
}
func (e schedulerBenchmarkExecutor) Identifier() string { return e.id }
func (e schedulerBenchmarkExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (e schedulerBenchmarkExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e schedulerBenchmarkExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e schedulerBenchmarkExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
func benchmarkManagerSetup(b *testing.B, total int, mixed bool, withPriority bool) (*Manager, []string, string) {
b.Helper()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
providers := []string{"gemini"}
manager.executors["gemini"] = schedulerBenchmarkExecutor{id: "gemini"}
if mixed {
providers = []string{"gemini", "claude"}
manager.executors["claude"] = schedulerBenchmarkExecutor{id: "claude"}
}
reg := registry.GetGlobalRegistry()
model := "bench-model"
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
auth := &Auth{ID: fmt.Sprintf("bench-%s-%04d", provider, index), Provider: provider}
if withPriority {
priority := "0"
if index%2 == 0 {
priority = "10"
}
auth.Attributes = map[string]string{"priority": priority}
}
_, errRegister := manager.Register(context.Background(), auth)
if errRegister != nil {
b.Fatalf("Register(%s) error = %v", auth.ID, errRegister)
}
reg.RegisterClient(auth.ID, provider, []*registry.ModelInfo{{ID: model}})
}
manager.syncScheduler()
b.Cleanup(func() {
for index := 0; index < total; index++ {
provider := providers[0]
if mixed && index%2 == 1 {
provider = providers[1]
}
reg.UnregisterClient(fmt.Sprintf("bench-%s-%04d", provider, index))
}
})
return manager, providers, model
}
func BenchmarkManagerPickNext500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNext1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority500(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 500, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextPriority1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil || exec == nil {
b.Fatalf("pickNext failed: auth=%v exec=%v err=%v", auth, exec, errPick)
}
}
}
func BenchmarkManagerPickNextMixed500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextMixedPriority500(b *testing.B) {
manager, providers, model := benchmarkManagerSetup(b, 500, true, true)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, _, errWarm := manager.pickNextMixed(ctx, providers, model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNextMixed error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, exec, provider, errPick := manager.pickNextMixed(ctx, providers, model, opts, tried)
if errPick != nil || auth == nil || exec == nil || provider == "" {
b.Fatalf("pickNextMixed failed: auth=%v exec=%v provider=%q err=%v", auth, exec, provider, errPick)
}
}
}
func BenchmarkManagerPickNextAndMarkResult1000(b *testing.B) {
manager, _, model := benchmarkManagerSetup(b, 1000, false, false)
ctx := context.Background()
opts := cliproxyexecutor.Options{}
tried := map[string]struct{}{}
if _, _, errWarm := manager.pickNext(ctx, "gemini", model, opts, tried); errWarm != nil {
b.Fatalf("warmup pickNext error = %v", errWarm)
}
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
auth, _, errPick := manager.pickNext(ctx, "gemini", model, opts, tried)
if errPick != nil || auth == nil {
b.Fatalf("pickNext failed: auth=%v err=%v", auth, errPick)
}
manager.MarkResult(ctx, Result{AuthID: auth.ID, Provider: "gemini", Model: model, Success: true})
}
}
================================================
FILE: sdk/cliproxy/auth/scheduler_test.go
================================================
package auth
import (
"context"
"net/http"
"testing"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type schedulerTestExecutor struct{}
func (schedulerTestExecutor) Identifier() string { return "test" }
func (schedulerTestExecutor) Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (*cliproxyexecutor.StreamResult, error) {
return nil, nil
}
func (schedulerTestExecutor) Refresh(ctx context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (schedulerTestExecutor) CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (schedulerTestExecutor) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
return nil, nil
}
type trackingSelector struct {
calls int
lastAuthID []string
}
func (s *trackingSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
s.calls++
s.lastAuthID = s.lastAuthID[:0]
for _, auth := range auths {
s.lastAuthID = append(s.lastAuthID, auth.ID)
}
if len(auths) == 0 {
return nil, nil
}
return auths[len(auths)-1], nil
}
func newSchedulerForTest(selector Selector, auths ...*Auth) *authScheduler {
scheduler := newAuthScheduler(selector)
scheduler.rebuild(auths)
return scheduler
}
func registerSchedulerModels(t *testing.T, provider string, model string, authIDs ...string) {
t.Helper()
reg := registry.GetGlobalRegistry()
for _, authID := range authIDs {
reg.RegisterClient(authID, provider, []*registry.ModelInfo{{ID: model}})
}
t.Cleanup(func() {
for _, authID := range authIDs {
reg.UnregisterClient(authID)
}
})
}
func TestSchedulerPick_RoundRobinHighestPriority(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "gemini", Attributes: map[string]string{"priority": "0"}},
&Auth{ID: "high-b", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
&Auth{ID: "high-a", Provider: "gemini", Attributes: map[string]string{"priority": "10"}},
)
want := []string{"high-a", "high-b", "high-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_FillFirstSticksToFirstReady(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&FillFirstSelector{},
&Auth{ID: "b", Provider: "gemini"},
&Auth{ID: "a", Provider: "gemini"},
&Auth{ID: "c", Provider: "gemini"},
)
for index := 0; index < 3; index++ {
got, errPick := scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != "a" {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, "a")
}
}
}
func TestSchedulerPick_PromotesExpiredCooldownBeforePick(t *testing.T) {
t.Parallel()
model := "gemini-2.5-pro"
registerSchedulerModels(t, "gemini", model, "cooldown-expired")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{
ID: "cooldown-expired",
Provider: "gemini",
ModelStates: map[string]*ModelState{
model: {
Status: StatusError,
Unavailable: true,
NextRetryAfter: time.Now().Add(-1 * time.Second),
},
},
},
)
got, errPick := scheduler.pickSingle(context.Background(), "gemini", model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickSingle() auth = nil")
}
if got.ID != "cooldown-expired" {
t.Fatalf("pickSingle() auth.ID = %q, want %q", got.ID, "cooldown-expired")
}
}
func TestSchedulerPick_GeminiVirtualParentUsesTwoLevelRotation(t *testing.T) {
t.Parallel()
registerSchedulerModels(t, "gemini-cli", "gemini-2.5-pro", "cred-a::proj-1", "cred-a::proj-2", "cred-b::proj-1", "cred-b::proj-2")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "cred-a::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-a::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-a"}},
&Auth{ID: "cred-b::proj-1", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
&Auth{ID: "cred-b::proj-2", Provider: "gemini-cli", Attributes: map[string]string{"gemini_virtual_parent": "cred-b"}},
)
wantParents := []string{"cred-a", "cred-b", "cred-a", "cred-b"}
wantIDs := []string{"cred-a::proj-1", "cred-b::proj-1", "cred-a::proj-2", "cred-b::proj-2"}
for index := range wantIDs {
got, errPick := scheduler.pickSingle(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantIDs[index] {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
if got.Attributes["gemini_virtual_parent"] != wantParents[index] {
t.Fatalf("pickSingle() #%d parent = %q, want %q", index, got.Attributes["gemini_virtual_parent"], wantParents[index])
}
}
}
func TestSchedulerPick_CodexWebsocketPrefersWebsocketEnabledSubset(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "codex-http", Provider: "codex"},
&Auth{ID: "codex-ws-a", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
&Auth{ID: "codex-ws-b", Provider: "codex", Attributes: map[string]string{"websockets": "true"}},
)
ctx := cliproxyexecutor.WithDownstreamWebsocket(context.Background())
want := []string{"codex-ws-a", "codex-ws-b", "codex-ws-a"}
for index, wantID := range want {
got, errPick := scheduler.pickSingle(ctx, "codex", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickSingle() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickSingle() #%d auth = nil", index)
}
if got.ID != wantID {
t.Fatalf("pickSingle() #%d auth.ID = %q, want %q", index, got.ID, wantID)
}
}
}
func TestSchedulerPick_MixedProvidersUsesProviderRotationOverReadyCandidates(t *testing.T) {
t.Parallel()
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "gemini-a", Provider: "gemini"},
&Auth{ID: "gemini-b", Provider: "gemini"},
&Auth{ID: "claude-a", Provider: "claude"},
)
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestSchedulerPick_MixedProvidersPrefersHighestPriorityTier(t *testing.T) {
t.Parallel()
model := "gpt-default"
registerSchedulerModels(t, "provider-low", model, "low")
registerSchedulerModels(t, "provider-high-a", model, "high-a")
registerSchedulerModels(t, "provider-high-b", model, "high-b")
scheduler := newSchedulerForTest(
&RoundRobinSelector{},
&Auth{ID: "low", Provider: "provider-low", Attributes: map[string]string{"priority": "4"}},
&Auth{ID: "high-a", Provider: "provider-high-a", Attributes: map[string]string{"priority": "7"}},
&Auth{ID: "high-b", Provider: "provider-high-b", Attributes: map[string]string{"priority": "7"}},
)
providers := []string{"provider-low", "provider-high-a", "provider-high-b"}
wantProviders := []string{"provider-high-a", "provider-high-b", "provider-high-a", "provider-high-b"}
wantIDs := []string{"high-a", "high-b", "high-a", "high-b"}
for index := range wantProviders {
got, provider, errPick := scheduler.pickMixed(context.Background(), providers, model, cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_UsesProviderRotationBeforeCredentialRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManagerCustomSelector_FallsBackToLegacyPath(t *testing.T) {
t.Parallel()
selector := &trackingSelector{}
manager := NewManager(nil, selector, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.auths["auth-a"] = &Auth{ID: "auth-a", Provider: "gemini"}
manager.auths["auth-b"] = &Auth{ID: "auth-b", Provider: "gemini"}
got, _, errPick := manager.pickNext(context.Background(), "gemini", "", cliproxyexecutor.Options{}, map[string]struct{}{})
if errPick != nil {
t.Fatalf("pickNext() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNext() auth = nil")
}
if selector.calls != 1 {
t.Fatalf("selector.calls = %d, want %d", selector.calls, 1)
}
if len(selector.lastAuthID) != 2 {
t.Fatalf("len(selector.lastAuthID) = %d, want %d", len(selector.lastAuthID), 2)
}
if got.ID != selector.lastAuthID[len(selector.lastAuthID)-1] {
t.Fatalf("pickNext() auth.ID = %q, want selector-picked %q", got.ID, selector.lastAuthID[len(selector.lastAuthID)-1])
}
}
func TestManager_InitializesSchedulerForBuiltInSelector(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if manager.scheduler == nil {
t.Fatalf("manager.scheduler = nil")
}
if manager.scheduler.strategy != schedulerStrategyRoundRobin {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyRoundRobin)
}
manager.SetSelector(&FillFirstSelector{})
if manager.scheduler.strategy != schedulerStrategyFillFirst {
t.Fatalf("manager.scheduler.strategy = %v, want %v", manager.scheduler.strategy, schedulerStrategyFillFirst)
}
}
func TestManager_SchedulerTracksRegisterAndUpdate(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() error = %v", errPick)
}
if got == nil || got.ID != "auth-a" {
t.Fatalf("scheduler.pickSingle() auth = %v, want auth-a", got)
}
if _, errUpdate := manager.Update(context.Background(), &Auth{ID: "auth-a", Provider: "gemini", Disabled: true}); errUpdate != nil {
t.Fatalf("Update(auth-a) error = %v", errUpdate)
}
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after update error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after update auth = %v, want auth-b", got)
}
}
func TestManager_PickNextMixed_UsesSchedulerRotation(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["gemini"] = schedulerTestExecutor{}
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-b) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
wantProviders := []string{"gemini", "claude", "gemini", "claude"}
wantIDs := []string{"gemini-a", "claude-a", "gemini-b", "claude-a"}
for index := range wantProviders {
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() #%d auth = nil", index)
}
if provider != wantProviders[index] {
t.Fatalf("pickNextMixed() #%d provider = %q, want %q", index, provider, wantProviders[index])
}
if got.ID != wantIDs[index] {
t.Fatalf("pickNextMixed() #%d auth.ID = %q, want %q", index, got.ID, wantIDs[index])
}
}
}
func TestManager_PickNextMixed_SkipsProvidersWithoutExecutors(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
manager.executors["claude"] = schedulerTestExecutor{}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "gemini-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(gemini-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "claude-a", Provider: "claude"}); errRegister != nil {
t.Fatalf("Register(claude-a) error = %v", errRegister)
}
got, _, provider, errPick := manager.pickNextMixed(context.Background(), []string{"gemini", "claude"}, "", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("pickNextMixed() error = %v", errPick)
}
if got == nil {
t.Fatalf("pickNextMixed() auth = nil")
}
if provider != "claude" {
t.Fatalf("pickNextMixed() provider = %q, want %q", provider, "claude")
}
if got.ID != "claude-a" {
t.Fatalf("pickNextMixed() auth.ID = %q, want %q", got.ID, "claude-a")
}
}
func TestManager_SchedulerTracksMarkResultCooldownAndRecovery(t *testing.T) {
t.Parallel()
manager := NewManager(nil, &RoundRobinSelector{}, nil)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("auth-a", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
reg.RegisterClient("auth-b", "gemini", []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
reg.UnregisterClient("auth-a")
reg.UnregisterClient("auth-b")
})
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-a", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-a) error = %v", errRegister)
}
if _, errRegister := manager.Register(context.Background(), &Auth{ID: "auth-b", Provider: "gemini"}); errRegister != nil {
t.Fatalf("Register(auth-b) error = %v", errRegister)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: false,
Error: &Error{HTTPStatus: 429, Message: "quota"},
})
got, errPick := manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after cooldown error = %v", errPick)
}
if got == nil || got.ID != "auth-b" {
t.Fatalf("scheduler.pickSingle() after cooldown auth = %v, want auth-b", got)
}
manager.MarkResult(context.Background(), Result{
AuthID: "auth-a",
Provider: "gemini",
Model: "test-model",
Success: true,
})
seen := make(map[string]struct{}, 2)
for index := 0; index < 2; index++ {
got, errPick = manager.scheduler.pickSingle(context.Background(), "gemini", "test-model", cliproxyexecutor.Options{}, nil)
if errPick != nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d error = %v", index, errPick)
}
if got == nil {
t.Fatalf("scheduler.pickSingle() after recovery #%d auth = nil", index)
}
seen[got.ID] = struct{}{}
}
if len(seen) != 2 {
t.Fatalf("len(seen) = %d, want %d", len(seen), 2)
}
}
================================================
FILE: sdk/cliproxy/auth/selector.go
================================================
package auth
import (
"context"
"encoding/json"
"fmt"
"math"
"math/rand/v2"
"net/http"
"sort"
"strconv"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// RoundRobinSelector provides a simple provider scoped round-robin selection strategy.
type RoundRobinSelector struct {
mu sync.Mutex
cursors map[string]int
maxKeys int
}
// FillFirstSelector selects the first available credential (deterministic ordering).
// This "burns" one account before moving to the next, which can help stagger
// rolling-window subscription caps (e.g. chat message limits).
type FillFirstSelector struct{}
type blockReason int
const (
blockReasonNone blockReason = iota
blockReasonCooldown
blockReasonDisabled
blockReasonOther
)
type modelCooldownError struct {
model string
resetIn time.Duration
provider string
}
func newModelCooldownError(model, provider string, resetIn time.Duration) *modelCooldownError {
if resetIn < 0 {
resetIn = 0
}
return &modelCooldownError{
model: model,
provider: provider,
resetIn: resetIn,
}
}
func (e *modelCooldownError) Error() string {
modelName := e.model
if modelName == "" {
modelName = "requested model"
}
message := fmt.Sprintf("All credentials for model %s are cooling down", modelName)
if e.provider != "" {
message = fmt.Sprintf("%s via provider %s", message, e.provider)
}
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
if resetSeconds < 0 {
resetSeconds = 0
}
displayDuration := e.resetIn
if displayDuration > 0 && displayDuration < time.Second {
displayDuration = time.Second
} else {
displayDuration = displayDuration.Round(time.Second)
}
errorBody := map[string]any{
"code": "model_cooldown",
"message": message,
"model": e.model,
"reset_time": displayDuration.String(),
"reset_seconds": resetSeconds,
}
if e.provider != "" {
errorBody["provider"] = e.provider
}
payload := map[string]any{"error": errorBody}
data, err := json.Marshal(payload)
if err != nil {
return fmt.Sprintf(`{"error":{"code":"model_cooldown","message":"%s"}}`, message)
}
return string(data)
}
func (e *modelCooldownError) StatusCode() int {
return http.StatusTooManyRequests
}
func (e *modelCooldownError) Headers() http.Header {
headers := make(http.Header)
headers.Set("Content-Type", "application/json")
resetSeconds := int(math.Ceil(e.resetIn.Seconds()))
if resetSeconds < 0 {
resetSeconds = 0
}
headers.Set("Retry-After", strconv.Itoa(resetSeconds))
return headers
}
func authPriority(auth *Auth) int {
if auth == nil || auth.Attributes == nil {
return 0
}
raw := strings.TrimSpace(auth.Attributes["priority"])
if raw == "" {
return 0
}
parsed, err := strconv.Atoi(raw)
if err != nil {
return 0
}
return parsed
}
func canonicalModelKey(model string) string {
model = strings.TrimSpace(model)
if model == "" {
return ""
}
parsed := thinking.ParseSuffix(model)
modelName := strings.TrimSpace(parsed.ModelName)
if modelName == "" {
return model
}
return modelName
}
func authWebsocketsEnabled(auth *Auth) bool {
if auth == nil {
return false
}
if len(auth.Attributes) > 0 {
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed
}
}
}
if len(auth.Metadata) == 0 {
return false
}
raw, ok := auth.Metadata["websockets"]
if !ok || raw == nil {
return false
}
switch v := raw.(type) {
case bool:
return v
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
if errParse == nil {
return parsed
}
default:
}
return false
}
func preferCodexWebsocketAuths(ctx context.Context, provider string, available []*Auth) []*Auth {
if len(available) == 0 {
return available
}
if !cliproxyexecutor.DownstreamWebsocket(ctx) {
return available
}
if !strings.EqualFold(strings.TrimSpace(provider), "codex") {
return available
}
wsEnabled := make([]*Auth, 0, len(available))
for i := 0; i < len(available); i++ {
candidate := available[i]
if authWebsocketsEnabled(candidate) {
wsEnabled = append(wsEnabled, candidate)
}
}
if len(wsEnabled) > 0 {
return wsEnabled
}
return available
}
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
available = make(map[int][]*Auth)
for i := 0; i < len(auths); i++ {
candidate := auths[i]
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
if !blocked {
priority := authPriority(candidate)
available[priority] = append(available[priority], candidate)
continue
}
if reason == blockReasonCooldown {
cooldownCount++
if !next.IsZero() && (earliest.IsZero() || next.Before(earliest)) {
earliest = next
}
}
}
return available, cooldownCount, earliest
}
func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]*Auth, error) {
if len(auths) == 0 {
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
}
availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now)
if len(availableByPriority) == 0 {
if cooldownCount == len(auths) && !earliest.IsZero() {
providerForError := provider
if providerForError == "mixed" {
providerForError = ""
}
resetIn := earliest.Sub(now)
if resetIn < 0 {
resetIn = 0
}
return nil, newModelCooldownError(model, providerForError, resetIn)
}
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
}
bestPriority := 0
found := false
for priority := range availableByPriority {
if !found || priority > bestPriority {
bestPriority = priority
found = true
}
}
available := availableByPriority[bestPriority]
if len(available) > 1 {
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
}
return available, nil
}
// Pick selects the next available auth for the provider in a round-robin manner.
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
// a two-level round-robin is used: first cycling across credential groups (parent
// accounts), then cycling within each group's project auths.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
available = preferCodexWebsocketAuths(ctx, provider, available)
key := provider + ":" + canonicalModelKey(model)
s.mu.Lock()
if s.cursors == nil {
s.cursors = make(map[string]int)
}
limit := s.maxKeys
if limit <= 0 {
limit = 4096
}
// Check if any available auth has gemini_virtual_parent attribute,
// indicating gemini-cli virtual auths that should use credential-level polling.
groups, parentOrder := groupByVirtualParent(available)
if len(parentOrder) > 1 {
// Two-level round-robin: first select a credential group, then pick within it.
groupKey := key + "::group"
s.ensureCursorKey(groupKey, limit)
if _, exists := s.cursors[groupKey]; !exists {
// Seed with a random initial offset so the starting credential is randomized.
s.cursors[groupKey] = rand.IntN(len(parentOrder))
}
groupIndex := s.cursors[groupKey]
if groupIndex >= 2_147_483_640 {
groupIndex = 0
}
s.cursors[groupKey] = groupIndex + 1
selectedParent := parentOrder[groupIndex%len(parentOrder)]
group := groups[selectedParent]
// Second level: round-robin within the selected credential group.
innerKey := key + "::cred:" + selectedParent
s.ensureCursorKey(innerKey, limit)
innerIndex := s.cursors[innerKey]
if innerIndex >= 2_147_483_640 {
innerIndex = 0
}
s.cursors[innerKey] = innerIndex + 1
s.mu.Unlock()
return group[innerIndex%len(group)], nil
}
// Flat round-robin for non-grouped auths (original behavior).
s.ensureCursorKey(key, limit)
index := s.cursors[key]
if index >= 2_147_483_640 {
index = 0
}
s.cursors[key] = index + 1
s.mu.Unlock()
return available[index%len(available)], nil
}
// ensureCursorKey ensures the cursor map has capacity for the given key.
// Must be called with s.mu held.
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
s.cursors = make(map[string]int)
}
}
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
if len(auths) == 0 {
return nil, nil
}
groups := make(map[string][]*Auth)
for _, a := range auths {
parent := ""
if a.Attributes != nil {
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
}
if parent == "" {
// Non-virtual auth present; fall back to flat round-robin.
return nil, nil
}
groups[parent] = append(groups[parent], a)
}
// Collect parent IDs in sorted order for stable cursor indexing.
parentOrder := make([]string, 0, len(groups))
for p := range groups {
parentOrder = append(parentOrder, p)
}
sort.Strings(parentOrder)
return groups, parentOrder
}
// Pick selects the first available auth for the provider in a deterministic manner.
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
available = preferCodexWebsocketAuths(ctx, provider, available)
return available[0], nil
}
func isAuthBlockedForModel(auth *Auth, model string, now time.Time) (bool, blockReason, time.Time) {
if auth == nil {
return true, blockReasonOther, time.Time{}
}
if auth.Disabled || auth.Status == StatusDisabled {
return true, blockReasonDisabled, time.Time{}
}
if model != "" {
if len(auth.ModelStates) > 0 {
state, ok := auth.ModelStates[model]
if (!ok || state == nil) && model != "" {
baseModel := canonicalModelKey(model)
if baseModel != "" && baseModel != model {
state, ok = auth.ModelStates[baseModel]
}
}
if ok && state != nil {
if state.Status == StatusDisabled {
return true, blockReasonDisabled, time.Time{}
}
if state.Unavailable {
if state.NextRetryAfter.IsZero() {
return false, blockReasonNone, time.Time{}
}
if state.NextRetryAfter.After(now) {
next := state.NextRetryAfter
if !state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.After(now) {
next = state.Quota.NextRecoverAt
}
if next.Before(now) {
next = now
}
if state.Quota.Exceeded {
return true, blockReasonCooldown, next
}
return true, blockReasonOther, next
}
}
return false, blockReasonNone, time.Time{}
}
}
return false, blockReasonNone, time.Time{}
}
if auth.Unavailable && auth.NextRetryAfter.After(now) {
next := auth.NextRetryAfter
if !auth.Quota.NextRecoverAt.IsZero() && auth.Quota.NextRecoverAt.After(now) {
next = auth.Quota.NextRecoverAt
}
if next.Before(now) {
next = now
}
if auth.Quota.Exceeded {
return true, blockReasonCooldown, next
}
return true, blockReasonOther, next
}
return false, blockReasonNone, time.Time{}
}
================================================
FILE: sdk/cliproxy/auth/selector_test.go
================================================
package auth
import (
"context"
"encoding/json"
"errors"
"net/http"
"sync"
"testing"
"time"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
func TestFillFirstSelectorPick_Deterministic(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
{ID: "c"},
}
got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got == nil {
t.Fatalf("Pick() auth = nil")
}
if got.ID != "a" {
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "a")
}
}
func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
{ID: "c"},
}
want := []string{"a", "b", "c", "a", "b"}
for i, id := range want {
got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != id {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
}
}
}
func TestRoundRobinSelectorPick_PriorityBuckets(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "c", Attributes: map[string]string{"priority": "0"}},
{ID: "a", Attributes: map[string]string{"priority": "10"}},
{ID: "b", Attributes: map[string]string{"priority": "10"}},
}
want := []string{"a", "b", "a", "b"}
for i, id := range want {
got, err := selector.Pick(context.Background(), "mixed", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != id {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
}
if got.ID == "c" {
t.Fatalf("Pick() #%d unexpectedly selected lower priority auth", i)
}
}
}
func TestFillFirstSelectorPick_PriorityFallbackCooldown(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
now := time.Now()
model := "test-model"
high := &Auth{
ID: "high",
Attributes: map[string]string{"priority": "10"},
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: now.Add(30 * time.Minute),
Quota: QuotaState{
Exceeded: true,
},
},
},
}
low := &Auth{ID: "low", Attributes: map[string]string{"priority": "0"}}
got, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, []*Auth{high, low})
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got == nil {
t.Fatalf("Pick() auth = nil")
}
if got.ID != "low" {
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
}
}
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
{ID: "c"},
}
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, 1)
goroutines := 32
iterations := 100
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
for j := 0; j < iterations; j++ {
got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths)
if err != nil {
select {
case errCh <- err:
default:
}
return
}
if got == nil {
select {
case errCh <- errors.New("Pick() returned nil auth"):
default:
}
return
}
if got.ID == "" {
select {
case errCh <- errors.New("Pick() returned auth with empty ID"):
default:
}
return
}
}
}()
}
close(start)
wg.Wait()
select {
case err := <-errCh:
t.Fatalf("concurrent Pick() error = %v", err)
default:
}
}
func TestSelectorPick_AllCooldownReturnsModelCooldownError(t *testing.T) {
t.Parallel()
model := "test-model"
now := time.Now()
next := now.Add(60 * time.Second)
auths := []*Auth{
{
ID: "a",
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: next,
Quota: QuotaState{
Exceeded: true,
NextRecoverAt: next,
},
},
},
},
{
ID: "b",
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: next,
Quota: QuotaState{
Exceeded: true,
NextRecoverAt: next,
},
},
},
},
}
t.Run("mixed provider redacts provider field", func(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
_, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, auths)
if err == nil {
t.Fatalf("Pick() error = nil")
}
var mce *modelCooldownError
if !errors.As(err, &mce) {
t.Fatalf("Pick() error = %T, want *modelCooldownError", err)
}
if mce.StatusCode() != http.StatusTooManyRequests {
t.Fatalf("StatusCode() = %d, want %d", mce.StatusCode(), http.StatusTooManyRequests)
}
headers := mce.Headers()
if got := headers.Get("Retry-After"); got == "" {
t.Fatalf("Headers().Get(Retry-After) = empty")
}
var payload map[string]any
if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil {
t.Fatalf("json.Unmarshal(Error()) error = %v", err)
}
rawErr, ok := payload["error"].(map[string]any)
if !ok {
t.Fatalf("Error() payload missing error object: %v", payload)
}
if got, _ := rawErr["code"].(string); got != "model_cooldown" {
t.Fatalf("Error().error.code = %q, want %q", got, "model_cooldown")
}
if _, ok := rawErr["provider"]; ok {
t.Fatalf("Error().error.provider exists for mixed provider: %v", rawErr["provider"])
}
})
t.Run("non-mixed provider includes provider field", func(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
_, err := selector.Pick(context.Background(), "gemini", model, cliproxyexecutor.Options{}, auths)
if err == nil {
t.Fatalf("Pick() error = nil")
}
var mce *modelCooldownError
if !errors.As(err, &mce) {
t.Fatalf("Pick() error = %T, want *modelCooldownError", err)
}
var payload map[string]any
if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil {
t.Fatalf("json.Unmarshal(Error()) error = %v", err)
}
rawErr, ok := payload["error"].(map[string]any)
if !ok {
t.Fatalf("Error() payload missing error object: %v", payload)
}
if got, _ := rawErr["provider"].(string); got != "gemini" {
t.Fatalf("Error().error.provider = %q, want %q", got, "gemini")
}
})
}
func TestIsAuthBlockedForModel_UnavailableWithoutNextRetryIsNotBlocked(t *testing.T) {
t.Parallel()
now := time.Now()
model := "test-model"
auth := &Auth{
ID: "a",
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
Quota: QuotaState{
Exceeded: true,
},
},
},
}
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
if blocked {
t.Fatalf("blocked = true, want false")
}
if reason != blockReasonNone {
t.Fatalf("reason = %v, want %v", reason, blockReasonNone)
}
if !next.IsZero() {
t.Fatalf("next = %v, want zero", next)
}
}
func TestFillFirstSelectorPick_ThinkingSuffixFallsBackToBaseModelState(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
now := time.Now()
baseModel := "test-model"
requestedModel := "test-model(high)"
high := &Auth{
ID: "high",
Attributes: map[string]string{"priority": "10"},
ModelStates: map[string]*ModelState{
baseModel: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: now.Add(30 * time.Minute),
Quota: QuotaState{
Exceeded: true,
},
},
},
}
low := &Auth{
ID: "low",
Attributes: map[string]string{"priority": "0"},
}
got, err := selector.Pick(context.Background(), "mixed", requestedModel, cliproxyexecutor.Options{}, []*Auth{high, low})
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got == nil {
t.Fatalf("Pick() auth = nil")
}
if got.ID != "low" {
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
}
}
func TestRoundRobinSelectorPick_ThinkingSuffixSharesCursor(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
}
first, err := selector.Pick(context.Background(), "gemini", "test-model(high)", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() first error = %v", err)
}
second, err := selector.Pick(context.Background(), "gemini", "test-model(low)", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() second error = %v", err)
}
if first == nil || second == nil {
t.Fatalf("Pick() returned nil auth")
}
if first.ID != "a" {
t.Fatalf("Pick() first auth.ID = %q, want %q", first.ID, "a")
}
if second.ID != "b" {
t.Fatalf("Pick() second auth.ID = %q, want %q", second.ID, "b")
}
}
func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{maxKeys: 2}
auths := []*Auth{{ID: "a"}}
_, _ = selector.Pick(context.Background(), "gemini", "m1", cliproxyexecutor.Options{}, auths)
_, _ = selector.Pick(context.Background(), "gemini", "m2", cliproxyexecutor.Options{}, auths)
_, _ = selector.Pick(context.Background(), "gemini", "m3", cliproxyexecutor.Options{}, auths)
selector.mu.Lock()
defer selector.mu.Unlock()
if selector.cursors == nil {
t.Fatalf("selector.cursors = nil")
}
if len(selector.cursors) != 1 {
t.Fatalf("len(selector.cursors) = %d, want %d", len(selector.cursors), 1)
}
if _, ok := selector.cursors["gemini:m3"]; !ok {
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
}
}
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Simulate two gemini-cli credentials, each with multiple projects:
// Credential A (parent = "cred-a.json") has 3 projects
// Credential B (parent = "cred-b.json") has 2 projects
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
}
// Two-level round-robin: consecutive picks must alternate between credentials.
// Credential group order is randomized, but within each call the group cursor
// advances by 1, so consecutive picks should cycle through different parents.
picks := make([]string, 6)
parents := make([]string, 6)
for i := 0; i < 6; i++ {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
picks[i] = got.ID
parents[i] = got.Attributes["gemini_virtual_parent"]
}
// Verify property: consecutive picks must alternate between credential groups.
for i := 1; i < len(parents); i++ {
if parents[i] == parents[i-1] {
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
i-1, i, parents[i], picks[i-1], picks[i])
}
}
// Verify property: each credential's projects are picked in sequence (round-robin within group).
credPicks := map[string][]string{}
for i, id := range picks {
credPicks[parents[i]] = append(credPicks[parents[i]], id)
}
for parent, ids := range credPicks {
for i := 1; i < len(ids); i++ {
if ids[i] == ids[i-1] {
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
}
}
}
}
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// All auths from the same parent - should fall back to flat round-robin
// because there's only one credential group (no benefit from two-level).
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
}
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
// Sorted by ID: proj-a1, proj-a2, proj-a3
want := []string{
"cred-a.json::proj-a1",
"cred-a.json::proj-a2",
"cred-a.json::proj-a3",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
// alongside virtual ones). Should fall back to flat round-robin.
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-regular.json"}, // no gemini_virtual_parent
}
// groupByVirtualParent returns nil when any auth lacks the attribute,
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
want := []string{
"cred-a.json::proj-a1",
"cred-regular.json",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}
================================================
FILE: sdk/cliproxy/auth/status.go
================================================
package auth
// Status represents the lifecycle state of an Auth entry.
type Status string
const (
// StatusUnknown means the auth state could not be determined.
StatusUnknown Status = "unknown"
// StatusActive indicates the auth is valid and ready for execution.
StatusActive Status = "active"
// StatusPending indicates the auth is waiting for an external action, such as MFA.
StatusPending Status = "pending"
// StatusRefreshing indicates the auth is undergoing a refresh flow.
StatusRefreshing Status = "refreshing"
// StatusError indicates the auth is temporarily unavailable due to errors.
StatusError Status = "error"
// StatusDisabled marks the auth as intentionally disabled.
StatusDisabled Status = "disabled"
)
================================================
FILE: sdk/cliproxy/auth/store.go
================================================
package auth
import "context"
// Store abstracts persistence of Auth state across restarts.
type Store interface {
// List returns all auth records stored in the backend.
List(ctx context.Context) ([]*Auth, error)
// Save persists the provided auth record, replacing any existing one with same ID.
Save(ctx context.Context, auth *Auth) (string, error)
// Delete removes the auth record identified by id.
Delete(ctx context.Context, id string) error
}
================================================
FILE: sdk/cliproxy/auth/types.go
================================================
package auth
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"net/http"
"net/url"
"strconv"
"strings"
"sync"
"time"
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
)
// PostAuthHook defines a function that is called after an Auth record is created
// but before it is persisted to storage. This allows for modification of the
// Auth record (e.g., injecting metadata) based on external context.
type PostAuthHook func(context.Context, *Auth) error
// RequestInfo holds information extracted from the HTTP request.
// It is injected into the context passed to PostAuthHook.
type RequestInfo struct {
Query url.Values
Headers http.Header
}
type requestInfoKey struct{}
// WithRequestInfo returns a new context with the given RequestInfo attached.
func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context {
return context.WithValue(ctx, requestInfoKey{}, info)
}
// GetRequestInfo retrieves the RequestInfo from the context, if present.
func GetRequestInfo(ctx context.Context) *RequestInfo {
if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok {
return val
}
return nil
}
// Auth encapsulates the runtime state and metadata associated with a single credential.
type Auth struct {
// ID uniquely identifies the auth record across restarts.
ID string `json:"id"`
// Index is a stable runtime identifier derived from auth metadata (not persisted).
Index string `json:"-"`
// Provider is the upstream provider key (e.g. "gemini", "claude").
Provider string `json:"provider"`
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
Prefix string `json:"prefix,omitempty"`
// FileName stores the relative or absolute path of the backing auth file.
FileName string `json:"-"`
// Storage holds the token persistence implementation used during login flows.
Storage baseauth.TokenStorage `json:"-"`
// Label is an optional human readable label for logging.
Label string `json:"label,omitempty"`
// Status is the lifecycle status managed by the AuthManager.
Status Status `json:"status"`
// StatusMessage holds a short description for the current status.
StatusMessage string `json:"status_message,omitempty"`
// Disabled indicates the auth is intentionally disabled by operator.
Disabled bool `json:"disabled"`
// Unavailable flags transient provider unavailability (e.g. quota exceeded).
Unavailable bool `json:"unavailable"`
// ProxyURL overrides the global proxy setting for this auth if provided.
ProxyURL string `json:"proxy_url,omitempty"`
// Attributes stores provider specific metadata needed by executors (immutable configuration).
Attributes map[string]string `json:"attributes,omitempty"`
// Metadata stores runtime mutable provider state (e.g. tokens, cookies).
Metadata map[string]any `json:"metadata,omitempty"`
// Quota captures recent quota information for load balancers.
Quota QuotaState `json:"quota"`
// LastError stores the last failure encountered while executing or refreshing.
LastError *Error `json:"last_error,omitempty"`
// CreatedAt is the creation timestamp in UTC.
CreatedAt time.Time `json:"created_at"`
// UpdatedAt is the last modification timestamp in UTC.
UpdatedAt time.Time `json:"updated_at"`
// LastRefreshedAt records the last successful refresh time in UTC.
LastRefreshedAt time.Time `json:"last_refreshed_at"`
// NextRefreshAfter is the earliest time a refresh should retrigger.
NextRefreshAfter time.Time `json:"next_refresh_after"`
// NextRetryAfter is the earliest time a retry should retrigger.
NextRetryAfter time.Time `json:"next_retry_after"`
// ModelStates tracks per-model runtime availability data.
ModelStates map[string]*ModelState `json:"model_states,omitempty"`
// Runtime carries non-serialisable data used during execution (in-memory only).
Runtime any `json:"-"`
indexAssigned bool `json:"-"`
}
// QuotaState contains limiter tracking data for a credential.
type QuotaState struct {
// Exceeded indicates the credential recently hit a quota error.
Exceeded bool `json:"exceeded"`
// Reason provides an optional provider specific human readable description.
Reason string `json:"reason,omitempty"`
// NextRecoverAt is when the credential may become available again.
NextRecoverAt time.Time `json:"next_recover_at"`
// BackoffLevel stores the progressive cooldown exponent used for rate limits.
BackoffLevel int `json:"backoff_level,omitempty"`
}
// ModelState captures the execution state for a specific model under an auth entry.
type ModelState struct {
// Status reflects the lifecycle status for this model.
Status Status `json:"status"`
// StatusMessage provides an optional short description of the status.
StatusMessage string `json:"status_message,omitempty"`
// Unavailable mirrors whether the model is temporarily blocked for retries.
Unavailable bool `json:"unavailable"`
// NextRetryAfter defines the per-model retry time.
NextRetryAfter time.Time `json:"next_retry_after"`
// LastError records the latest error observed for this model.
LastError *Error `json:"last_error,omitempty"`
// Quota retains quota information if this model hit rate limits.
Quota QuotaState `json:"quota"`
// UpdatedAt tracks the last update timestamp for this model state.
UpdatedAt time.Time `json:"updated_at"`
}
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
func (a *Auth) Clone() *Auth {
if a == nil {
return nil
}
copyAuth := *a
if len(a.Attributes) > 0 {
copyAuth.Attributes = make(map[string]string, len(a.Attributes))
for key, value := range a.Attributes {
copyAuth.Attributes[key] = value
}
}
if len(a.Metadata) > 0 {
copyAuth.Metadata = make(map[string]any, len(a.Metadata))
for key, value := range a.Metadata {
copyAuth.Metadata[key] = value
}
}
if len(a.ModelStates) > 0 {
copyAuth.ModelStates = make(map[string]*ModelState, len(a.ModelStates))
for key, state := range a.ModelStates {
copyAuth.ModelStates[key] = state.Clone()
}
}
copyAuth.Runtime = a.Runtime
return ©Auth
}
func stableAuthIndex(seed string) string {
seed = strings.TrimSpace(seed)
if seed == "" {
return ""
}
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:8])
}
func (a *Auth) indexSeed() string {
if a == nil {
return ""
}
if fileName := strings.TrimSpace(a.FileName); fileName != "" {
return "file:" + fileName
}
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
compatName := ""
baseURL := ""
apiKey := ""
source := ""
if a.Attributes != nil {
if value := strings.TrimSpace(a.Attributes["provider_key"]); value != "" {
providerKey = strings.ToLower(value)
}
compatName = strings.ToLower(strings.TrimSpace(a.Attributes["compat_name"]))
baseURL = strings.TrimSpace(a.Attributes["base_url"])
apiKey = strings.TrimSpace(a.Attributes["api_key"])
source = strings.TrimSpace(a.Attributes["source"])
}
proxyURL := strings.TrimSpace(a.ProxyURL)
hasCredentialIdentity := compatName != "" || baseURL != "" || proxyURL != "" || apiKey != "" || source != ""
if providerKey != "" && hasCredentialIdentity {
parts := []string{"provider=" + providerKey}
if compatName != "" {
parts = append(parts, "compat="+compatName)
}
if baseURL != "" {
parts = append(parts, "base="+baseURL)
}
if proxyURL != "" {
parts = append(parts, "proxy="+proxyURL)
}
if apiKey != "" {
parts = append(parts, "api_key="+apiKey)
}
if source != "" {
parts = append(parts, "source="+source)
}
return "config:" + strings.Join(parts, "\x00")
}
if id := strings.TrimSpace(a.ID); id != "" {
return "id:" + id
}
return ""
}
// EnsureIndex returns a stable index derived from the auth file name or credential identity.
func (a *Auth) EnsureIndex() string {
if a == nil {
return ""
}
if a.indexAssigned && a.Index != "" {
return a.Index
}
seed := a.indexSeed()
if seed == "" {
return ""
}
idx := stableAuthIndex(seed)
a.Index = idx
a.indexAssigned = true
return idx
}
// Clone duplicates a model state including nested error details.
func (m *ModelState) Clone() *ModelState {
if m == nil {
return nil
}
copyState := *m
if m.LastError != nil {
copyState.LastError = &Error{
Code: m.LastError.Code,
Message: m.LastError.Message,
Retryable: m.LastError.Retryable,
HTTPStatus: m.LastError.HTTPStatus,
}
}
return ©State
}
func (a *Auth) ProxyInfo() string {
if a == nil {
return ""
}
proxyStr := strings.TrimSpace(a.ProxyURL)
if proxyStr == "" {
return ""
}
if idx := strings.Index(proxyStr, "://"); idx > 0 {
return "via " + proxyStr[:idx] + " proxy"
}
return "via proxy"
}
// DisableCoolingOverride returns the auth-file scoped disable_cooling override when present.
// The value is read from metadata key "disable_cooling" (or legacy "disable-cooling").
func (a *Auth) DisableCoolingOverride() (bool, bool) {
if a == nil || a.Metadata == nil {
return false, false
}
if val, ok := a.Metadata["disable_cooling"]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed, true
}
}
if val, ok := a.Metadata["disable-cooling"]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed, true
}
}
return false, false
}
// ToolPrefixDisabled returns whether the proxy_ tool name prefix should be
// skipped for this auth. When true, tool names are sent to Anthropic unchanged.
// The value is read from metadata key "tool_prefix_disabled" (or "tool-prefix-disabled").
func (a *Auth) ToolPrefixDisabled() bool {
if a == nil || a.Metadata == nil {
return false
}
for _, key := range []string{"tool_prefix_disabled", "tool-prefix-disabled"} {
if val, ok := a.Metadata[key]; ok {
if parsed, okParse := parseBoolAny(val); okParse {
return parsed
}
}
}
return false
}
// RequestRetryOverride returns the auth-file scoped request_retry override when present.
// The value is read from metadata key "request_retry" (or legacy "request-retry").
func (a *Auth) RequestRetryOverride() (int, bool) {
if a == nil || a.Metadata == nil {
return 0, false
}
if val, ok := a.Metadata["request_retry"]; ok {
if parsed, okParse := parseIntAny(val); okParse {
if parsed < 0 {
parsed = 0
}
return parsed, true
}
}
if val, ok := a.Metadata["request-retry"]; ok {
if parsed, okParse := parseIntAny(val); okParse {
if parsed < 0 {
parsed = 0
}
return parsed, true
}
}
return 0, false
}
func parseBoolAny(val any) (bool, bool) {
switch typed := val.(type) {
case bool:
return typed, true
case string:
trimmed := strings.TrimSpace(typed)
if trimmed == "" {
return false, false
}
parsed, err := strconv.ParseBool(trimmed)
if err != nil {
return false, false
}
return parsed, true
case float64:
return typed != 0, true
case json.Number:
parsed, err := typed.Int64()
if err != nil {
return false, false
}
return parsed != 0, true
default:
return false, false
}
}
func parseIntAny(val any) (int, bool) {
switch typed := val.(type) {
case int:
return typed, true
case int32:
return int(typed), true
case int64:
return int(typed), true
case float64:
return int(typed), true
case json.Number:
parsed, err := typed.Int64()
if err != nil {
return 0, false
}
return int(parsed), true
case string:
trimmed := strings.TrimSpace(typed)
if trimmed == "" {
return 0, false
}
parsed, err := strconv.Atoi(trimmed)
if err != nil {
return 0, false
}
return parsed, true
default:
return 0, false
}
}
func (a *Auth) AccountInfo() (string, string) {
if a == nil {
return "", ""
}
// For Gemini CLI, include project ID in the OAuth account info if present.
if strings.ToLower(a.Provider) == "gemini-cli" {
if a.Metadata != nil {
email, _ := a.Metadata["email"].(string)
email = strings.TrimSpace(email)
if email != "" {
if p, ok := a.Metadata["project_id"].(string); ok {
p = strings.TrimSpace(p)
if p != "" {
return "oauth", email + " (" + p + ")"
}
}
return "oauth", email
}
}
}
// For iFlow provider, prioritize OAuth type if email is present
if strings.ToLower(a.Provider) == "iflow" {
if a.Metadata != nil {
if email, ok := a.Metadata["email"].(string); ok {
email = strings.TrimSpace(email)
if email != "" {
return "oauth", email
}
}
}
}
// Check metadata for email first (OAuth-style auth)
if a.Metadata != nil {
if v, ok := a.Metadata["email"].(string); ok {
email := strings.TrimSpace(v)
if email != "" {
return "oauth", email
}
}
}
// Fall back to API key (API-key auth)
if a.Attributes != nil {
if v := a.Attributes["api_key"]; v != "" {
return "api_key", v
}
}
return "", ""
}
// ExpirationTime attempts to extract the credential expiration timestamp from metadata.
// It inspects common keys such as "expired", "expire", "expires_at", and also
// nested "token" objects to remain compatible with legacy auth file formats.
func (a *Auth) ExpirationTime() (time.Time, bool) {
if a == nil {
return time.Time{}, false
}
if ts, ok := expirationFromMap(a.Metadata); ok {
return ts, true
}
return time.Time{}, false
}
var (
refreshLeadMu sync.RWMutex
refreshLeadFactories = make(map[string]func() *time.Duration)
)
func RegisterRefreshLeadProvider(provider string, factory func() *time.Duration) {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" || factory == nil {
return
}
refreshLeadMu.Lock()
refreshLeadFactories[provider] = factory
refreshLeadMu.Unlock()
}
var expireKeys = [...]string{"expired", "expire", "expires_at", "expiresAt", "expiry", "expires"}
func expirationFromMap(meta map[string]any) (time.Time, bool) {
if meta == nil {
return time.Time{}, false
}
for _, key := range expireKeys {
if v, ok := meta[key]; ok {
if ts, ok1 := parseTimeValue(v); ok1 {
return ts, true
}
}
}
for _, nestedKey := range []string{"token", "Token"} {
if nested, ok := meta[nestedKey]; ok {
switch val := nested.(type) {
case map[string]any:
if ts, ok1 := expirationFromMap(val); ok1 {
return ts, true
}
case map[string]string:
temp := make(map[string]any, len(val))
for k, v := range val {
temp[k] = v
}
if ts, ok1 := expirationFromMap(temp); ok1 {
return ts, true
}
}
}
}
return time.Time{}, false
}
func ProviderRefreshLead(provider string, runtime any) *time.Duration {
provider = strings.ToLower(strings.TrimSpace(provider))
if runtime != nil {
if eval, ok := runtime.(interface{ RefreshLead() *time.Duration }); ok {
if lead := eval.RefreshLead(); lead != nil && *lead > 0 {
return lead
}
}
}
refreshLeadMu.RLock()
factory := refreshLeadFactories[provider]
refreshLeadMu.RUnlock()
if factory == nil {
return nil
}
if lead := factory(); lead != nil && *lead > 0 {
return lead
}
return nil
}
func parseTimeValue(v any) (time.Time, bool) {
switch value := v.(type) {
case string:
s := strings.TrimSpace(value)
if s == "" {
return time.Time{}, false
}
layouts := []string{
time.RFC3339,
time.RFC3339Nano,
"2006-01-02 15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04:05Z07:00",
}
for _, layout := range layouts {
if ts, err := time.Parse(layout, s); err == nil {
return ts, true
}
}
if unix, err := strconv.ParseInt(s, 10, 64); err == nil {
return normaliseUnix(unix), true
}
case float64:
return normaliseUnix(int64(value)), true
case int64:
return normaliseUnix(value), true
case json.Number:
if i, err := value.Int64(); err == nil {
return normaliseUnix(i), true
}
if f, err := value.Float64(); err == nil {
return normaliseUnix(int64(f)), true
}
}
return time.Time{}, false
}
func normaliseUnix(raw int64) time.Time {
if raw <= 0 {
return time.Time{}
}
// Heuristic: treat values with millisecond precision (>1e12) accordingly.
if raw > 1_000_000_000_000 {
return time.UnixMilli(raw)
}
return time.Unix(raw, 0)
}
================================================
FILE: sdk/cliproxy/auth/types_test.go
================================================
package auth
import "testing"
func TestToolPrefixDisabled(t *testing.T) {
var a *Auth
if a.ToolPrefixDisabled() {
t.Error("nil auth should return false")
}
a = &Auth{}
if a.ToolPrefixDisabled() {
t.Error("empty auth should return false")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": true}}
if !a.ToolPrefixDisabled() {
t.Error("should return true when set to true")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": "true"}}
if !a.ToolPrefixDisabled() {
t.Error("should return true when set to string 'true'")
}
a = &Auth{Metadata: map[string]any{"tool-prefix-disabled": true}}
if !a.ToolPrefixDisabled() {
t.Error("should return true with kebab-case key")
}
a = &Auth{Metadata: map[string]any{"tool_prefix_disabled": false}}
if a.ToolPrefixDisabled() {
t.Error("should return false when set to false")
}
}
func TestEnsureIndexUsesCredentialIdentity(t *testing.T) {
t.Parallel()
geminiAuth := &Auth{
Provider: "gemini",
Attributes: map[string]string{
"api_key": "shared-key",
"source": "config:gemini[abc123]",
},
}
compatAuth := &Auth{
Provider: "bohe",
Attributes: map[string]string{
"api_key": "shared-key",
"compat_name": "bohe",
"provider_key": "bohe",
"source": "config:bohe[def456]",
},
}
geminiAltBase := &Auth{
Provider: "gemini",
Attributes: map[string]string{
"api_key": "shared-key",
"base_url": "https://alt.example.com",
"source": "config:gemini[ghi789]",
},
}
geminiDuplicate := &Auth{
Provider: "gemini",
Attributes: map[string]string{
"api_key": "shared-key",
"source": "config:gemini[abc123-1]",
},
}
geminiIndex := geminiAuth.EnsureIndex()
compatIndex := compatAuth.EnsureIndex()
altBaseIndex := geminiAltBase.EnsureIndex()
duplicateIndex := geminiDuplicate.EnsureIndex()
if geminiIndex == "" {
t.Fatal("gemini index should not be empty")
}
if compatIndex == "" {
t.Fatal("compat index should not be empty")
}
if altBaseIndex == "" {
t.Fatal("alt base index should not be empty")
}
if duplicateIndex == "" {
t.Fatal("duplicate index should not be empty")
}
if geminiIndex == compatIndex {
t.Fatalf("shared api key produced duplicate auth_index %q", geminiIndex)
}
if geminiIndex == altBaseIndex {
t.Fatalf("same provider/key with different base_url produced duplicate auth_index %q", geminiIndex)
}
if geminiIndex == duplicateIndex {
t.Fatalf("duplicate config entries should be separated by source-derived seed, got %q", geminiIndex)
}
}
================================================
FILE: sdk/cliproxy/builder.go
================================================
// Package cliproxy provides the core service implementation for the CLI Proxy API.
// It includes service lifecycle management, authentication handling, file watching,
// and integration with various AI service providers through a unified interface.
package cliproxy
import (
"fmt"
"strings"
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
// Builder constructs a Service instance with customizable providers.
// It provides a fluent interface for configuring all aspects of the service
// including authentication, file watching, HTTP server options, and lifecycle hooks.
type Builder struct {
// cfg holds the application configuration.
cfg *config.Config
// configPath is the path to the configuration file.
configPath string
// tokenProvider handles loading token-based clients.
tokenProvider TokenClientProvider
// apiKeyProvider handles loading API key-based clients.
apiKeyProvider APIKeyClientProvider
// watcherFactory creates file watcher instances.
watcherFactory WatcherFactory
// hooks provides lifecycle callbacks.
hooks Hooks
// authManager handles legacy authentication operations.
authManager *sdkAuth.Manager
// accessManager handles request authentication providers.
accessManager *sdkaccess.Manager
// coreManager handles core authentication and execution.
coreManager *coreauth.Manager
// serverOptions contains additional server configuration options.
serverOptions []api.ServerOption
}
// Hooks allows callers to plug into service lifecycle stages.
// These callbacks provide opportunities to perform custom initialization
// and cleanup operations during service startup and shutdown.
type Hooks struct {
// OnBeforeStart is called before the service starts, allowing configuration
// modifications or additional setup.
OnBeforeStart func(*config.Config)
// OnAfterStart is called after the service has started successfully,
// providing access to the service instance for additional operations.
OnAfterStart func(*Service)
}
// NewBuilder creates a Builder with default dependencies left unset.
// Use the fluent interface methods to configure the service before calling Build().
//
// Returns:
// - *Builder: A new builder instance ready for configuration
func NewBuilder() *Builder {
return &Builder{}
}
// WithConfig sets the configuration instance used by the service.
//
// Parameters:
// - cfg: The application configuration
//
// Returns:
// - *Builder: The builder instance for method chaining
func (b *Builder) WithConfig(cfg *config.Config) *Builder {
b.cfg = cfg
return b
}
// WithConfigPath sets the absolute configuration file path used for reload watching.
//
// Parameters:
// - path: The absolute path to the configuration file
//
// Returns:
// - *Builder: The builder instance for method chaining
func (b *Builder) WithConfigPath(path string) *Builder {
b.configPath = path
return b
}
// WithTokenClientProvider overrides the provider responsible for token-backed clients.
func (b *Builder) WithTokenClientProvider(provider TokenClientProvider) *Builder {
b.tokenProvider = provider
return b
}
// WithAPIKeyClientProvider overrides the provider responsible for API key-backed clients.
func (b *Builder) WithAPIKeyClientProvider(provider APIKeyClientProvider) *Builder {
b.apiKeyProvider = provider
return b
}
// WithWatcherFactory allows customizing the watcher factory that handles reloads.
func (b *Builder) WithWatcherFactory(factory WatcherFactory) *Builder {
b.watcherFactory = factory
return b
}
// WithHooks registers lifecycle hooks executed around service startup.
func (b *Builder) WithHooks(h Hooks) *Builder {
b.hooks = h
return b
}
// WithAuthManager overrides the authentication manager used for token lifecycle operations.
func (b *Builder) WithAuthManager(mgr *sdkAuth.Manager) *Builder {
b.authManager = mgr
return b
}
// WithRequestAccessManager overrides the request authentication manager.
func (b *Builder) WithRequestAccessManager(mgr *sdkaccess.Manager) *Builder {
b.accessManager = mgr
return b
}
// WithCoreAuthManager overrides the runtime auth manager responsible for request execution.
func (b *Builder) WithCoreAuthManager(mgr *coreauth.Manager) *Builder {
b.coreManager = mgr
return b
}
// WithServerOptions appends server configuration options used during construction.
func (b *Builder) WithServerOptions(opts ...api.ServerOption) *Builder {
b.serverOptions = append(b.serverOptions, opts...)
return b
}
// WithLocalManagementPassword configures a password that is only accepted from localhost management requests.
func (b *Builder) WithLocalManagementPassword(password string) *Builder {
if password == "" {
return b
}
b.serverOptions = append(b.serverOptions, api.WithLocalManagementPassword(password))
return b
}
// WithPostAuthHook registers a hook to be called after an Auth record is created
// but before it is persisted to storage.
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
if hook == nil {
return b
}
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
return b
}
// Build validates inputs, applies defaults, and returns a ready-to-run service.
func (b *Builder) Build() (*Service, error) {
if b.cfg == nil {
return nil, fmt.Errorf("cliproxy: configuration is required")
}
if b.configPath == "" {
return nil, fmt.Errorf("cliproxy: configuration path is required")
}
tokenProvider := b.tokenProvider
if tokenProvider == nil {
tokenProvider = NewFileTokenClientProvider()
}
apiKeyProvider := b.apiKeyProvider
if apiKeyProvider == nil {
apiKeyProvider = NewAPIKeyClientProvider()
}
watcherFactory := b.watcherFactory
if watcherFactory == nil {
watcherFactory = defaultWatcherFactory
}
authManager := b.authManager
if authManager == nil {
authManager = newDefaultAuthManager()
}
accessManager := b.accessManager
if accessManager == nil {
accessManager = sdkaccess.NewManager()
}
configaccess.Register(&b.cfg.SDKConfig)
accessManager.SetProviders(sdkaccess.RegisteredProviders())
coreManager := b.coreManager
if coreManager == nil {
tokenStore := sdkAuth.GetTokenStore()
if dirSetter, ok := tokenStore.(interface{ SetBaseDir(string) }); ok && b.cfg != nil {
dirSetter.SetBaseDir(b.cfg.AuthDir)
}
strategy := ""
if b.cfg != nil {
strategy = strings.ToLower(strings.TrimSpace(b.cfg.Routing.Strategy))
}
var selector coreauth.Selector
switch strategy {
case "fill-first", "fillfirst", "ff":
selector = &coreauth.FillFirstSelector{}
default:
selector = &coreauth.RoundRobinSelector{}
}
coreManager = coreauth.NewManager(tokenStore, selector, nil)
}
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
coreManager.SetConfig(b.cfg)
coreManager.SetOAuthModelAlias(b.cfg.OAuthModelAlias)
service := &Service{
cfg: b.cfg,
configPath: b.configPath,
tokenProvider: tokenProvider,
apiKeyProvider: apiKeyProvider,
watcherFactory: watcherFactory,
hooks: b.hooks,
authManager: authManager,
accessManager: accessManager,
coreManager: coreManager,
serverOptions: append([]api.ServerOption(nil), b.serverOptions...),
}
return service, nil
}
================================================
FILE: sdk/cliproxy/executor/context.go
================================================
package executor
import "context"
type downstreamWebsocketContextKey struct{}
// WithDownstreamWebsocket marks the current request as coming from a downstream websocket connection.
func WithDownstreamWebsocket(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, downstreamWebsocketContextKey{}, true)
}
// DownstreamWebsocket reports whether the current request originates from a downstream websocket connection.
func DownstreamWebsocket(ctx context.Context) bool {
if ctx == nil {
return false
}
raw := ctx.Value(downstreamWebsocketContextKey{})
enabled, ok := raw.(bool)
return ok && enabled
}
================================================
FILE: sdk/cliproxy/executor/types.go
================================================
package executor
import (
"net/http"
"net/url"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
const RequestedModelMetadataKey = "requested_model"
const (
// PinnedAuthMetadataKey locks execution to a specific auth ID.
PinnedAuthMetadataKey = "pinned_auth_id"
// SelectedAuthMetadataKey stores the auth ID selected by the scheduler.
SelectedAuthMetadataKey = "selected_auth_id"
// SelectedAuthCallbackMetadataKey carries an optional callback invoked with the selected auth ID.
SelectedAuthCallbackMetadataKey = "selected_auth_callback"
// ExecutionSessionMetadataKey identifies a long-lived downstream execution session.
ExecutionSessionMetadataKey = "execution_session_id"
)
// Request encapsulates the translated payload that will be sent to a provider executor.
type Request struct {
// Model is the upstream model identifier after translation.
Model string
// Payload is the provider specific JSON payload.
Payload []byte
// Format represents the provider payload schema.
Format sdktranslator.Format
// Metadata carries optional provider specific execution hints.
Metadata map[string]any
}
// Options controls execution behavior for both streaming and non-streaming calls.
type Options struct {
// Stream toggles streaming mode.
Stream bool
// Alt carries optional alternate format hint (e.g. SSE JSON key).
Alt string
// Headers are forwarded to the provider request builder.
Headers http.Header
// Query contains optional query string parameters.
Query url.Values
// OriginalRequest preserves the inbound request bytes prior to translation.
OriginalRequest []byte
// SourceFormat identifies the inbound schema.
SourceFormat sdktranslator.Format
// Metadata carries extra execution hints shared across selection and executors.
Metadata map[string]any
}
// Response wraps either a full provider response or metadata for streaming flows.
type Response struct {
// Payload is the provider response in the executor format.
Payload []byte
// Metadata exposes optional structured data for translators.
Metadata map[string]any
// Headers carries upstream HTTP response headers for passthrough to clients.
Headers http.Header
}
// StreamChunk represents a single streaming payload unit emitted by provider executors.
type StreamChunk struct {
// Payload is the raw provider chunk payload.
Payload []byte
// Err reports any terminal error encountered while producing chunks.
Err error
}
// StreamResult wraps the streaming response, providing both the chunk channel
// and the upstream HTTP response headers captured before streaming begins.
type StreamResult struct {
// Headers carries upstream HTTP response headers from the initial connection.
Headers http.Header
// Chunks is the channel of streaming payload units.
Chunks <-chan StreamChunk
}
// StatusError represents an error that carries an HTTP-like status code.
// Provider executors should implement this when possible to enable
// better auth state updates on failures (e.g., 401/402/429).
type StatusError interface {
error
StatusCode() int
}
================================================
FILE: sdk/cliproxy/model_registry.go
================================================
package cliproxy
import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
// ModelInfo re-exports the registry model info structure.
type ModelInfo = registry.ModelInfo
// ModelRegistryHook re-exports the registry hook interface for external integrations.
type ModelRegistryHook = registry.ModelRegistryHook
// ModelRegistry describes registry operations consumed by external callers.
type ModelRegistry interface {
RegisterClient(clientID, clientProvider string, models []*ModelInfo)
UnregisterClient(clientID string)
SetModelQuotaExceeded(clientID, modelID string)
ClearModelQuotaExceeded(clientID, modelID string)
ClientSupportsModel(clientID, modelID string) bool
GetAvailableModels(handlerType string) []map[string]any
GetAvailableModelsByProvider(provider string) []*ModelInfo
}
// GlobalModelRegistry returns the shared registry instance.
func GlobalModelRegistry() ModelRegistry {
return registry.GetGlobalRegistry()
}
// SetGlobalModelRegistryHook registers an optional hook on the shared global registry instance.
func SetGlobalModelRegistryHook(hook ModelRegistryHook) {
registry.GetGlobalRegistry().SetHook(hook)
}
================================================
FILE: sdk/cliproxy/pipeline/context.go
================================================
package pipeline
import (
"context"
"net/http"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
)
// Context encapsulates execution state shared across middleware, translators, and executors.
type Context struct {
// Request encapsulates the provider facing request payload.
Request cliproxyexecutor.Request
// Options carries execution flags (streaming, headers, etc.).
Options cliproxyexecutor.Options
// Auth references the credential selected for execution.
Auth *cliproxyauth.Auth
// Translator represents the pipeline responsible for schema adaptation.
Translator *sdktranslator.Pipeline
// HTTPClient allows middleware to customise the outbound transport per request.
HTTPClient *http.Client
}
// Hook captures middleware callbacks around execution.
type Hook interface {
BeforeExecute(ctx context.Context, execCtx *Context)
AfterExecute(ctx context.Context, execCtx *Context, resp cliproxyexecutor.Response, err error)
OnStreamChunk(ctx context.Context, execCtx *Context, chunk cliproxyexecutor.StreamChunk)
}
// HookFunc aggregates optional hook implementations.
type HookFunc struct {
Before func(context.Context, *Context)
After func(context.Context, *Context, cliproxyexecutor.Response, error)
Stream func(context.Context, *Context, cliproxyexecutor.StreamChunk)
}
// BeforeExecute implements Hook.
func (h HookFunc) BeforeExecute(ctx context.Context, execCtx *Context) {
if h.Before != nil {
h.Before(ctx, execCtx)
}
}
// AfterExecute implements Hook.
func (h HookFunc) AfterExecute(ctx context.Context, execCtx *Context, resp cliproxyexecutor.Response, err error) {
if h.After != nil {
h.After(ctx, execCtx, resp, err)
}
}
// OnStreamChunk implements Hook.
func (h HookFunc) OnStreamChunk(ctx context.Context, execCtx *Context, chunk cliproxyexecutor.StreamChunk) {
if h.Stream != nil {
h.Stream(ctx, execCtx, chunk)
}
}
// RoundTripperProvider allows injection of custom HTTP transports per auth entry.
type RoundTripperProvider interface {
RoundTripperFor(auth *cliproxyauth.Auth) http.RoundTripper
}
================================================
FILE: sdk/cliproxy/pprof_server.go
================================================
package cliproxy
import (
"context"
"errors"
"net/http"
"net/http/pprof"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
type pprofServer struct {
mu sync.Mutex
server *http.Server
addr string
enabled bool
}
func newPprofServer() *pprofServer {
return &pprofServer{}
}
func (s *Service) applyPprofConfig(cfg *config.Config) {
if s == nil || cfg == nil {
return
}
if s.pprofServer == nil {
s.pprofServer = newPprofServer()
}
s.pprofServer.Apply(cfg)
}
func (s *Service) shutdownPprof(ctx context.Context) error {
if s == nil || s.pprofServer == nil {
return nil
}
return s.pprofServer.Shutdown(ctx)
}
func (p *pprofServer) Apply(cfg *config.Config) {
if p == nil || cfg == nil {
return
}
addr := strings.TrimSpace(cfg.Pprof.Addr)
if addr == "" {
addr = config.DefaultPprofAddr
}
enabled := cfg.Pprof.Enable
p.mu.Lock()
currentServer := p.server
currentAddr := p.addr
p.addr = addr
p.enabled = enabled
if !enabled {
p.server = nil
p.mu.Unlock()
if currentServer != nil {
p.stopServer(currentServer, currentAddr, "disabled")
}
return
}
if currentServer != nil && currentAddr == addr {
p.mu.Unlock()
return
}
p.server = nil
p.mu.Unlock()
if currentServer != nil {
p.stopServer(currentServer, currentAddr, "restarted")
}
p.startServer(addr)
}
func (p *pprofServer) Shutdown(ctx context.Context) error {
if p == nil {
return nil
}
p.mu.Lock()
currentServer := p.server
currentAddr := p.addr
p.server = nil
p.enabled = false
p.mu.Unlock()
if currentServer == nil {
return nil
}
return p.stopServerWithContext(ctx, currentServer, currentAddr, "shutdown")
}
func (p *pprofServer) startServer(addr string) {
mux := newPprofMux()
server := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
p.mu.Lock()
if !p.enabled || p.addr != addr || p.server != nil {
p.mu.Unlock()
return
}
p.server = server
p.mu.Unlock()
log.Infof("pprof server starting on %s", addr)
go func() {
if errServe := server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
log.Errorf("pprof server failed on %s: %v", addr, errServe)
p.mu.Lock()
if p.server == server {
p.server = nil
}
p.mu.Unlock()
}
}()
}
func (p *pprofServer) stopServer(server *http.Server, addr string, reason string) {
_ = p.stopServerWithContext(context.Background(), server, addr, reason)
}
func (p *pprofServer) stopServerWithContext(ctx context.Context, server *http.Server, addr string, reason string) error {
if server == nil {
return nil
}
stopCtx := ctx
if stopCtx == nil {
stopCtx = context.Background()
}
stopCtx, cancel := context.WithTimeout(stopCtx, 5*time.Second)
defer cancel()
if errStop := server.Shutdown(stopCtx); errStop != nil {
log.Errorf("pprof server stop failed on %s: %v", addr, errStop)
return errStop
}
log.Infof("pprof server stopped on %s (%s)", addr, reason)
return nil
}
func newPprofMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
return mux
}
================================================
FILE: sdk/cliproxy/providers.go
================================================
package cliproxy
import (
"context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
// NewFileTokenClientProvider returns the default token-backed client loader.
func NewFileTokenClientProvider() TokenClientProvider {
return &fileTokenClientProvider{}
}
type fileTokenClientProvider struct{}
func (p *fileTokenClientProvider) Load(ctx context.Context, cfg *config.Config) (*TokenClientResult, error) {
// Stateless executors handle tokens
_ = ctx
_ = cfg
return &TokenClientResult{SuccessfulAuthed: 0}, nil
}
// NewAPIKeyClientProvider returns the default API key client loader that reuses existing logic.
func NewAPIKeyClientProvider() APIKeyClientProvider {
return &apiKeyClientProvider{}
}
type apiKeyClientProvider struct{}
func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) {
geminiCount, vertexCompatCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg)
if ctx != nil {
select {
case <-ctx.Done():
return nil, ctx.Err()
default:
}
}
return &APIKeyClientResult{
GeminiKeyCount: geminiCount,
VertexCompatKeyCount: vertexCompatCount,
ClaudeKeyCount: claudeCount,
CodexKeyCount: codexCount,
OpenAICompatCount: openAICompat,
}, nil
}
================================================
FILE: sdk/cliproxy/rtprovider.go
================================================
package cliproxy
import (
"net/http"
"strings"
"sync"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/proxyutil"
log "github.com/sirupsen/logrus"
)
// defaultRoundTripperProvider returns a per-auth HTTP RoundTripper based on
// the Auth.ProxyURL value. It caches transports per proxy URL string.
type defaultRoundTripperProvider struct {
mu sync.RWMutex
cache map[string]http.RoundTripper
}
func newDefaultRoundTripperProvider() *defaultRoundTripperProvider {
return &defaultRoundTripperProvider{cache: make(map[string]http.RoundTripper)}
}
// RoundTripperFor implements coreauth.RoundTripperProvider.
func (p *defaultRoundTripperProvider) RoundTripperFor(auth *coreauth.Auth) http.RoundTripper {
if auth == nil {
return nil
}
proxyStr := strings.TrimSpace(auth.ProxyURL)
if proxyStr == "" {
return nil
}
p.mu.RLock()
rt := p.cache[proxyStr]
p.mu.RUnlock()
if rt != nil {
return rt
}
transport, _, errBuild := proxyutil.BuildHTTPTransport(proxyStr)
if errBuild != nil {
log.Errorf("%v", errBuild)
return nil
}
if transport == nil {
return nil
}
p.mu.Lock()
p.cache[proxyStr] = transport
p.mu.Unlock()
return transport
}
================================================
FILE: sdk/cliproxy/rtprovider_test.go
================================================
package cliproxy
import (
"net/http"
"testing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
func TestRoundTripperForDirectBypassesProxy(t *testing.T) {
t.Parallel()
provider := newDefaultRoundTripperProvider()
rt := provider.RoundTripperFor(&coreauth.Auth{ProxyURL: "direct"})
transport, ok := rt.(*http.Transport)
if !ok {
t.Fatalf("transport type = %T, want *http.Transport", rt)
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
================================================
FILE: sdk/cliproxy/service.go
================================================
// Package cliproxy provides the core service implementation for the CLI Proxy API.
// It includes service lifecycle management, authentication handling, file watching,
// and integration with various AI service providers through a unified interface.
package cliproxy
import (
"context"
"errors"
"fmt"
"os"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus"
)
// Service wraps the proxy server lifecycle so external programs can embed the CLI proxy.
// It manages the complete lifecycle including authentication, file watching, HTTP server,
// and integration with various AI service providers.
type Service struct {
// cfg holds the current application configuration.
cfg *config.Config
// cfgMu protects concurrent access to the configuration.
cfgMu sync.RWMutex
// configPath is the path to the configuration file.
configPath string
// tokenProvider handles loading token-based clients.
tokenProvider TokenClientProvider
// apiKeyProvider handles loading API key-based clients.
apiKeyProvider APIKeyClientProvider
// watcherFactory creates file watcher instances.
watcherFactory WatcherFactory
// hooks provides lifecycle callbacks.
hooks Hooks
// serverOptions contains additional server configuration options.
serverOptions []api.ServerOption
// server is the HTTP API server instance.
server *api.Server
// pprofServer manages the optional pprof HTTP debug server.
pprofServer *pprofServer
// serverErr channel for server startup/shutdown errors.
serverErr chan error
// watcher handles file system monitoring.
watcher *WatcherWrapper
// watcherCancel cancels the watcher context.
watcherCancel context.CancelFunc
// authUpdates channel for authentication updates.
authUpdates chan watcher.AuthUpdate
// authQueueStop cancels the auth update queue processing.
authQueueStop context.CancelFunc
// authManager handles legacy authentication operations.
authManager *sdkAuth.Manager
// accessManager handles request authentication providers.
accessManager *sdkaccess.Manager
// coreManager handles core authentication and execution.
coreManager *coreauth.Manager
// shutdownOnce ensures shutdown is called only once.
shutdownOnce sync.Once
// wsGateway manages websocket Gemini providers.
wsGateway *wsrelay.Manager
}
// RegisterUsagePlugin registers a usage plugin on the global usage manager.
// This allows external code to monitor API usage and token consumption.
//
// Parameters:
// - plugin: The usage plugin to register
func (s *Service) RegisterUsagePlugin(plugin usage.Plugin) {
usage.RegisterPlugin(plugin)
}
// newDefaultAuthManager creates a default authentication manager with all supported providers.
func newDefaultAuthManager() *sdkAuth.Manager {
return sdkAuth.NewManager(
sdkAuth.GetTokenStore(),
sdkAuth.NewGeminiAuthenticator(),
sdkAuth.NewCodexAuthenticator(),
sdkAuth.NewClaudeAuthenticator(),
sdkAuth.NewQwenAuthenticator(),
)
}
func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
if s == nil {
return
}
if s.authUpdates == nil {
s.authUpdates = make(chan watcher.AuthUpdate, 256)
}
if s.authQueueStop != nil {
return
}
queueCtx, cancel := context.WithCancel(ctx)
s.authQueueStop = cancel
go s.consumeAuthUpdates(queueCtx)
}
func (s *Service) consumeAuthUpdates(ctx context.Context) {
ctx = coreauth.WithSkipPersist(ctx)
for {
select {
case <-ctx.Done():
return
case update, ok := <-s.authUpdates:
if !ok {
return
}
s.handleAuthUpdate(ctx, update)
labelDrain:
for {
select {
case nextUpdate := <-s.authUpdates:
s.handleAuthUpdate(ctx, nextUpdate)
default:
break labelDrain
}
}
}
}
}
func (s *Service) emitAuthUpdate(ctx context.Context, update watcher.AuthUpdate) {
if s == nil {
return
}
if ctx == nil {
ctx = context.Background()
}
if s.watcher != nil && s.watcher.DispatchRuntimeAuthUpdate(update) {
return
}
if s.authUpdates != nil {
select {
case s.authUpdates <- update:
return
default:
log.Debugf("auth update queue saturated, applying inline action=%v id=%s", update.Action, update.ID)
}
}
s.handleAuthUpdate(ctx, update)
}
func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdate) {
if s == nil {
return
}
s.cfgMu.RLock()
cfg := s.cfg
s.cfgMu.RUnlock()
if cfg == nil || s.coreManager == nil {
return
}
switch update.Action {
case watcher.AuthUpdateActionAdd, watcher.AuthUpdateActionModify:
if update.Auth == nil || update.Auth.ID == "" {
return
}
s.applyCoreAuthAddOrUpdate(ctx, update.Auth)
case watcher.AuthUpdateActionDelete:
id := update.ID
if id == "" && update.Auth != nil {
id = update.Auth.ID
}
if id == "" {
return
}
s.applyCoreAuthRemoval(ctx, id)
default:
log.Debugf("received unknown auth update action: %v", update.Action)
}
}
func (s *Service) ensureWebsocketGateway() {
if s == nil {
return
}
if s.wsGateway != nil {
return
}
opts := wsrelay.Options{
Path: "/v1/ws",
OnConnected: s.wsOnConnected,
OnDisconnected: s.wsOnDisconnected,
LogDebugf: log.Debugf,
LogInfof: log.Infof,
LogWarnf: log.Warnf,
}
s.wsGateway = wsrelay.NewManager(opts)
}
func (s *Service) wsOnConnected(channelID string) {
if s == nil || channelID == "" {
return
}
if !strings.HasPrefix(strings.ToLower(channelID), "aistudio-") {
return
}
if s.coreManager != nil {
if existing, ok := s.coreManager.GetByID(channelID); ok && existing != nil {
if !existing.Disabled && existing.Status == coreauth.StatusActive {
return
}
}
}
now := time.Now().UTC()
auth := &coreauth.Auth{
ID: channelID, // keep channel identifier as ID
Provider: "aistudio", // logical provider for switch routing
Label: channelID, // display original channel id
Status: coreauth.StatusActive,
CreatedAt: now,
UpdatedAt: now,
Attributes: map[string]string{"runtime_only": "true"},
Metadata: map[string]any{"email": channelID}, // metadata drives logging and usage tracking
}
log.Infof("websocket provider connected: %s", channelID)
s.emitAuthUpdate(context.Background(), watcher.AuthUpdate{
Action: watcher.AuthUpdateActionAdd,
ID: auth.ID,
Auth: auth,
})
}
func (s *Service) wsOnDisconnected(channelID string, reason error) {
if s == nil || channelID == "" {
return
}
if reason != nil {
if strings.Contains(reason.Error(), "replaced by new connection") {
log.Infof("websocket provider replaced: %s", channelID)
return
}
log.Warnf("websocket provider disconnected: %s (%v)", channelID, reason)
} else {
log.Infof("websocket provider disconnected: %s", channelID)
}
ctx := context.Background()
s.emitAuthUpdate(ctx, watcher.AuthUpdate{
Action: watcher.AuthUpdateActionDelete,
ID: channelID,
})
}
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
if s == nil || s.coreManager == nil || auth == nil || auth.ID == "" {
return
}
auth = auth.Clone()
s.ensureExecutorsForAuth(auth)
// IMPORTANT: Update coreManager FIRST, before model registration.
// This ensures that configuration changes (proxy_url, prefix, etc.) take effect
// immediately for API calls, rather than waiting for model registration to complete.
op := "register"
var err error
if existing, ok := s.coreManager.GetByID(auth.ID); ok {
auth.CreatedAt = existing.CreatedAt
auth.LastRefreshedAt = existing.LastRefreshedAt
auth.NextRefreshAfter = existing.NextRefreshAfter
if len(auth.ModelStates) == 0 && len(existing.ModelStates) > 0 {
auth.ModelStates = existing.ModelStates
}
op = "update"
_, err = s.coreManager.Update(ctx, auth)
} else {
_, err = s.coreManager.Register(ctx, auth)
}
if err != nil {
log.Errorf("failed to %s auth %s: %v", op, auth.ID, err)
current, ok := s.coreManager.GetByID(auth.ID)
if !ok || current.Disabled {
GlobalModelRegistry().UnregisterClient(auth.ID)
return
}
auth = current
}
// Register models after auth is updated in coreManager.
// This operation may block on network calls, but the auth configuration
// is already effective at this point.
s.registerModelsForAuth(auth)
// Refresh the scheduler entry so that the auth's supportedModelSet is rebuilt
// from the now-populated global model registry. Without this, newly added auths
// have an empty supportedModelSet (because Register/Update upserts into the
// scheduler before registerModelsForAuth runs) and are invisible to the scheduler.
s.coreManager.RefreshSchedulerEntry(auth.ID)
}
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
if s == nil || id == "" {
return
}
if s.coreManager == nil {
return
}
GlobalModelRegistry().UnregisterClient(id)
if existing, ok := s.coreManager.GetByID(id); ok && existing != nil {
existing.Disabled = true
existing.Status = coreauth.StatusDisabled
if _, err := s.coreManager.Update(ctx, existing); err != nil {
log.Errorf("failed to disable auth %s: %v", id, err)
}
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
s.ensureExecutorsForAuth(existing)
}
}
}
func (s *Service) applyRetryConfig(cfg *config.Config) {
if s == nil || s.coreManager == nil || cfg == nil {
return
}
maxInterval := time.Duration(cfg.MaxRetryInterval) * time.Second
s.coreManager.SetRetryConfig(cfg.RequestRetry, maxInterval, cfg.MaxRetryCredentials)
}
func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName string, ok bool) {
if a == nil {
return "", "", false
}
if len(a.Attributes) > 0 {
providerKey = strings.TrimSpace(a.Attributes["provider_key"])
compatName = strings.TrimSpace(a.Attributes["compat_name"])
if compatName != "" {
if providerKey == "" {
providerKey = compatName
}
return strings.ToLower(providerKey), compatName, true
}
}
if strings.EqualFold(strings.TrimSpace(a.Provider), "openai-compatibility") {
return "openai-compatibility", strings.TrimSpace(a.Label), true
}
return "", "", false
}
func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.ensureExecutorsForAuthWithMode(a, false)
}
func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace bool) {
if s == nil || s.coreManager == nil || a == nil {
return
}
if strings.EqualFold(strings.TrimSpace(a.Provider), "codex") {
if !forceReplace {
existingExecutor, hasExecutor := s.coreManager.Executor("codex")
if hasExecutor {
_, isCodexAutoExecutor := existingExecutor.(*executor.CodexAutoExecutor)
if isCodexAutoExecutor {
return
}
}
}
s.coreManager.RegisterExecutor(executor.NewCodexAutoExecutor(s.cfg))
return
}
// Skip disabled auth entries when (re)binding executors.
// Disabled auths can linger during config reloads (e.g., removed OpenAI-compat entries)
// and must not override active provider executors (such as iFlow OAuth accounts).
if a.Disabled {
return
}
if compatProviderKey, _, isCompat := openAICompatInfoFromAuth(a); isCompat {
if compatProviderKey == "" {
compatProviderKey = strings.ToLower(strings.TrimSpace(a.Provider))
}
if compatProviderKey == "" {
compatProviderKey = "openai-compatibility"
}
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(compatProviderKey, s.cfg))
return
}
switch strings.ToLower(a.Provider) {
case "gemini":
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
case "vertex":
s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg))
case "gemini-cli":
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
case "aistudio":
if s.wsGateway != nil {
s.coreManager.RegisterExecutor(executor.NewAIStudioExecutor(s.cfg, a.ID, s.wsGateway))
}
return
case "antigravity":
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
case "claude":
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "qwen":
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
case "iflow":
s.coreManager.RegisterExecutor(executor.NewIFlowExecutor(s.cfg))
case "kimi":
s.coreManager.RegisterExecutor(executor.NewKimiExecutor(s.cfg))
default:
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
if providerKey == "" {
providerKey = "openai-compatibility"
}
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(providerKey, s.cfg))
}
}
func (s *Service) registerResolvedModelsForAuth(a *coreauth.Auth, providerKey string, models []*ModelInfo) {
if a == nil || a.ID == "" {
return
}
if len(models) == 0 {
GlobalModelRegistry().UnregisterClient(a.ID)
return
}
GlobalModelRegistry().RegisterClient(a.ID, providerKey, models)
}
// rebindExecutors refreshes provider executors so they observe the latest configuration.
func (s *Service) rebindExecutors() {
if s == nil || s.coreManager == nil {
return
}
auths := s.coreManager.List()
reboundCodex := false
for _, auth := range auths {
if auth != nil && strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
if reboundCodex {
continue
}
reboundCodex = true
}
s.ensureExecutorsForAuthWithMode(auth, true)
}
}
// Run starts the service and blocks until the context is cancelled or the server stops.
// It initializes all components including authentication, file watching, HTTP server,
// and starts processing requests. The method blocks until the context is cancelled.
//
// Parameters:
// - ctx: The context for controlling the service lifecycle
//
// Returns:
// - error: An error if the service fails to start or run
func (s *Service) Run(ctx context.Context) error {
if s == nil {
return fmt.Errorf("cliproxy: service is nil")
}
if ctx == nil {
ctx = context.Background()
}
usage.StartDefault(ctx)
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
defer func() {
if err := s.Shutdown(shutdownCtx); err != nil {
log.Errorf("service shutdown returned error: %v", err)
}
}()
if err := s.ensureAuthDir(); err != nil {
return err
}
s.applyRetryConfig(s.cfg)
if s.coreManager != nil {
if errLoad := s.coreManager.Load(ctx); errLoad != nil {
log.Warnf("failed to load auth store: %v", errLoad)
}
}
tokenResult, err := s.tokenProvider.Load(ctx, s.cfg)
if err != nil && !errors.Is(err, context.Canceled) {
return err
}
if tokenResult == nil {
tokenResult = &TokenClientResult{}
}
apiKeyResult, err := s.apiKeyProvider.Load(ctx, s.cfg)
if err != nil && !errors.Is(err, context.Canceled) {
return err
}
if apiKeyResult == nil {
apiKeyResult = &APIKeyClientResult{}
}
// legacy clients removed; no caches to refresh
// handlers no longer depend on legacy clients; pass nil slice initially
s.server = api.NewServer(s.cfg, s.coreManager, s.accessManager, s.configPath, s.serverOptions...)
if s.authManager == nil {
s.authManager = newDefaultAuthManager()
}
s.ensureWebsocketGateway()
if s.server != nil && s.wsGateway != nil {
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
s.server.SetWebsocketAuthChangeHandler(func(oldEnabled, newEnabled bool) {
if oldEnabled == newEnabled {
return
}
if !oldEnabled && newEnabled {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if errStop := s.wsGateway.Stop(ctx); errStop != nil {
log.Warnf("failed to reset websocket connections after ws-auth change %t -> %t: %v", oldEnabled, newEnabled, errStop)
return
}
log.Debugf("ws-auth enabled; existing websocket sessions terminated to enforce authentication")
return
}
log.Debugf("ws-auth disabled; existing websocket sessions remain connected")
})
}
if s.hooks.OnBeforeStart != nil {
s.hooks.OnBeforeStart(s.cfg)
}
// Register callback for startup and periodic model catalog refresh.
// When remote model definitions change, re-register models for affected providers.
// This intentionally rebuilds per-auth model availability from the latest catalog
// snapshot instead of preserving prior registry suppression state.
registry.SetModelRefreshCallback(func(changedProviders []string) {
if s == nil || s.coreManager == nil || len(changedProviders) == 0 {
return
}
providerSet := make(map[string]bool, len(changedProviders))
for _, p := range changedProviders {
providerSet[strings.ToLower(strings.TrimSpace(p))] = true
}
auths := s.coreManager.List()
refreshed := 0
for _, item := range auths {
if item == nil || item.ID == "" {
continue
}
auth, ok := s.coreManager.GetByID(item.ID)
if !ok || auth == nil || auth.Disabled {
continue
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
if !providerSet[provider] {
continue
}
if s.refreshModelRegistrationForAuth(auth) {
refreshed++
}
}
if refreshed > 0 {
log.Infof("re-registered models for %d auth(s) due to model catalog changes: %v", refreshed, changedProviders)
}
})
s.serverErr = make(chan error, 1)
go func() {
if errStart := s.server.Start(); errStart != nil {
s.serverErr <- errStart
} else {
s.serverErr <- nil
}
}()
time.Sleep(100 * time.Millisecond)
fmt.Printf("API server started successfully on: %s:%d\n", s.cfg.Host, s.cfg.Port)
s.applyPprofConfig(s.cfg)
if s.hooks.OnAfterStart != nil {
s.hooks.OnAfterStart(s)
}
var watcherWrapper *WatcherWrapper
reloadCallback := func(newCfg *config.Config) {
previousStrategy := ""
s.cfgMu.RLock()
if s.cfg != nil {
previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy))
}
s.cfgMu.RUnlock()
if newCfg == nil {
s.cfgMu.RLock()
newCfg = s.cfg
s.cfgMu.RUnlock()
}
if newCfg == nil {
return
}
nextStrategy := strings.ToLower(strings.TrimSpace(newCfg.Routing.Strategy))
normalizeStrategy := func(strategy string) string {
switch strategy {
case "fill-first", "fillfirst", "ff":
return "fill-first"
default:
return "round-robin"
}
}
previousStrategy = normalizeStrategy(previousStrategy)
nextStrategy = normalizeStrategy(nextStrategy)
if s.coreManager != nil && previousStrategy != nextStrategy {
var selector coreauth.Selector
switch nextStrategy {
case "fill-first":
selector = &coreauth.FillFirstSelector{}
default:
selector = &coreauth.RoundRobinSelector{}
}
s.coreManager.SetSelector(selector)
}
s.applyRetryConfig(newCfg)
s.applyPprofConfig(newCfg)
if s.server != nil {
s.server.UpdateClients(newCfg)
}
s.cfgMu.Lock()
s.cfg = newCfg
s.cfgMu.Unlock()
if s.coreManager != nil {
s.coreManager.SetConfig(newCfg)
s.coreManager.SetOAuthModelAlias(newCfg.OAuthModelAlias)
}
s.rebindExecutors()
}
watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)
if err != nil {
return fmt.Errorf("cliproxy: failed to create watcher: %w", err)
}
s.watcher = watcherWrapper
s.ensureAuthUpdateQueue(ctx)
if s.authUpdates != nil {
watcherWrapper.SetAuthUpdateQueue(s.authUpdates)
}
watcherWrapper.SetConfig(s.cfg)
watcherCtx, watcherCancel := context.WithCancel(context.Background())
s.watcherCancel = watcherCancel
if err = watcherWrapper.Start(watcherCtx); err != nil {
return fmt.Errorf("cliproxy: failed to start watcher: %w", err)
}
log.Info("file watcher started for config and auth directory changes")
// Prefer core auth manager auto refresh if available.
if s.coreManager != nil {
interval := 15 * time.Minute
s.coreManager.StartAutoRefresh(context.Background(), interval)
log.Infof("core auth auto-refresh started (interval=%s)", interval)
}
select {
case <-ctx.Done():
log.Debug("service context cancelled, shutting down...")
return ctx.Err()
case err = <-s.serverErr:
return err
}
}
// Shutdown gracefully stops background workers and the HTTP server.
// It ensures all resources are properly cleaned up and connections are closed.
// The shutdown is idempotent and can be called multiple times safely.
//
// Parameters:
// - ctx: The context for controlling the shutdown timeout
//
// Returns:
// - error: An error if shutdown fails
func (s *Service) Shutdown(ctx context.Context) error {
if s == nil {
return nil
}
var shutdownErr error
s.shutdownOnce.Do(func() {
if ctx == nil {
ctx = context.Background()
}
// legacy refresh loop removed; only stopping core auth manager below
if s.watcherCancel != nil {
s.watcherCancel()
}
if s.coreManager != nil {
s.coreManager.StopAutoRefresh()
}
if s.watcher != nil {
if err := s.watcher.Stop(); err != nil {
log.Errorf("failed to stop file watcher: %v", err)
shutdownErr = err
}
}
if s.wsGateway != nil {
if err := s.wsGateway.Stop(ctx); err != nil {
log.Errorf("failed to stop websocket gateway: %v", err)
if shutdownErr == nil {
shutdownErr = err
}
}
}
if s.authQueueStop != nil {
s.authQueueStop()
s.authQueueStop = nil
}
if errShutdownPprof := s.shutdownPprof(ctx); errShutdownPprof != nil {
log.Errorf("failed to stop pprof server: %v", errShutdownPprof)
if shutdownErr == nil {
shutdownErr = errShutdownPprof
}
}
// no legacy clients to persist
if s.server != nil {
shutdownCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
defer cancel()
if err := s.server.Stop(shutdownCtx); err != nil {
log.Errorf("error stopping API server: %v", err)
if shutdownErr == nil {
shutdownErr = err
}
}
}
usage.StopDefault()
})
return shutdownErr
}
func (s *Service) ensureAuthDir() error {
info, err := os.Stat(s.cfg.AuthDir)
if err != nil {
if os.IsNotExist(err) {
if mkErr := os.MkdirAll(s.cfg.AuthDir, 0o755); mkErr != nil {
return fmt.Errorf("cliproxy: failed to create auth directory %s: %w", s.cfg.AuthDir, mkErr)
}
log.Infof("created missing auth directory: %s", s.cfg.AuthDir)
return nil
}
return fmt.Errorf("cliproxy: error checking auth directory %s: %w", s.cfg.AuthDir, err)
}
if !info.IsDir() {
return fmt.Errorf("cliproxy: auth path exists but is not a directory: %s", s.cfg.AuthDir)
}
return nil
}
// registerModelsForAuth (re)binds provider models in the global registry using the core auth ID as client identifier.
func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if a == nil || a.ID == "" {
return
}
if a.Disabled {
GlobalModelRegistry().UnregisterClient(a.ID)
return
}
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
if authKind == "" {
if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
GlobalModelRegistry().UnregisterClient(a.ID)
return
}
}
// Unregister legacy client ID (if present) to avoid double counting
if a.Runtime != nil {
if idGetter, ok := a.Runtime.(interface{ GetClientID() string }); ok {
if rid := idGetter.GetClientID(); rid != "" && rid != a.ID {
GlobalModelRegistry().UnregisterClient(rid)
}
}
}
provider := strings.ToLower(strings.TrimSpace(a.Provider))
compatProviderKey, compatDisplayName, compatDetected := openAICompatInfoFromAuth(a)
if compatDetected {
provider = "openai-compatibility"
}
excluded := s.oauthExcludedModels(provider, authKind)
// The synthesizer pre-merges per-account and global exclusions into the "excluded_models" attribute.
// If this attribute is present, it represents the complete list of exclusions and overrides the global config.
if a.Attributes != nil {
if val, ok := a.Attributes["excluded_models"]; ok && strings.TrimSpace(val) != "" {
excluded = strings.Split(val, ",")
}
}
var models []*ModelInfo
switch provider {
case "gemini":
models = registry.GetGeminiModels()
if entry := s.resolveConfigGeminiKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildGeminiConfigModels(entry)
}
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
}
models = applyExcludedModels(models, excluded)
case "vertex":
// Vertex AI Gemini supports the same model identifiers as Gemini.
models = registry.GetGeminiVertexModels()
if entry := s.resolveConfigVertexCompatKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildVertexCompatConfigModels(entry)
}
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
}
models = applyExcludedModels(models, excluded)
case "gemini-cli":
models = registry.GetGeminiCLIModels()
models = applyExcludedModels(models, excluded)
case "aistudio":
models = registry.GetAIStudioModels()
models = applyExcludedModels(models, excluded)
case "antigravity":
models = registry.GetAntigravityModels()
models = applyExcludedModels(models, excluded)
case "claude":
models = registry.GetClaudeModels()
if entry := s.resolveConfigClaudeKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildClaudeConfigModels(entry)
}
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
}
models = applyExcludedModels(models, excluded)
case "codex":
codexPlanType := ""
if a.Attributes != nil {
codexPlanType = strings.TrimSpace(a.Attributes["plan_type"])
}
switch strings.ToLower(codexPlanType) {
case "pro":
models = registry.GetCodexProModels()
case "plus":
models = registry.GetCodexPlusModels()
case "team", "business", "go":
models = registry.GetCodexTeamModels()
case "free":
models = registry.GetCodexFreeModels()
default:
models = registry.GetCodexProModels()
}
if entry := s.resolveConfigCodexKey(a); entry != nil {
if len(entry.Models) > 0 {
models = buildCodexConfigModels(entry)
}
if authKind == "apikey" {
excluded = entry.ExcludedModels
}
}
models = applyExcludedModels(models, excluded)
case "qwen":
models = registry.GetQwenModels()
models = applyExcludedModels(models, excluded)
case "iflow":
models = registry.GetIFlowModels()
models = applyExcludedModels(models, excluded)
case "kimi":
models = registry.GetKimiModels()
models = applyExcludedModels(models, excluded)
default:
// Handle OpenAI-compatibility providers by name using config
if s.cfg != nil {
providerKey := provider
compatName := strings.TrimSpace(a.Provider)
isCompatAuth := false
if compatDetected {
if compatProviderKey != "" {
providerKey = compatProviderKey
}
if compatDisplayName != "" {
compatName = compatDisplayName
}
isCompatAuth = true
}
if strings.EqualFold(providerKey, "openai-compatibility") {
isCompatAuth = true
if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" {
compatName = v
}
if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" {
providerKey = strings.ToLower(v)
isCompatAuth = true
}
}
if providerKey == "openai-compatibility" && compatName != "" {
providerKey = strings.ToLower(compatName)
}
} else if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" {
compatName = v
isCompatAuth = true
}
if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" {
providerKey = strings.ToLower(v)
isCompatAuth = true
}
}
for i := range s.cfg.OpenAICompatibility {
compat := &s.cfg.OpenAICompatibility[i]
if strings.EqualFold(compat.Name, compatName) {
isCompatAuth = true
// Convert compatibility models to registry models
ms := make([]*ModelInfo, 0, len(compat.Models))
for j := range compat.Models {
m := compat.Models[j]
// Use alias as model ID, fallback to name if alias is empty
modelID := m.Alias
if modelID == "" {
modelID = m.Name
}
ms = append(ms, &ModelInfo{
ID: modelID,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: compat.Name,
Type: "openai-compatibility",
DisplayName: modelID,
UserDefined: true,
})
}
// Register and return
if len(ms) > 0 {
if providerKey == "" {
providerKey = "openai-compatibility"
}
s.registerResolvedModelsForAuth(a, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
} else {
// Ensure stale registrations are cleared when model list becomes empty.
GlobalModelRegistry().UnregisterClient(a.ID)
}
return
}
}
if isCompatAuth {
// No matching provider found or models removed entirely; drop any prior registration.
GlobalModelRegistry().UnregisterClient(a.ID)
return
}
}
}
models = applyOAuthModelAlias(s.cfg, provider, authKind, models)
if len(models) > 0 {
key := provider
if key == "" {
key = strings.ToLower(strings.TrimSpace(a.Provider))
}
s.registerResolvedModelsForAuth(a, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
return
}
GlobalModelRegistry().UnregisterClient(a.ID)
}
// refreshModelRegistrationForAuth re-applies the latest model registration for
// one auth and reconciles any concurrent auth changes that race with the
// refresh. Callers are expected to pre-filter provider membership.
//
// Re-registration is deliberate: registry cooldown/suspension state is treated
// as part of the previous registration snapshot and is cleared when the auth is
// rebound to the refreshed model catalog.
func (s *Service) refreshModelRegistrationForAuth(current *coreauth.Auth) bool {
if s == nil || s.coreManager == nil || current == nil || current.ID == "" {
return false
}
if !current.Disabled {
s.ensureExecutorsForAuth(current)
}
s.registerModelsForAuth(current)
latest, ok := s.latestAuthForModelRegistration(current.ID)
if !ok || latest.Disabled {
GlobalModelRegistry().UnregisterClient(current.ID)
s.coreManager.RefreshSchedulerEntry(current.ID)
return false
}
// Re-apply the latest auth snapshot so concurrent auth updates cannot leave
// stale model registrations behind. This may duplicate registration work when
// no auth fields changed, but keeps the refresh path simple and correct.
s.ensureExecutorsForAuth(latest)
s.registerModelsForAuth(latest)
s.coreManager.RefreshSchedulerEntry(current.ID)
return true
}
// latestAuthForModelRegistration returns the latest auth snapshot regardless of
// provider membership. Callers use this after a registration attempt to restore
// whichever state currently owns the client ID in the global registry.
func (s *Service) latestAuthForModelRegistration(authID string) (*coreauth.Auth, bool) {
if s == nil || s.coreManager == nil || authID == "" {
return nil, false
}
auth, ok := s.coreManager.GetByID(authID)
if !ok || auth == nil || auth.ID == "" {
return nil, false
}
return auth, true
}
func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey {
if auth == nil || s.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range s.cfg.ClaudeKey {
entry := &s.cfg.ClaudeKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range s.cfg.ClaudeKey {
entry := &s.cfg.ClaudeKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
func (s *Service) resolveConfigGeminiKey(auth *coreauth.Auth) *config.GeminiKey {
if auth == nil || s.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range s.cfg.GeminiKey {
entry := &s.cfg.GeminiKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
return nil
}
func (s *Service) resolveConfigVertexCompatKey(auth *coreauth.Auth) *config.VertexCompatKey {
if auth == nil || s.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range s.cfg.VertexCompatAPIKey {
entry := &s.cfg.VertexCompatAPIKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range s.cfg.VertexCompatAPIKey {
entry := &s.cfg.VertexCompatAPIKey[i]
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
return entry
}
}
}
return nil
}
func (s *Service) resolveConfigCodexKey(auth *coreauth.Auth) *config.CodexKey {
if auth == nil || s.cfg == nil {
return nil
}
var attrKey, attrBase string
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range s.cfg.CodexKey {
entry := &s.cfg.CodexKey[i]
cfgKey := strings.TrimSpace(entry.APIKey)
cfgBase := strings.TrimSpace(entry.BaseURL)
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
return nil
}
func (s *Service) oauthExcludedModels(provider, authKind string) []string {
cfg := s.cfg
if cfg == nil {
return nil
}
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
providerKey := strings.ToLower(strings.TrimSpace(provider))
if authKindKey == "apikey" {
return nil
}
return cfg.OAuthExcludedModels[providerKey]
}
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
if len(models) == 0 || len(excluded) == 0 {
return models
}
patterns := make([]string, 0, len(excluded))
for _, item := range excluded {
if trimmed := strings.TrimSpace(item); trimmed != "" {
patterns = append(patterns, strings.ToLower(trimmed))
}
}
if len(patterns) == 0 {
return models
}
filtered := make([]*ModelInfo, 0, len(models))
for _, model := range models {
if model == nil {
continue
}
modelID := strings.ToLower(strings.TrimSpace(model.ID))
blocked := false
for _, pattern := range patterns {
if matchWildcard(pattern, modelID) {
blocked = true
break
}
}
if !blocked {
filtered = append(filtered, model)
}
}
return filtered
}
func applyModelPrefixes(models []*ModelInfo, prefix string, forceModelPrefix bool) []*ModelInfo {
trimmedPrefix := strings.TrimSpace(prefix)
if trimmedPrefix == "" || len(models) == 0 {
return models
}
out := make([]*ModelInfo, 0, len(models)*2)
seen := make(map[string]struct{}, len(models)*2)
addModel := func(model *ModelInfo) {
if model == nil {
return
}
id := strings.TrimSpace(model.ID)
if id == "" {
return
}
if _, exists := seen[id]; exists {
return
}
seen[id] = struct{}{}
out = append(out, model)
}
for _, model := range models {
if model == nil {
continue
}
baseID := strings.TrimSpace(model.ID)
if baseID == "" {
continue
}
if !forceModelPrefix || trimmedPrefix == baseID {
addModel(model)
}
clone := *model
clone.ID = trimmedPrefix + "/" + baseID
addModel(&clone)
}
return out
}
// matchWildcard performs case-insensitive wildcard matching where '*' matches any substring.
func matchWildcard(pattern, value string) bool {
if pattern == "" {
return false
}
// Fast path for exact match (no wildcard present).
if !strings.Contains(pattern, "*") {
return pattern == value
}
parts := strings.Split(pattern, "*")
// Handle prefix.
if prefix := parts[0]; prefix != "" {
if !strings.HasPrefix(value, prefix) {
return false
}
value = value[len(prefix):]
}
// Handle suffix.
if suffix := parts[len(parts)-1]; suffix != "" {
if !strings.HasSuffix(value, suffix) {
return false
}
value = value[:len(value)-len(suffix)]
}
// Handle middle segments in order.
for i := 1; i < len(parts)-1; i++ {
segment := parts[i]
if segment == "" {
continue
}
idx := strings.Index(value, segment)
if idx < 0 {
return false
}
value = value[idx+len(segment):]
}
return true
}
type modelEntry interface {
GetName() string
GetAlias() string
}
func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*ModelInfo {
if len(models) == 0 {
return nil
}
now := time.Now().Unix()
out := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for i := range models {
model := models[i]
name := strings.TrimSpace(model.GetName())
alias := strings.TrimSpace(model.GetAlias())
if alias == "" {
alias = name
}
if alias == "" {
continue
}
key := strings.ToLower(alias)
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
display := name
if display == "" {
display = alias
}
info := &ModelInfo{
ID: alias,
Object: "model",
Created: now,
OwnedBy: ownedBy,
Type: modelType,
DisplayName: display,
UserDefined: true,
}
if name != "" {
if upstream := registry.LookupStaticModelInfo(name); upstream != nil && upstream.Thinking != nil {
info.Thinking = upstream.Thinking
}
}
out = append(out, info)
}
return out
}
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "google", "vertex")
}
func buildGeminiConfigModels(entry *config.GeminiKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "google", "gemini")
}
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "anthropic", "claude")
}
func buildCodexConfigModels(entry *config.CodexKey) []*ModelInfo {
if entry == nil {
return nil
}
return buildConfigModels(entry.Models, "openai", "openai")
}
func rewriteModelInfoName(name, oldID, newID string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return name
}
oldID = strings.TrimSpace(oldID)
newID = strings.TrimSpace(newID)
if oldID == "" || newID == "" {
return name
}
if strings.EqualFold(oldID, newID) {
return name
}
if strings.EqualFold(trimmed, oldID) {
return newID
}
if strings.HasSuffix(trimmed, "/"+oldID) {
prefix := strings.TrimSuffix(trimmed, oldID)
return prefix + newID
}
if trimmed == "models/"+oldID {
return "models/" + newID
}
return name
}
func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
if cfg == nil || len(models) == 0 {
return models
}
channel := coreauth.OAuthModelAliasChannel(provider, authKind)
if channel == "" || len(cfg.OAuthModelAlias) == 0 {
return models
}
aliases := cfg.OAuthModelAlias[channel]
if len(aliases) == 0 {
return models
}
type aliasEntry struct {
alias string
fork bool
}
forward := make(map[string][]aliasEntry, len(aliases))
for i := range aliases {
name := strings.TrimSpace(aliases[i].Name)
alias := strings.TrimSpace(aliases[i].Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
key := strings.ToLower(name)
forward[key] = append(forward[key], aliasEntry{alias: alias, fork: aliases[i].Fork})
}
if len(forward) == 0 {
return models
}
out := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
id := strings.TrimSpace(model.ID)
if id == "" {
continue
}
key := strings.ToLower(id)
entries := forward[key]
if len(entries) == 0 {
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
out = append(out, model)
continue
}
keepOriginal := false
for _, entry := range entries {
if entry.fork {
keepOriginal = true
break
}
}
if keepOriginal {
if _, exists := seen[key]; !exists {
seen[key] = struct{}{}
out = append(out, model)
}
}
addedAlias := false
for _, entry := range entries {
mappedID := strings.TrimSpace(entry.alias)
if mappedID == "" {
continue
}
if strings.EqualFold(mappedID, id) {
continue
}
aliasKey := strings.ToLower(mappedID)
if _, exists := seen[aliasKey]; exists {
continue
}
seen[aliasKey] = struct{}{}
clone := *model
clone.ID = mappedID
if clone.Name != "" {
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
}
out = append(out, &clone)
addedAlias = true
}
if !keepOriginal && !addedAlias {
if _, exists := seen[key]; exists {
continue
}
seen[key] = struct{}{}
out = append(out, model)
}
}
return out
}
================================================
FILE: sdk/cliproxy/service_codex_executor_binding_test.go
================================================
package cliproxy
import (
"testing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestEnsureExecutorsForAuth_CodexDoesNotReplaceInNormalMode(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "codex-auth-1",
Provider: "codex",
Status: coreauth.StatusActive,
}
service.ensureExecutorsForAuth(auth)
firstExecutor, okFirst := service.coreManager.Executor("codex")
if !okFirst || firstExecutor == nil {
t.Fatal("expected codex executor after first bind")
}
service.ensureExecutorsForAuth(auth)
secondExecutor, okSecond := service.coreManager.Executor("codex")
if !okSecond || secondExecutor == nil {
t.Fatal("expected codex executor after second bind")
}
if firstExecutor != secondExecutor {
t.Fatal("expected codex executor to stay unchanged in normal mode")
}
}
func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "codex-auth-2",
Provider: "codex",
Status: coreauth.StatusActive,
}
service.ensureExecutorsForAuth(auth)
firstExecutor, okFirst := service.coreManager.Executor("codex")
if !okFirst || firstExecutor == nil {
t.Fatal("expected codex executor after first bind")
}
service.ensureExecutorsForAuthWithMode(auth, true)
secondExecutor, okSecond := service.coreManager.Executor("codex")
if !okSecond || secondExecutor == nil {
t.Fatal("expected codex executor after forced rebind")
}
if firstExecutor == secondExecutor {
t.Fatal("expected codex executor replacement in force mode")
}
}
================================================
FILE: sdk/cliproxy/service_excluded_models_test.go
================================================
package cliproxy
import (
"strings"
"testing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestRegisterModelsForAuth_UsesPreMergedExcludedModelsAttribute(t *testing.T) {
service := &Service{
cfg: &config.Config{
OAuthExcludedModels: map[string][]string{
"gemini-cli": {"gemini-2.5-pro"},
},
},
}
auth := &coreauth.Auth{
ID: "auth-gemini-cli",
Provider: "gemini-cli",
Status: coreauth.StatusActive,
Attributes: map[string]string{
"auth_kind": "oauth",
"excluded_models": "gemini-2.5-flash",
},
}
registry := GlobalModelRegistry()
registry.UnregisterClient(auth.ID)
t.Cleanup(func() {
registry.UnregisterClient(auth.ID)
})
service.registerModelsForAuth(auth)
models := registry.GetAvailableModelsByProvider("gemini-cli")
if len(models) == 0 {
t.Fatal("expected gemini-cli models to be registered")
}
for _, model := range models {
if model == nil {
continue
}
modelID := strings.TrimSpace(model.ID)
if strings.EqualFold(modelID, "gemini-2.5-flash") {
t.Fatalf("expected model %q to be excluded by auth attribute", modelID)
}
}
seenGlobalExcluded := false
for _, model := range models {
if model == nil {
continue
}
if strings.EqualFold(strings.TrimSpace(model.ID), "gemini-2.5-pro") {
seenGlobalExcluded = true
break
}
}
if !seenGlobalExcluded {
t.Fatal("expected global excluded model to be present when attribute override is set")
}
}
================================================
FILE: sdk/cliproxy/service_oauth_model_alias_test.go
================================================
package cliproxy
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestApplyOAuthModelAlias_Rename(t *testing.T) {
cfg := &config.Config{
OAuthModelAlias: map[string][]config.OAuthModelAlias{
"codex": {
{Name: "gpt-5", Alias: "g5"},
},
},
}
models := []*ModelInfo{
{ID: "gpt-5", Name: "models/gpt-5"},
}
out := applyOAuthModelAlias(cfg, "codex", "oauth", models)
if len(out) != 1 {
t.Fatalf("expected 1 model, got %d", len(out))
}
if out[0].ID != "g5" {
t.Fatalf("expected model id %q, got %q", "g5", out[0].ID)
}
if out[0].Name != "models/g5" {
t.Fatalf("expected model name %q, got %q", "models/g5", out[0].Name)
}
}
func TestApplyOAuthModelAlias_ForkAddsAlias(t *testing.T) {
cfg := &config.Config{
OAuthModelAlias: map[string][]config.OAuthModelAlias{
"codex": {
{Name: "gpt-5", Alias: "g5", Fork: true},
},
},
}
models := []*ModelInfo{
{ID: "gpt-5", Name: "models/gpt-5"},
}
out := applyOAuthModelAlias(cfg, "codex", "oauth", models)
if len(out) != 2 {
t.Fatalf("expected 2 models, got %d", len(out))
}
if out[0].ID != "gpt-5" {
t.Fatalf("expected first model id %q, got %q", "gpt-5", out[0].ID)
}
if out[1].ID != "g5" {
t.Fatalf("expected second model id %q, got %q", "g5", out[1].ID)
}
if out[1].Name != "models/g5" {
t.Fatalf("expected forked model name %q, got %q", "models/g5", out[1].Name)
}
}
func TestApplyOAuthModelAlias_ForkAddsMultipleAliases(t *testing.T) {
cfg := &config.Config{
OAuthModelAlias: map[string][]config.OAuthModelAlias{
"codex": {
{Name: "gpt-5", Alias: "g5", Fork: true},
{Name: "gpt-5", Alias: "g5-2", Fork: true},
},
},
}
models := []*ModelInfo{
{ID: "gpt-5", Name: "models/gpt-5"},
}
out := applyOAuthModelAlias(cfg, "codex", "oauth", models)
if len(out) != 3 {
t.Fatalf("expected 3 models, got %d", len(out))
}
if out[0].ID != "gpt-5" {
t.Fatalf("expected first model id %q, got %q", "gpt-5", out[0].ID)
}
if out[1].ID != "g5" {
t.Fatalf("expected second model id %q, got %q", "g5", out[1].ID)
}
if out[1].Name != "models/g5" {
t.Fatalf("expected forked model name %q, got %q", "models/g5", out[1].Name)
}
if out[2].ID != "g5-2" {
t.Fatalf("expected third model id %q, got %q", "g5-2", out[2].ID)
}
if out[2].Name != "models/g5-2" {
t.Fatalf("expected forked model name %q, got %q", "models/g5-2", out[2].Name)
}
}
================================================
FILE: sdk/cliproxy/types.go
================================================
// Package cliproxy provides the core service implementation for the CLI Proxy API.
// It includes service lifecycle management, authentication handling, file watching,
// and integration with various AI service providers through a unified interface.
package cliproxy
import (
"context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
// TokenClientProvider loads clients backed by stored authentication tokens.
// It provides an interface for loading authentication tokens from various sources
// and creating clients for AI service providers.
type TokenClientProvider interface {
// Load loads token-based clients from the configured source.
//
// Parameters:
// - ctx: The context for the loading operation
// - cfg: The application configuration
//
// Returns:
// - *TokenClientResult: The result containing loaded clients
// - error: An error if loading fails
Load(ctx context.Context, cfg *config.Config) (*TokenClientResult, error)
}
// TokenClientResult represents clients generated from persisted tokens.
// It contains metadata about the loading operation and the number of successful authentications.
type TokenClientResult struct {
// SuccessfulAuthed is the number of successfully authenticated clients.
SuccessfulAuthed int
}
// APIKeyClientProvider loads clients backed directly by configured API keys.
// It provides an interface for loading API key-based clients for various AI service providers.
type APIKeyClientProvider interface {
// Load loads API key-based clients from the configuration.
//
// Parameters:
// - ctx: The context for the loading operation
// - cfg: The application configuration
//
// Returns:
// - *APIKeyClientResult: The result containing loaded clients
// - error: An error if loading fails
Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error)
}
// APIKeyClientResult is returned by APIKeyClientProvider.Load()
type APIKeyClientResult struct {
// GeminiKeyCount is the number of Gemini API keys loaded
GeminiKeyCount int
// VertexCompatKeyCount is the number of Vertex-compatible API keys loaded
VertexCompatKeyCount int
// ClaudeKeyCount is the number of Claude API keys loaded
ClaudeKeyCount int
// CodexKeyCount is the number of Codex API keys loaded
CodexKeyCount int
// OpenAICompatCount is the number of OpenAI compatibility API keys loaded
OpenAICompatCount int
}
// WatcherFactory creates a watcher for configuration and token changes.
// The reload callback receives the updated configuration when changes are detected.
//
// Parameters:
// - configPath: The path to the configuration file to watch
// - authDir: The directory containing authentication tokens to watch
// - reload: The callback function to call when changes are detected
//
// Returns:
// - *WatcherWrapper: A watcher wrapper instance
// - error: An error if watcher creation fails
type WatcherFactory func(configPath, authDir string, reload func(*config.Config)) (*WatcherWrapper, error)
// WatcherWrapper exposes the subset of watcher methods required by the SDK.
type WatcherWrapper struct {
start func(ctx context.Context) error
stop func() error
setConfig func(cfg *config.Config)
snapshotAuths func() []*coreauth.Auth
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool
}
// Start proxies to the underlying watcher Start implementation.
func (w *WatcherWrapper) Start(ctx context.Context) error {
if w == nil || w.start == nil {
return nil
}
return w.start(ctx)
}
// Stop proxies to the underlying watcher Stop implementation.
func (w *WatcherWrapper) Stop() error {
if w == nil || w.stop == nil {
return nil
}
return w.stop()
}
// SetConfig updates the watcher configuration cache.
func (w *WatcherWrapper) SetConfig(cfg *config.Config) {
if w == nil || w.setConfig == nil {
return
}
w.setConfig(cfg)
}
// DispatchRuntimeAuthUpdate forwards runtime auth updates (e.g., websocket providers)
// into the watcher-managed auth update queue when available.
// Returns true if the update was enqueued successfully.
func (w *WatcherWrapper) DispatchRuntimeAuthUpdate(update watcher.AuthUpdate) bool {
if w == nil || w.dispatchRuntimeUpdate == nil {
return false
}
return w.dispatchRuntimeUpdate(update)
}
// SetClients updates the watcher file-backed clients registry.
// SetClients and SetAPIKeyClients removed; watcher manages its own caches
// SnapshotClients returns the current combined clients snapshot from the underlying watcher.
// SnapshotClients removed; use SnapshotAuths
// SnapshotAuths returns the current auth entries derived from legacy clients.
func (w *WatcherWrapper) SnapshotAuths() []*coreauth.Auth {
if w == nil || w.snapshotAuths == nil {
return nil
}
return w.snapshotAuths()
}
// SetAuthUpdateQueue registers the channel used to propagate auth updates.
func (w *WatcherWrapper) SetAuthUpdateQueue(queue chan<- watcher.AuthUpdate) {
if w == nil || w.setUpdateQueue == nil {
return
}
w.setUpdateQueue(queue)
}
================================================
FILE: sdk/cliproxy/usage/manager.go
================================================
package usage
import (
"context"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// Record contains the usage statistics captured for a single provider request.
type Record struct {
Provider string
Model string
APIKey string
AuthID string
AuthIndex string
Source string
RequestedAt time.Time
Failed bool
Detail Detail
}
// Detail holds the token usage breakdown.
type Detail struct {
InputTokens int64
OutputTokens int64
ReasoningTokens int64
CachedTokens int64
TotalTokens int64
}
// Plugin consumes usage records emitted by the proxy runtime.
type Plugin interface {
HandleUsage(ctx context.Context, record Record)
}
type queueItem struct {
ctx context.Context
record Record
}
// Manager maintains a queue of usage records and delivers them to registered plugins.
type Manager struct {
once sync.Once
stopOnce sync.Once
cancel context.CancelFunc
mu sync.Mutex
cond *sync.Cond
queue []queueItem
closed bool
pluginsMu sync.RWMutex
plugins []Plugin
}
// NewManager constructs a manager with a buffered queue.
func NewManager(buffer int) *Manager {
m := &Manager{}
m.cond = sync.NewCond(&m.mu)
return m
}
// Start launches the background dispatcher. Calling Start multiple times is safe.
func (m *Manager) Start(ctx context.Context) {
if m == nil {
return
}
m.once.Do(func() {
if ctx == nil {
ctx = context.Background()
}
var workerCtx context.Context
workerCtx, m.cancel = context.WithCancel(ctx)
go m.run(workerCtx)
})
}
// Stop stops the dispatcher and drains the queue.
func (m *Manager) Stop() {
if m == nil {
return
}
m.stopOnce.Do(func() {
if m.cancel != nil {
m.cancel()
}
m.mu.Lock()
m.closed = true
m.mu.Unlock()
m.cond.Broadcast()
})
}
// Register appends a plugin to the delivery list.
func (m *Manager) Register(plugin Plugin) {
if m == nil || plugin == nil {
return
}
m.pluginsMu.Lock()
m.plugins = append(m.plugins, plugin)
m.pluginsMu.Unlock()
}
// Publish enqueues a usage record for processing. If no plugin is registered
// the record will be discarded downstream.
func (m *Manager) Publish(ctx context.Context, record Record) {
if m == nil {
return
}
// ensure worker is running even if Start was not called explicitly
m.Start(context.Background())
m.mu.Lock()
if m.closed {
m.mu.Unlock()
return
}
m.queue = append(m.queue, queueItem{ctx: ctx, record: record})
m.mu.Unlock()
m.cond.Signal()
}
func (m *Manager) run(ctx context.Context) {
for {
m.mu.Lock()
for !m.closed && len(m.queue) == 0 {
m.cond.Wait()
}
if len(m.queue) == 0 && m.closed {
m.mu.Unlock()
return
}
item := m.queue[0]
m.queue = m.queue[1:]
m.mu.Unlock()
m.dispatch(item)
}
}
func (m *Manager) dispatch(item queueItem) {
m.pluginsMu.RLock()
plugins := make([]Plugin, len(m.plugins))
copy(plugins, m.plugins)
m.pluginsMu.RUnlock()
if len(plugins) == 0 {
return
}
for _, plugin := range plugins {
if plugin == nil {
continue
}
safeInvoke(plugin, item.ctx, item.record)
}
}
func safeInvoke(plugin Plugin, ctx context.Context, record Record) {
defer func() {
if r := recover(); r != nil {
log.Errorf("usage: plugin panic recovered: %v", r)
}
}()
plugin.HandleUsage(ctx, record)
}
var defaultManager = NewManager(512)
// DefaultManager returns the global usage manager instance.
func DefaultManager() *Manager { return defaultManager }
// RegisterPlugin registers a plugin on the default manager.
func RegisterPlugin(plugin Plugin) { DefaultManager().Register(plugin) }
// PublishRecord publishes a record using the default manager.
func PublishRecord(ctx context.Context, record Record) { DefaultManager().Publish(ctx, record) }
// StartDefault starts the default manager's dispatcher.
func StartDefault(ctx context.Context) { DefaultManager().Start(ctx) }
// StopDefault stops the default manager's dispatcher.
func StopDefault() { DefaultManager().Stop() }
================================================
FILE: sdk/cliproxy/watcher.go
================================================
package cliproxy
import (
"context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func defaultWatcherFactory(configPath, authDir string, reload func(*config.Config)) (*WatcherWrapper, error) {
w, err := watcher.NewWatcher(configPath, authDir, reload)
if err != nil {
return nil, err
}
return &WatcherWrapper{
start: func(ctx context.Context) error {
return w.Start(ctx)
},
stop: func() error {
return w.Stop()
},
setConfig: func(cfg *config.Config) {
w.SetConfig(cfg)
},
snapshotAuths: func() []*coreauth.Auth { return w.SnapshotCoreAuths() },
setUpdateQueue: func(queue chan<- watcher.AuthUpdate) {
w.SetAuthUpdateQueue(queue)
},
dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool {
return w.DispatchRuntimeAuthUpdate(update)
},
}, nil
}
================================================
FILE: sdk/config/config.go
================================================
// Package config provides the public SDK configuration API.
//
// It re-exports the server configuration types and helpers so external projects can
// embed CLIProxyAPI without importing internal packages.
package config
import internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
type SDKConfig = internalconfig.SDKConfig
type Config = internalconfig.Config
type StreamingConfig = internalconfig.StreamingConfig
type TLSConfig = internalconfig.TLSConfig
type RemoteManagement = internalconfig.RemoteManagement
type AmpCode = internalconfig.AmpCode
type OAuthModelAlias = internalconfig.OAuthModelAlias
type PayloadConfig = internalconfig.PayloadConfig
type PayloadRule = internalconfig.PayloadRule
type PayloadFilterRule = internalconfig.PayloadFilterRule
type PayloadModelRule = internalconfig.PayloadModelRule
type GeminiKey = internalconfig.GeminiKey
type CodexKey = internalconfig.CodexKey
type ClaudeKey = internalconfig.ClaudeKey
type VertexCompatKey = internalconfig.VertexCompatKey
type VertexCompatModel = internalconfig.VertexCompatModel
type OpenAICompatibility = internalconfig.OpenAICompatibility
type OpenAICompatibilityAPIKey = internalconfig.OpenAICompatibilityAPIKey
type OpenAICompatibilityModel = internalconfig.OpenAICompatibilityModel
type TLS = internalconfig.TLSConfig
const (
DefaultPanelGitHubRepository = internalconfig.DefaultPanelGitHubRepository
)
func LoadConfig(configFile string) (*Config, error) { return internalconfig.LoadConfig(configFile) }
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
return internalconfig.LoadConfigOptional(configFile, optional)
}
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
return internalconfig.SaveConfigPreserveComments(configFile, cfg)
}
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
return internalconfig.SaveConfigPreserveCommentsUpdateNestedScalar(configFile, path, value)
}
func NormalizeCommentIndentation(data []byte) []byte {
return internalconfig.NormalizeCommentIndentation(data)
}
================================================
FILE: sdk/logging/request_logger.go
================================================
// Package logging re-exports request logging primitives for SDK consumers.
package logging
import internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
const defaultErrorLogsMaxFiles = 10
// RequestLogger defines the interface for logging HTTP requests and responses.
type RequestLogger = internallogging.RequestLogger
// StreamingLogWriter handles real-time logging of streaming response chunks.
type StreamingLogWriter = internallogging.StreamingLogWriter
// FileRequestLogger implements RequestLogger using file-based storage.
type FileRequestLogger = internallogging.FileRequestLogger
// NewFileRequestLogger creates a new file-based request logger with default error log retention (10 files).
func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileRequestLogger {
return internallogging.NewFileRequestLogger(enabled, logsDir, configDir, defaultErrorLogsMaxFiles)
}
// NewFileRequestLoggerWithOptions creates a new file-based request logger with configurable error log retention.
func NewFileRequestLoggerWithOptions(enabled bool, logsDir string, configDir string, errorLogsMaxFiles int) *FileRequestLogger {
return internallogging.NewFileRequestLogger(enabled, logsDir, configDir, errorLogsMaxFiles)
}
================================================
FILE: sdk/proxyutil/proxy.go
================================================
package proxyutil
import (
"context"
"fmt"
"net"
"net/http"
"net/url"
"strings"
"golang.org/x/net/proxy"
)
// Mode describes how a proxy setting should be interpreted.
type Mode int
const (
// ModeInherit means no explicit proxy behavior was configured.
ModeInherit Mode = iota
// ModeDirect means outbound requests must bypass proxies explicitly.
ModeDirect
// ModeProxy means a concrete proxy URL was configured.
ModeProxy
// ModeInvalid means the proxy setting is present but malformed or unsupported.
ModeInvalid
)
// Setting is the normalized interpretation of a proxy configuration value.
type Setting struct {
Raw string
Mode Mode
URL *url.URL
}
// Parse normalizes a proxy configuration value into inherit, direct, or proxy modes.
func Parse(raw string) (Setting, error) {
trimmed := strings.TrimSpace(raw)
setting := Setting{Raw: trimmed}
if trimmed == "" {
setting.Mode = ModeInherit
return setting, nil
}
if strings.EqualFold(trimmed, "direct") || strings.EqualFold(trimmed, "none") {
setting.Mode = ModeDirect
return setting, nil
}
parsedURL, errParse := url.Parse(trimmed)
if errParse != nil {
setting.Mode = ModeInvalid
return setting, fmt.Errorf("parse proxy URL failed: %w", errParse)
}
if parsedURL.Scheme == "" || parsedURL.Host == "" {
setting.Mode = ModeInvalid
return setting, fmt.Errorf("proxy URL missing scheme/host")
}
switch parsedURL.Scheme {
case "socks5", "http", "https":
setting.Mode = ModeProxy
setting.URL = parsedURL
return setting, nil
default:
setting.Mode = ModeInvalid
return setting, fmt.Errorf("unsupported proxy scheme: %s", parsedURL.Scheme)
}
}
// NewDirectTransport returns a transport that bypasses environment proxies.
func NewDirectTransport() *http.Transport {
if transport, ok := http.DefaultTransport.(*http.Transport); ok && transport != nil {
clone := transport.Clone()
clone.Proxy = nil
return clone
}
return &http.Transport{Proxy: nil}
}
// BuildHTTPTransport constructs an HTTP transport for the provided proxy setting.
func BuildHTTPTransport(raw string) (*http.Transport, Mode, error) {
setting, errParse := Parse(raw)
if errParse != nil {
return nil, setting.Mode, errParse
}
switch setting.Mode {
case ModeInherit:
return nil, setting.Mode, nil
case ModeDirect:
return NewDirectTransport(), setting.Mode, nil
case ModeProxy:
if setting.URL.Scheme == "socks5" {
var proxyAuth *proxy.Auth
if setting.URL.User != nil {
username := setting.URL.User.Username()
password, _ := setting.URL.User.Password()
proxyAuth = &proxy.Auth{User: username, Password: password}
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", setting.URL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
return nil, setting.Mode, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
return &http.Transport{
Proxy: nil,
DialContext: func(_ context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}, setting.Mode, nil
}
return &http.Transport{Proxy: http.ProxyURL(setting.URL)}, setting.Mode, nil
default:
return nil, setting.Mode, nil
}
}
// BuildDialer constructs a proxy dialer for settings that operate at the connection layer.
func BuildDialer(raw string) (proxy.Dialer, Mode, error) {
setting, errParse := Parse(raw)
if errParse != nil {
return nil, setting.Mode, errParse
}
switch setting.Mode {
case ModeInherit:
return nil, setting.Mode, nil
case ModeDirect:
return proxy.Direct, setting.Mode, nil
case ModeProxy:
dialer, errDialer := proxy.FromURL(setting.URL, proxy.Direct)
if errDialer != nil {
return nil, setting.Mode, fmt.Errorf("create proxy dialer failed: %w", errDialer)
}
return dialer, setting.Mode, nil
default:
return nil, setting.Mode, nil
}
}
================================================
FILE: sdk/proxyutil/proxy_test.go
================================================
package proxyutil
import (
"net/http"
"testing"
)
func TestParse(t *testing.T) {
t.Parallel()
tests := []struct {
name string
input string
want Mode
wantErr bool
}{
{name: "inherit", input: "", want: ModeInherit},
{name: "direct", input: "direct", want: ModeDirect},
{name: "none", input: "none", want: ModeDirect},
{name: "http", input: "http://proxy.example.com:8080", want: ModeProxy},
{name: "https", input: "https://proxy.example.com:8443", want: ModeProxy},
{name: "socks5", input: "socks5://proxy.example.com:1080", want: ModeProxy},
{name: "invalid", input: "bad-value", want: ModeInvalid, wantErr: true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
setting, errParse := Parse(tt.input)
if tt.wantErr && errParse == nil {
t.Fatal("expected error, got nil")
}
if !tt.wantErr && errParse != nil {
t.Fatalf("unexpected error: %v", errParse)
}
if setting.Mode != tt.want {
t.Fatalf("mode = %d, want %d", setting.Mode, tt.want)
}
})
}
}
func TestBuildHTTPTransportDirectBypassesProxy(t *testing.T) {
t.Parallel()
transport, mode, errBuild := BuildHTTPTransport("direct")
if errBuild != nil {
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
}
if mode != ModeDirect {
t.Fatalf("mode = %d, want %d", mode, ModeDirect)
}
if transport == nil {
t.Fatal("expected transport, got nil")
}
if transport.Proxy != nil {
t.Fatal("expected direct transport to disable proxy function")
}
}
func TestBuildHTTPTransportHTTPProxy(t *testing.T) {
t.Parallel()
transport, mode, errBuild := BuildHTTPTransport("http://proxy.example.com:8080")
if errBuild != nil {
t.Fatalf("BuildHTTPTransport returned error: %v", errBuild)
}
if mode != ModeProxy {
t.Fatalf("mode = %d, want %d", mode, ModeProxy)
}
if transport == nil {
t.Fatal("expected transport, got nil")
}
req, errRequest := http.NewRequest(http.MethodGet, "https://example.com", nil)
if errRequest != nil {
t.Fatalf("http.NewRequest returned error: %v", errRequest)
}
proxyURL, errProxy := transport.Proxy(req)
if errProxy != nil {
t.Fatalf("transport.Proxy returned error: %v", errProxy)
}
if proxyURL == nil || proxyURL.String() != "http://proxy.example.com:8080" {
t.Fatalf("proxy URL = %v, want http://proxy.example.com:8080", proxyURL)
}
}
================================================
FILE: sdk/translator/builtin/builtin.go
================================================
// Package builtin exposes the built-in translator registrations for SDK users.
package builtin
import (
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
)
// Registry exposes the default registry populated with all built-in translators.
func Registry() *sdktranslator.Registry {
return sdktranslator.Default()
}
// Pipeline returns a pipeline that already contains the built-in translators.
func Pipeline() *sdktranslator.Pipeline {
return sdktranslator.NewPipeline(sdktranslator.Default())
}
================================================
FILE: sdk/translator/format.go
================================================
package translator
// Format identifies a request/response schema used inside the proxy.
type Format string
// FromString converts an arbitrary identifier to a translator format.
func FromString(v string) Format {
return Format(v)
}
// String returns the raw schema identifier.
func (f Format) String() string {
return string(f)
}
================================================
FILE: sdk/translator/formats.go
================================================
package translator
// Common format identifiers exposed for SDK users.
const (
FormatOpenAI Format = "openai"
FormatOpenAIResponse Format = "openai-response"
FormatClaude Format = "claude"
FormatGemini Format = "gemini"
FormatGeminiCLI Format = "gemini-cli"
FormatCodex Format = "codex"
FormatAntigravity Format = "antigravity"
)
================================================
FILE: sdk/translator/helpers.go
================================================
package translator
import "context"
// TranslateRequestByFormatName converts a request payload between schemas by their string identifiers.
func TranslateRequestByFormatName(from, to Format, model string, rawJSON []byte, stream bool) []byte {
return TranslateRequest(from, to, model, rawJSON, stream)
}
// HasResponseTransformerByFormatName reports whether a response translator exists between two schemas.
func HasResponseTransformerByFormatName(from, to Format) bool {
return HasResponseTransformer(from, to)
}
// TranslateStreamByFormatName converts streaming responses between schemas by their string identifiers.
func TranslateStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
return TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateNonStreamByFormatName converts non-streaming responses between schemas by their string identifiers.
func TranslateNonStreamByFormatName(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
return TranslateNonStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateTokenCountByFormatName converts token counts between schemas by their string identifiers.
func TranslateTokenCountByFormatName(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
return TranslateTokenCount(ctx, from, to, count, rawJSON)
}
================================================
FILE: sdk/translator/pipeline.go
================================================
package translator
import "context"
// RequestEnvelope represents a request in the translation pipeline.
type RequestEnvelope struct {
Format Format
Model string
Stream bool
Body []byte
}
// ResponseEnvelope represents a response in the translation pipeline.
type ResponseEnvelope struct {
Format Format
Model string
Stream bool
Body []byte
Chunks []string
}
// RequestMiddleware decorates request translation.
type RequestMiddleware func(ctx context.Context, req RequestEnvelope, next RequestHandler) (RequestEnvelope, error)
// ResponseMiddleware decorates response translation.
type ResponseMiddleware func(ctx context.Context, resp ResponseEnvelope, next ResponseHandler) (ResponseEnvelope, error)
// RequestHandler performs request translation between formats.
type RequestHandler func(ctx context.Context, req RequestEnvelope) (RequestEnvelope, error)
// ResponseHandler performs response translation between formats.
type ResponseHandler func(ctx context.Context, resp ResponseEnvelope) (ResponseEnvelope, error)
// Pipeline orchestrates request/response transformation with middleware support.
type Pipeline struct {
registry *Registry
requestMiddleware []RequestMiddleware
responseMiddleware []ResponseMiddleware
}
// NewPipeline constructs a pipeline bound to the provided registry.
func NewPipeline(registry *Registry) *Pipeline {
if registry == nil {
registry = Default()
}
return &Pipeline{registry: registry}
}
// UseRequest adds request middleware executed in registration order.
func (p *Pipeline) UseRequest(mw RequestMiddleware) {
if mw != nil {
p.requestMiddleware = append(p.requestMiddleware, mw)
}
}
// UseResponse adds response middleware executed in registration order.
func (p *Pipeline) UseResponse(mw ResponseMiddleware) {
if mw != nil {
p.responseMiddleware = append(p.responseMiddleware, mw)
}
}
// TranslateRequest applies middleware and registry transformations.
func (p *Pipeline) TranslateRequest(ctx context.Context, from, to Format, req RequestEnvelope) (RequestEnvelope, error) {
terminal := func(ctx context.Context, input RequestEnvelope) (RequestEnvelope, error) {
translated := p.registry.TranslateRequest(from, to, input.Model, input.Body, input.Stream)
input.Body = translated
input.Format = to
return input, nil
}
handler := terminal
for i := len(p.requestMiddleware) - 1; i >= 0; i-- {
mw := p.requestMiddleware[i]
next := handler
handler = func(ctx context.Context, r RequestEnvelope) (RequestEnvelope, error) {
return mw(ctx, r, next)
}
}
return handler(ctx, req)
}
// TranslateResponse applies middleware and registry transformations.
func (p *Pipeline) TranslateResponse(ctx context.Context, from, to Format, resp ResponseEnvelope, originalReq, translatedReq []byte, param *any) (ResponseEnvelope, error) {
terminal := func(ctx context.Context, input ResponseEnvelope) (ResponseEnvelope, error) {
if input.Stream {
input.Chunks = p.registry.TranslateStream(ctx, from, to, input.Model, originalReq, translatedReq, input.Body, param)
} else {
input.Body = []byte(p.registry.TranslateNonStream(ctx, from, to, input.Model, originalReq, translatedReq, input.Body, param))
}
input.Format = to
return input, nil
}
handler := terminal
for i := len(p.responseMiddleware) - 1; i >= 0; i-- {
mw := p.responseMiddleware[i]
next := handler
handler = func(ctx context.Context, r ResponseEnvelope) (ResponseEnvelope, error) {
return mw(ctx, r, next)
}
}
return handler(ctx, resp)
}
================================================
FILE: sdk/translator/registry.go
================================================
package translator
import (
"context"
"sync"
)
// Registry manages translation functions across schemas.
type Registry struct {
mu sync.RWMutex
requests map[Format]map[Format]RequestTransform
responses map[Format]map[Format]ResponseTransform
}
// NewRegistry constructs an empty translator registry.
func NewRegistry() *Registry {
return &Registry{
requests: make(map[Format]map[Format]RequestTransform),
responses: make(map[Format]map[Format]ResponseTransform),
}
}
// Register stores request/response transforms between two formats.
func (r *Registry) Register(from, to Format, request RequestTransform, response ResponseTransform) {
r.mu.Lock()
defer r.mu.Unlock()
if _, ok := r.requests[from]; !ok {
r.requests[from] = make(map[Format]RequestTransform)
}
if request != nil {
r.requests[from][to] = request
}
if _, ok := r.responses[from]; !ok {
r.responses[from] = make(map[Format]ResponseTransform)
}
r.responses[from][to] = response
}
// TranslateRequest converts a payload between schemas, returning the original payload
// if no translator is registered.
func (r *Registry) TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.requests[from]; ok {
if fn, isOk := byTarget[to]; isOk && fn != nil {
return fn(model, rawJSON, stream)
}
}
return rawJSON
}
// HasResponseTransformer indicates whether a response translator exists.
func (r *Registry) HasResponseTransformer(from, to Format) bool {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.responses[from]; ok {
if _, isOk := byTarget[to]; isOk {
return true
}
}
return false
}
// TranslateStream applies the registered streaming response translator.
func (r *Registry) TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.responses[to]; ok {
if fn, isOk := byTarget[from]; isOk && fn.Stream != nil {
return fn.Stream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
return []string{string(rawJSON)}
}
// TranslateNonStream applies the registered non-stream response translator.
func (r *Registry) TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.responses[to]; ok {
if fn, isOk := byTarget[from]; isOk && fn.NonStream != nil {
return fn.NonStream(ctx, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
}
return string(rawJSON)
}
// TranslateNonStream applies the registered non-stream response translator.
func (r *Registry) TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
r.mu.RLock()
defer r.mu.RUnlock()
if byTarget, ok := r.responses[to]; ok {
if fn, isOk := byTarget[from]; isOk && fn.TokenCount != nil {
return fn.TokenCount(ctx, count)
}
}
return string(rawJSON)
}
var defaultRegistry = NewRegistry()
// Default exposes the package-level registry for shared use.
func Default() *Registry {
return defaultRegistry
}
// Register attaches transforms to the default registry.
func Register(from, to Format, request RequestTransform, response ResponseTransform) {
defaultRegistry.Register(from, to, request, response)
}
// TranslateRequest is a helper on the default registry.
func TranslateRequest(from, to Format, model string, rawJSON []byte, stream bool) []byte {
return defaultRegistry.TranslateRequest(from, to, model, rawJSON, stream)
}
// HasResponseTransformer inspects the default registry.
func HasResponseTransformer(from, to Format) bool {
return defaultRegistry.HasResponseTransformer(from, to)
}
// TranslateStream is a helper on the default registry.
func TranslateStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string {
return defaultRegistry.TranslateStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateNonStream is a helper on the default registry.
func TranslateNonStream(ctx context.Context, from, to Format, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string {
return defaultRegistry.TranslateNonStream(ctx, from, to, model, originalRequestRawJSON, requestRawJSON, rawJSON, param)
}
// TranslateTokenCount is a helper on the default registry.
func TranslateTokenCount(ctx context.Context, from, to Format, count int64, rawJSON []byte) string {
return defaultRegistry.TranslateTokenCount(ctx, from, to, count, rawJSON)
}
================================================
FILE: sdk/translator/types.go
================================================
// Package translator provides types and functions for converting chat requests and responses between different schemas.
package translator
import "context"
// RequestTransform is a function type that converts a request payload from a source schema to a target schema.
// It takes the model name, the raw JSON payload of the request, and a boolean indicating if the request is for a streaming response.
// It returns the converted request payload as a byte slice.
type RequestTransform func(model string, rawJSON []byte, stream bool) []byte
// ResponseStreamTransform is a function type that converts a streaming response from a source schema to a target schema.
// It takes a context, the model name, the raw JSON of the original and converted requests, the raw JSON of the current response chunk, and an optional parameter.
// It returns a slice of strings, where each string is a chunk of the converted streaming response.
type ResponseStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) []string
// ResponseNonStreamTransform is a function type that converts a non-streaming response from a source schema to a target schema.
// It takes a context, the model name, the raw JSON of the original and converted requests, the raw JSON of the response, and an optional parameter.
// It returns the converted response as a single string.
type ResponseNonStreamTransform func(ctx context.Context, model string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, param *any) string
// ResponseTokenCountTransform is a function type that transforms a token count from a source format to a target format.
// It takes a context and the token count as an int64, and returns the transformed token count as a string.
type ResponseTokenCountTransform func(ctx context.Context, count int64) string
// ResponseTransform is a struct that groups together the functions for transforming streaming and non-streaming responses,
// as well as token counts.
type ResponseTransform struct {
// Stream is the function for transforming streaming responses.
Stream ResponseStreamTransform
// NonStream is the function for transforming non-streaming responses.
NonStream ResponseNonStreamTransform
// TokenCount is the function for transforming token counts.
TokenCount ResponseTokenCountTransform
}
================================================
FILE: test/amp_management_test.go
================================================
package test
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func init() {
gin.SetMode(gin.TestMode)
}
// newAmpTestHandler creates a test handler with default ampcode configuration.
func newAmpTestHandler(t *testing.T) (*management.Handler, string) {
t.Helper()
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
cfg := &config.Config{
AmpCode: config.AmpCode{
UpstreamURL: "https://example.com",
UpstreamAPIKey: "test-api-key-12345",
RestrictManagementToLocalhost: true,
ForceModelMappings: false,
ModelMappings: []config.AmpModelMapping{
{From: "gpt-4", To: "gemini-pro"},
},
},
}
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
h := management.NewHandler(cfg, configPath, nil)
return h, configPath
}
// setupAmpRouter creates a test router with all ampcode management endpoints.
func setupAmpRouter(h *management.Handler) *gin.Engine {
r := gin.New()
mgmt := r.Group("/v0/management")
{
mgmt.GET("/ampcode", h.GetAmpCode)
mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL)
mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL)
mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL)
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
mgmt.GET("/ampcode/upstream-api-keys", h.GetAmpUpstreamAPIKeys)
mgmt.PUT("/ampcode/upstream-api-keys", h.PutAmpUpstreamAPIKeys)
mgmt.PATCH("/ampcode/upstream-api-keys", h.PatchAmpUpstreamAPIKeys)
mgmt.DELETE("/ampcode/upstream-api-keys", h.DeleteAmpUpstreamAPIKeys)
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings)
mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings)
mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings)
mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings)
}
return r
}
// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config.
func TestGetAmpCode(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]config.AmpCode
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
ampcode := resp["ampcode"]
if ampcode.UpstreamURL != "https://example.com" {
t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL)
}
if len(ampcode.ModelMappings) != 1 {
t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings))
}
}
// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL.
func TestGetAmpUpstreamURL(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["upstream-url"] != "https://example.com" {
t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"])
}
}
// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL.
func TestPutAmpUpstreamURL(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "https://new-upstream.com"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
}
// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL.
func TestDeleteAmpUpstreamURL(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key.
func TestGetAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
key := resp["upstream-api-key"].(string)
if key != "test-api-key-12345" {
t.Errorf("expected key %q, got %q", "test-api-key-12345", key)
}
}
// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key.
func TestPutAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "new-secret-key"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
func TestPutAmpUpstreamAPIKeys_PersistsAndReturns(t *testing.T) {
h, configPath := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value":[{"upstream-api-key":" u1 ","api-keys":[" k1 ","","k2"]}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
// Verify it was persisted to disk
loaded, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("failed to load config from disk: %v", err)
}
if len(loaded.AmpCode.UpstreamAPIKeys) != 1 {
t.Fatalf("expected 1 upstream-api-keys entry, got %d", len(loaded.AmpCode.UpstreamAPIKeys))
}
entry := loaded.AmpCode.UpstreamAPIKeys[0]
if entry.UpstreamAPIKey != "u1" {
t.Fatalf("expected upstream-api-key u1, got %q", entry.UpstreamAPIKey)
}
if len(entry.APIKeys) != 2 || entry.APIKeys[0] != "k1" || entry.APIKeys[1] != "k2" {
t.Fatalf("expected api-keys [k1 k2], got %#v", entry.APIKeys)
}
// Verify it is returned by GET /ampcode
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]config.AmpCode
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if got := resp["ampcode"].UpstreamAPIKeys; len(got) != 1 || got[0].UpstreamAPIKey != "u1" {
t.Fatalf("expected upstream-api-keys to be present after update, got %#v", got)
}
}
func TestDeleteAmpUpstreamAPIKeys_ClearsAll(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
// Seed with one entry
putBody := `{"value":[{"upstream-api-key":"u1","api-keys":["k1"]}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
deleteBody := `{"value":[]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-keys", bytes.NewBufferString(deleteBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-keys", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpUpstreamAPIKeyEntry
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["upstream-api-keys"] != nil && len(resp["upstream-api-keys"]) != 0 {
t.Fatalf("expected cleared list, got %#v", resp["upstream-api-keys"])
}
}
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting.
func TestGetAmpRestrictManagementToLocalhost(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["restrict-management-to-localhost"] != true {
t.Error("expected restrict-management-to-localhost to be true")
}
}
// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting.
func TestPutAmpRestrictManagementToLocalhost(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": false}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings.
func TestGetAmpModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 1 {
t.Fatalf("expected 1 mapping, got %d", len(mappings))
}
if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" {
t.Errorf("unexpected mapping: %+v", mappings[0])
}
}
// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings.
func TestPutAmpModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
}
// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones.
func TestPatchAmpModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}`
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
}
// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field.
func TestDeleteAmpModelMappings_Specific(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": ["gpt-4"]}`
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings.
func TestDeleteAmpModelMappings_All(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting.
func TestGetAmpForceModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["force-model-mappings"] != false {
t.Error("expected force-model-mappings to be false")
}
}
// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting.
func TestPutAmpForceModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": true}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted.
func TestPutAmpModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String())
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 3 {
t.Fatalf("expected 3 mappings, got %d", len(mappings))
}
expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"}
for _, m := range mappings {
if expected[m.From] != m.To {
t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To)
}
}
}
// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly.
func TestPatchAmpModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}`
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PATCH failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 2 {
t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings))
}
found := make(map[string]string)
for _, m := range mappings {
found[m.From] = m.To
}
if found["gpt-4"] != "updated-target" {
t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"])
}
if found["new-model"] != "new-target" {
t.Errorf("new-model should map to new-target, got %q", found["new-model"])
}
}
// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others.
func TestDeleteAmpModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
delBody := `{"value": ["a", "c"]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("DELETE failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 1 {
t.Fatalf("expected 1 mapping remaining, got %d", len(mappings))
}
if mappings[0].From != "b" || mappings[0].To != "2" {
t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To)
}
}
// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones.
func TestDeleteAmpModelMappings_NonExistent(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
delBody := `{"value": ["non-existent-model"]}`
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(resp["model-mappings"]) != 1 {
t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"]))
}
}
// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings.
func TestPutAmpModelMappings_Empty(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": []}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(resp["model-mappings"]) != 0 {
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
}
}
// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state.
func TestPutAmpUpstreamURL_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "https://new-api.example.com"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-url"] != "https://new-api.example.com" {
t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"])
}
}
// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL.
func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("DELETE failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-url"] != "" {
t.Errorf("expected empty string, got %q", resp["upstream-url"])
}
}
// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state.
func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "new-secret-api-key-xyz"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-api-key"] != "new-secret-api-key-xyz" {
t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"])
}
}
// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key.
func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("DELETE failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-api-key"] != "" {
t.Errorf("expected empty string, got %q", resp["upstream-api-key"])
}
}
// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction.
func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": false}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["restrict-management-to-localhost"] != false {
t.Error("expected false after update")
}
}
// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting.
func TestPutAmpForceModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": true}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["force-model-mappings"] != true {
t.Error("expected true after update")
}
}
// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400.
func TestPutBoolField_EmptyObject(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code)
}
}
// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET.
func TestComplexMappingsWorkflow(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}`
req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
delBody := `{"value": ["m1", "m3"]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 3 {
t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings))
}
expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"}
found := make(map[string]string)
for _, m := range mappings {
found[m.From] = m.To
}
for from, to := range expected {
if found[from] != to {
t.Errorf("mapping %s: expected %q, got %q", from, to, found[from])
}
}
}
// TestNilHandlerGetAmpCode verifies handler works with empty config.
func TestNilHandlerGetAmpCode(t *testing.T) {
cfg := &config.Config{}
h := management.NewHandler(cfg, "", nil)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config.
func TestEmptyConfigGetAmpModelMappings(t *testing.T) {
cfg := &config.Config{}
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
h := management.NewHandler(cfg, configPath, nil)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(resp["model-mappings"]) != 0 {
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
}
}
================================================
FILE: test/builtin_tools_translation_test.go
================================================
package test
import (
"testing"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
)
func TestOpenAIToCodex_PreservesBuiltinTools(t *testing.T) {
in := []byte(`{
"model":"gpt-5",
"messages":[{"role":"user","content":"hi"}],
"tools":[{"type":"web_search","search_context_size":"high"}],
"tool_choice":{"type":"web_search"}
}`)
out := sdktranslator.TranslateRequest(sdktranslator.FormatOpenAI, sdktranslator.FormatCodex, "gpt-5", in, false)
if got := gjson.GetBytes(out, "tools.#").Int(); got != 1 {
t.Fatalf("expected 1 tool, got %d: %s", got, string(out))
}
if got := gjson.GetBytes(out, "tools.0.type").String(); got != "web_search" {
t.Fatalf("expected tools[0].type=web_search, got %q: %s", got, string(out))
}
if got := gjson.GetBytes(out, "tools.0.search_context_size").String(); got != "high" {
t.Fatalf("expected tools[0].search_context_size=high, got %q: %s", got, string(out))
}
if got := gjson.GetBytes(out, "tool_choice.type").String(); got != "web_search" {
t.Fatalf("expected tool_choice.type=web_search, got %q: %s", got, string(out))
}
}
func TestOpenAIResponsesToOpenAI_IgnoresBuiltinTools(t *testing.T) {
in := []byte(`{
"model":"gpt-5",
"input":[{"role":"user","content":[{"type":"input_text","text":"hi"}]}],
"tools":[{"type":"web_search","search_context_size":"low"}]
}`)
out := sdktranslator.TranslateRequest(sdktranslator.FormatOpenAIResponse, sdktranslator.FormatOpenAI, "gpt-5", in, false)
if got := gjson.GetBytes(out, "tools.#").Int(); got != 0 {
t.Fatalf("expected 0 tools (builtin tools not supported in Chat Completions), got %d: %s", got, string(out))
}
}
================================================
FILE: test/thinking_conversion_test.go
================================================
package test
import (
"fmt"
"strings"
"testing"
"time"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/translator"
// Import provider packages to trigger init() registration of ProviderAppliers
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/antigravity"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/claude"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/codex"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/kimi"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// thinkingTestCase represents a common test case structure for both suffix and body tests.
type thinkingTestCase struct {
name string
from string
to string
model string
inputJSON string
expectField string
expectValue string
expectField2 string
expectValue2 string
includeThoughts string
expectErr bool
}
// TestThinkingE2EMatrix_Suffix tests the thinking configuration transformation using model name suffix.
// Data flow: Input JSON → TranslateRequest → ApplyThinking → Validate Output
// No helper functions are used; all test data is inline.
func TestThinkingE2EMatrix_Suffix(t *testing.T) {
reg := registry.GetGlobalRegistry()
uid := fmt.Sprintf("thinking-e2e-suffix-%d", time.Now().UnixNano())
reg.RegisterClient(uid, "test", getTestModels())
defer reg.UnregisterClient(uid)
cases := []thinkingTestCase{
// level-model (Levels=minimal/low/medium/high, ZeroAllowed=false, DynamicAllowed=false)
// Case 1: No suffix → injected default → medium
{
name: "1",
from: "openai",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 2: Specified medium → medium
{
name: "2",
from: "openai",
to: "codex",
model: "level-model(medium)",
inputJSON: `{"model":"level-model(medium)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 3: Specified xhigh → out of range error
{
name: "3",
from: "openai",
to: "codex",
model: "level-model(xhigh)",
inputJSON: `{"model":"level-model(xhigh)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: true,
},
// Case 4: Level none → clamped to minimal (ZeroAllowed=false)
{
name: "4",
from: "openai",
to: "codex",
model: "level-model(none)",
inputJSON: `{"model":"level-model(none)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "minimal",
expectErr: false,
},
// Case 5: Level auto → DynamicAllowed=false → medium (mid-range)
{
name: "5",
from: "openai",
to: "codex",
model: "level-model(auto)",
inputJSON: `{"model":"level-model(auto)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 6: No suffix from gemini → injected default → medium
{
name: "6",
from: "gemini",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 7: Budget 8192 → medium
{
name: "7",
from: "gemini",
to: "codex",
model: "level-model(8192)",
inputJSON: `{"model":"level-model(8192)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 8: Budget 64000 → clamped to high
{
name: "8",
from: "gemini",
to: "codex",
model: "level-model(64000)",
inputJSON: `{"model":"level-model(64000)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning.effort",
expectValue: "high",
expectErr: false,
},
// Case 9: Budget 0 → clamped to minimal (ZeroAllowed=false)
{
name: "9",
from: "gemini",
to: "codex",
model: "level-model(0)",
inputJSON: `{"model":"level-model(0)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning.effort",
expectValue: "minimal",
expectErr: false,
},
// Case 10: Budget -1 → auto → DynamicAllowed=false → medium (mid-range)
{
name: "10",
from: "gemini",
to: "codex",
model: "level-model(-1)",
inputJSON: `{"model":"level-model(-1)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 11: Claude source no suffix → passthrough (no thinking)
{
name: "11",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 12: Budget 8192 → medium
{
name: "12",
from: "claude",
to: "openai",
model: "level-model(8192)",
inputJSON: `{"model":"level-model(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning_effort",
expectValue: "medium",
expectErr: false,
},
// Case 13: Budget 64000 → clamped to high
{
name: "13",
from: "claude",
to: "openai",
model: "level-model(64000)",
inputJSON: `{"model":"level-model(64000)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning_effort",
expectValue: "high",
expectErr: false,
},
// Case 14: Budget 0 → clamped to minimal (ZeroAllowed=false)
{
name: "14",
from: "claude",
to: "openai",
model: "level-model(0)",
inputJSON: `{"model":"level-model(0)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning_effort",
expectValue: "minimal",
expectErr: false,
},
// Case 15: Budget -1 → auto → DynamicAllowed=false → medium (mid-range)
{
name: "15",
from: "claude",
to: "openai",
model: "level-model(-1)",
inputJSON: `{"model":"level-model(-1)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning_effort",
expectValue: "medium",
expectErr: false,
},
// level-subset-model (Levels=low/high, ZeroAllowed=false, DynamicAllowed=false)
// Case 16: Budget 8192 → medium → rounded down to low
{
name: "16",
from: "gemini",
to: "openai",
model: "level-subset-model(8192)",
inputJSON: `{"model":"level-subset-model(8192)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning_effort",
expectValue: "low",
expectErr: false,
},
// Case 17: Budget 1 → minimal → clamped to low (min supported)
{
name: "17",
from: "claude",
to: "gemini",
model: "level-subset-model(1)",
inputJSON: `{"model":"level-subset-model(1)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "low",
includeThoughts: "true",
expectErr: false,
},
// gemini-budget-model (Min=128, Max=20000, ZeroAllowed=false, DynamicAllowed=true)
// Case 18: No suffix → passthrough
{
name: "18",
from: "openai",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 19: Effort medium → 8192
{
name: "19",
from: "openai",
to: "gemini",
model: "gemini-budget-model(medium)",
inputJSON: `{"model":"gemini-budget-model(medium)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 20: Effort xhigh → clamped to 20000 (max)
{
name: "20",
from: "openai",
to: "gemini",
model: "gemini-budget-model(xhigh)",
inputJSON: `{"model":"gemini-budget-model(xhigh)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 21: Effort none → clamped to 128 (min) → includeThoughts=false
{
name: "21",
from: "openai",
to: "gemini",
model: "gemini-budget-model(none)",
inputJSON: `{"model":"gemini-budget-model(none)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "128",
includeThoughts: "false",
expectErr: false,
},
// Case 22: Effort auto → DynamicAllowed=true → -1
{
name: "22",
from: "openai",
to: "gemini",
model: "gemini-budget-model(auto)",
inputJSON: `{"model":"gemini-budget-model(auto)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// Case 23: Claude source no suffix → passthrough
{
name: "23",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 24: Budget 8192 → 8192
{
name: "24",
from: "claude",
to: "gemini",
model: "gemini-budget-model(8192)",
inputJSON: `{"model":"gemini-budget-model(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 25: Budget 64000 → clamped to 20000 (max)
{
name: "25",
from: "claude",
to: "gemini",
model: "gemini-budget-model(64000)",
inputJSON: `{"model":"gemini-budget-model(64000)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 26: Budget 0 → clamped to 128 (min) → includeThoughts=false
{
name: "26",
from: "claude",
to: "gemini",
model: "gemini-budget-model(0)",
inputJSON: `{"model":"gemini-budget-model(0)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "128",
includeThoughts: "false",
expectErr: false,
},
// Case 27: Budget -1 → DynamicAllowed=true → -1
{
name: "27",
from: "claude",
to: "gemini",
model: "gemini-budget-model(-1)",
inputJSON: `{"model":"gemini-budget-model(-1)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// gemini-mixed-model (Min=128, Max=32768, Levels=low/high, ZeroAllowed=false, DynamicAllowed=true)
// Case 28: OpenAI source no suffix → passthrough
{
name: "28",
from: "openai",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 29: Effort high → low/high supported → high
{
name: "29",
from: "openai",
to: "gemini",
model: "gemini-mixed-model(high)",
inputJSON: `{"model":"gemini-mixed-model(high)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "high",
includeThoughts: "true",
expectErr: false,
},
// Case 30: Effort xhigh → clamped to high
{
name: "30",
from: "openai",
to: "gemini",
model: "gemini-mixed-model(xhigh)",
inputJSON: `{"model":"gemini-mixed-model(xhigh)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "high",
includeThoughts: "true",
expectErr: false,
},
// Case 31: Effort none → clamped to low (min supported) → includeThoughts=false
{
name: "31",
from: "openai",
to: "gemini",
model: "gemini-mixed-model(none)",
inputJSON: `{"model":"gemini-mixed-model(none)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "low",
includeThoughts: "false",
expectErr: false,
},
// Case 32: Effort auto → DynamicAllowed=true → -1 (budget)
{
name: "32",
from: "openai",
to: "gemini",
model: "gemini-mixed-model(auto)",
inputJSON: `{"model":"gemini-mixed-model(auto)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// Case 33: Claude source no suffix → passthrough
{
name: "33",
from: "claude",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 34: Budget 8192 → 8192 (keep budget)
{
name: "34",
from: "claude",
to: "gemini",
model: "gemini-mixed-model(8192)",
inputJSON: `{"model":"gemini-mixed-model(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 35: Budget 64000 → clamped to 32768 (max)
{
name: "35",
from: "claude",
to: "gemini",
model: "gemini-mixed-model(64000)",
inputJSON: `{"model":"gemini-mixed-model(64000)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "32768",
includeThoughts: "true",
expectErr: false,
},
// Case 36: Budget 0 → minimal → clamped to low (min level) → includeThoughts=false
{
name: "36",
from: "claude",
to: "gemini",
model: "gemini-mixed-model(0)",
inputJSON: `{"model":"gemini-mixed-model(0)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "low",
includeThoughts: "false",
expectErr: false,
},
// Case 37: Budget -1 → DynamicAllowed=true → -1 (budget)
{
name: "37",
from: "claude",
to: "gemini",
model: "gemini-mixed-model(-1)",
inputJSON: `{"model":"gemini-mixed-model(-1)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// claude-budget-model (Min=1024, Max=128000, ZeroAllowed=true, DynamicAllowed=false)
// Case 38: OpenAI source no suffix → passthrough
{
name: "38",
from: "openai",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 39: Effort medium → 8192
{
name: "39",
from: "openai",
to: "claude",
model: "claude-budget-model(medium)",
inputJSON: `{"model":"claude-budget-model(medium)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Case 40: Effort xhigh → clamped to 32768 (matrix value)
{
name: "40",
from: "openai",
to: "claude",
model: "claude-budget-model(xhigh)",
inputJSON: `{"model":"claude-budget-model(xhigh)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "thinking.budget_tokens",
expectValue: "32768",
expectErr: false,
},
// Case 41: Effort none → ZeroAllowed=true → disabled
{
name: "41",
from: "openai",
to: "claude",
model: "claude-budget-model(none)",
inputJSON: `{"model":"claude-budget-model(none)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "thinking.type",
expectValue: "disabled",
expectErr: false,
},
// Case 42: Effort auto → DynamicAllowed=false → 64512 (mid-range)
{
name: "42",
from: "openai",
to: "claude",
model: "claude-budget-model(auto)",
inputJSON: `{"model":"claude-budget-model(auto)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "thinking.budget_tokens",
expectValue: "64512",
expectErr: false,
},
// Case 43: Gemini source no suffix → passthrough
{
name: "43",
from: "gemini",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 44: Budget 8192 → 8192
{
name: "44",
from: "gemini",
to: "claude",
model: "claude-budget-model(8192)",
inputJSON: `{"model":"claude-budget-model(8192)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Case 45: Budget 200000 → clamped to 128000 (max)
{
name: "45",
from: "gemini",
to: "claude",
model: "claude-budget-model(200000)",
inputJSON: `{"model":"claude-budget-model(200000)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "thinking.budget_tokens",
expectValue: "128000",
expectErr: false,
},
// Case 46: Budget 0 → ZeroAllowed=true → disabled
{
name: "46",
from: "gemini",
to: "claude",
model: "claude-budget-model(0)",
inputJSON: `{"model":"claude-budget-model(0)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "thinking.type",
expectValue: "disabled",
expectErr: false,
},
// Case 47: Budget -1 → auto → DynamicAllowed=false → 64512 (mid-range)
{
name: "47",
from: "gemini",
to: "claude",
model: "claude-budget-model(-1)",
inputJSON: `{"model":"claude-budget-model(-1)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "thinking.budget_tokens",
expectValue: "64512",
expectErr: false,
},
// antigravity-budget-model (Min=128, Max=20000, ZeroAllowed=true, DynamicAllowed=true)
// Case 48: Gemini to Antigravity no suffix → passthrough
{
name: "48",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 49: Effort medium → 8192
{
name: "49",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model(medium)",
inputJSON: `{"model":"antigravity-budget-model(medium)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 50: Effort xhigh → clamped to 20000 (max)
{
name: "50",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model(xhigh)",
inputJSON: `{"model":"antigravity-budget-model(xhigh)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 51: Effort none → ZeroAllowed=true → 0 → includeThoughts=false
{
name: "51",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model(none)",
inputJSON: `{"model":"antigravity-budget-model(none)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "0",
includeThoughts: "false",
expectErr: false,
},
// Case 52: Effort auto → DynamicAllowed=true → -1
{
name: "52",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model(auto)",
inputJSON: `{"model":"antigravity-budget-model(auto)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// Case 53: Claude to Antigravity no suffix → passthrough
{
name: "53",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 54: Budget 8192 → 8192
{
name: "54",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model(8192)",
inputJSON: `{"model":"antigravity-budget-model(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 55: Budget 64000 → clamped to 20000 (max)
{
name: "55",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model(64000)",
inputJSON: `{"model":"antigravity-budget-model(64000)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 56: Budget 0 → ZeroAllowed=true → 0 → includeThoughts=false
{
name: "56",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model(0)",
inputJSON: `{"model":"antigravity-budget-model(0)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "0",
includeThoughts: "false",
expectErr: false,
},
// Case 57: Budget -1 → DynamicAllowed=true → -1
{
name: "57",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model(-1)",
inputJSON: `{"model":"antigravity-budget-model(-1)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// no-thinking-model (Thinking=nil)
// Case 58: No thinking support → no configuration
{
name: "58",
from: "gemini",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 59: Budget 8192 → no thinking support → suffix stripped → no configuration
{
name: "59",
from: "gemini",
to: "openai",
model: "no-thinking-model(8192)",
inputJSON: `{"model":"no-thinking-model(8192)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 60: Budget 0 → suffix stripped → no configuration
{
name: "60",
from: "gemini",
to: "openai",
model: "no-thinking-model(0)",
inputJSON: `{"model":"no-thinking-model(0)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 61: Budget -1 → suffix stripped → no configuration
{
name: "61",
from: "gemini",
to: "openai",
model: "no-thinking-model(-1)",
inputJSON: `{"model":"no-thinking-model(-1)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 62: Claude source no suffix → no configuration
{
name: "62",
from: "claude",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 63: Budget 8192 → suffix stripped → no configuration
{
name: "63",
from: "claude",
to: "openai",
model: "no-thinking-model(8192)",
inputJSON: `{"model":"no-thinking-model(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 64: Budget 0 → suffix stripped → no configuration
{
name: "64",
from: "claude",
to: "openai",
model: "no-thinking-model(0)",
inputJSON: `{"model":"no-thinking-model(0)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 65: Budget -1 → suffix stripped → no configuration
{
name: "65",
from: "claude",
to: "openai",
model: "no-thinking-model(-1)",
inputJSON: `{"model":"no-thinking-model(-1)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// user-defined-model (UserDefined=true, Thinking=nil)
// Case 66: User defined model no suffix → passthrough
{
name: "66",
from: "gemini",
to: "openai",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 67: Budget 8192 → passthrough logic → medium
{
name: "67",
from: "gemini",
to: "openai",
model: "user-defined-model(8192)",
inputJSON: `{"model":"user-defined-model(8192)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning_effort",
expectValue: "medium",
expectErr: false,
},
// Case 68: Budget 64000 → passthrough logic → xhigh
{
name: "68",
from: "gemini",
to: "openai",
model: "user-defined-model(64000)",
inputJSON: `{"model":"user-defined-model(64000)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning_effort",
expectValue: "xhigh",
expectErr: false,
},
// Case 69: Budget 0 → passthrough logic → none
{
name: "69",
from: "gemini",
to: "openai",
model: "user-defined-model(0)",
inputJSON: `{"model":"user-defined-model(0)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning_effort",
expectValue: "none",
expectErr: false,
},
// Case 70: Budget -1 → passthrough logic → auto
{
name: "70",
from: "gemini",
to: "openai",
model: "user-defined-model(-1)",
inputJSON: `{"model":"user-defined-model(-1)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning_effort",
expectValue: "auto",
expectErr: false,
},
// Case 71: Claude to Codex no suffix → injected default → medium
{
name: "71",
from: "claude",
to: "codex",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 72: Budget 8192 → passthrough logic → medium
{
name: "72",
from: "claude",
to: "codex",
model: "user-defined-model(8192)",
inputJSON: `{"model":"user-defined-model(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 73: Budget 64000 → passthrough logic → xhigh
{
name: "73",
from: "claude",
to: "codex",
model: "user-defined-model(64000)",
inputJSON: `{"model":"user-defined-model(64000)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "xhigh",
expectErr: false,
},
// Case 74: Budget 0 → passthrough logic → none
{
name: "74",
from: "claude",
to: "codex",
model: "user-defined-model(0)",
inputJSON: `{"model":"user-defined-model(0)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "none",
expectErr: false,
},
// Case 75: Budget -1 → passthrough logic → auto
{
name: "75",
from: "claude",
to: "codex",
model: "user-defined-model(-1)",
inputJSON: `{"model":"user-defined-model(-1)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "auto",
expectErr: false,
},
// Case 76: OpenAI to Gemini budget 8192 → passthrough → 8192
{
name: "76",
from: "openai",
to: "gemini",
model: "user-defined-model(8192)",
inputJSON: `{"model":"user-defined-model(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 77: OpenAI to Claude budget 8192 → passthrough → 8192
{
name: "77",
from: "openai",
to: "claude",
model: "user-defined-model(8192)",
inputJSON: `{"model":"user-defined-model(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Case 78: OpenAI-Response to Gemini budget 8192 → passthrough → 8192
{
name: "78",
from: "openai-response",
to: "gemini",
model: "user-defined-model(8192)",
inputJSON: `{"model":"user-defined-model(8192)","input":[{"role":"user","content":"hi"}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 79: OpenAI-Response to Claude budget 8192 → passthrough → 8192
{
name: "79",
from: "openai-response",
to: "claude",
model: "user-defined-model(8192)",
inputJSON: `{"model":"user-defined-model(8192)","input":[{"role":"user","content":"hi"}]}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Same-protocol passthrough tests (80-89)
// Case 80: OpenAI to OpenAI, level high → passthrough reasoning_effort
{
name: "80",
from: "openai",
to: "openai",
model: "level-model(high)",
inputJSON: `{"model":"level-model(high)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning_effort",
expectValue: "high",
expectErr: false,
},
// Case 81: OpenAI to OpenAI, level xhigh → out of range error
{
name: "81",
from: "openai",
to: "openai",
model: "level-model(xhigh)",
inputJSON: `{"model":"level-model(xhigh)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: true,
},
// Case 82: OpenAI-Response to Codex, level high → passthrough reasoning.effort
{
name: "82",
from: "openai-response",
to: "codex",
model: "level-model(high)",
inputJSON: `{"model":"level-model(high)","input":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "high",
expectErr: false,
},
// Case 83: OpenAI-Response to Codex, level xhigh → out of range error
{
name: "83",
from: "openai-response",
to: "codex",
model: "level-model(xhigh)",
inputJSON: `{"model":"level-model(xhigh)","input":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: true,
},
// Case 84: Gemini to Gemini, budget 8192 → passthrough thinkingBudget
{
name: "84",
from: "gemini",
to: "gemini",
model: "gemini-budget-model(8192)",
inputJSON: `{"model":"gemini-budget-model(8192)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 85: Gemini to Gemini, budget 64000 → clamped to Max
{
name: "85",
from: "gemini",
to: "gemini",
model: "gemini-budget-model(64000)",
inputJSON: `{"model":"gemini-budget-model(64000)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 86: Claude to Claude, budget 8192 → passthrough thinking.budget_tokens
{
name: "86",
from: "claude",
to: "claude",
model: "claude-budget-model(8192)",
inputJSON: `{"model":"claude-budget-model(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Case 87: Claude to Claude, budget 200000 → clamped to Max
{
name: "87",
from: "claude",
to: "claude",
model: "claude-budget-model(200000)",
inputJSON: `{"model":"claude-budget-model(200000)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "thinking.budget_tokens",
expectValue: "128000",
expectErr: false,
},
// Case 88: Gemini-CLI to Antigravity, budget 8192 → passthrough thinkingBudget
{
name: "88",
from: "gemini-cli",
to: "antigravity",
model: "antigravity-budget-model(8192)",
inputJSON: `{"model":"antigravity-budget-model(8192)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 89: Gemini-CLI to Antigravity, budget 64000 → clamped to Max
{
name: "89",
from: "gemini-cli",
to: "antigravity",
model: "antigravity-budget-model(64000)",
inputJSON: `{"model":"antigravity-budget-model(64000)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// iflow tests: glm-test and minimax-test (Cases 90-105)
// glm-test (from: openai, claude)
// Case 90: OpenAI to iflow, no suffix → passthrough
{
name: "90",
from: "openai",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 91: OpenAI to iflow, (medium) → enable_thinking=true
{
name: "91",
from: "openai",
to: "iflow",
model: "glm-test(medium)",
inputJSON: `{"model":"glm-test(medium)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "true",
expectErr: false,
},
// Case 92: OpenAI to iflow, (auto) → enable_thinking=true
{
name: "92",
from: "openai",
to: "iflow",
model: "glm-test(auto)",
inputJSON: `{"model":"glm-test(auto)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "true",
expectErr: false,
},
// Case 93: OpenAI to iflow, (none) → enable_thinking=false
{
name: "93",
from: "openai",
to: "iflow",
model: "glm-test(none)",
inputJSON: `{"model":"glm-test(none)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "false",
expectErr: false,
},
// Case 94: Claude to iflow, no suffix → passthrough
{
name: "94",
from: "claude",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 95: Claude to iflow, (8192) → enable_thinking=true
{
name: "95",
from: "claude",
to: "iflow",
model: "glm-test(8192)",
inputJSON: `{"model":"glm-test(8192)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "true",
expectErr: false,
},
// Case 96: Claude to iflow, (-1) → enable_thinking=true
{
name: "96",
from: "claude",
to: "iflow",
model: "glm-test(-1)",
inputJSON: `{"model":"glm-test(-1)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "true",
expectErr: false,
},
// Case 97: Claude to iflow, (0) → enable_thinking=false
{
name: "97",
from: "claude",
to: "iflow",
model: "glm-test(0)",
inputJSON: `{"model":"glm-test(0)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "false",
expectErr: false,
},
// minimax-test (from: openai, gemini)
// Case 98: OpenAI to iflow, no suffix → passthrough
{
name: "98",
from: "openai",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 99: OpenAI to iflow, (medium) → reasoning_split=true
{
name: "99",
from: "openai",
to: "iflow",
model: "minimax-test(medium)",
inputJSON: `{"model":"minimax-test(medium)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning_split",
expectValue: "true",
expectErr: false,
},
// Case 100: OpenAI to iflow, (auto) → reasoning_split=true
{
name: "100",
from: "openai",
to: "iflow",
model: "minimax-test(auto)",
inputJSON: `{"model":"minimax-test(auto)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning_split",
expectValue: "true",
expectErr: false,
},
// Case 101: OpenAI to iflow, (none) → reasoning_split=false
{
name: "101",
from: "openai",
to: "iflow",
model: "minimax-test(none)",
inputJSON: `{"model":"minimax-test(none)","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning_split",
expectValue: "false",
expectErr: false,
},
// Case 102: Gemini to iflow, no suffix → passthrough
{
name: "102",
from: "gemini",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 103: Gemini to iflow, (8192) → reasoning_split=true
{
name: "103",
from: "gemini",
to: "iflow",
model: "minimax-test(8192)",
inputJSON: `{"model":"minimax-test(8192)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning_split",
expectValue: "true",
expectErr: false,
},
// Case 104: Gemini to iflow, (-1) → reasoning_split=true
{
name: "104",
from: "gemini",
to: "iflow",
model: "minimax-test(-1)",
inputJSON: `{"model":"minimax-test(-1)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning_split",
expectValue: "true",
expectErr: false,
},
// Case 105: Gemini to iflow, (0) → reasoning_split=false
{
name: "105",
from: "gemini",
to: "iflow",
model: "minimax-test(0)",
inputJSON: `{"model":"minimax-test(0)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning_split",
expectValue: "false",
expectErr: false,
},
// Gemini Family Cross-Channel Consistency (Cases 106-114)
// Tests that gemini/gemini-cli/antigravity as same API family should have consistent validation behavior
// Case 106: Gemini to Antigravity, budget 64000 (suffix) → clamped to Max
{
name: "106",
from: "gemini",
to: "antigravity",
model: "gemini-budget-model(64000)",
inputJSON: `{"model":"gemini-budget-model(64000)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 107: Gemini to Gemini-CLI, budget 64000 (suffix) → clamped to Max
{
name: "107",
from: "gemini",
to: "gemini-cli",
model: "gemini-budget-model(64000)",
inputJSON: `{"model":"gemini-budget-model(64000)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 108: Gemini-CLI to Antigravity, budget 64000 (suffix) → clamped to Max
{
name: "108",
from: "gemini-cli",
to: "antigravity",
model: "gemini-budget-model(64000)",
inputJSON: `{"model":"gemini-budget-model(64000)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 109: Gemini-CLI to Gemini, budget 64000 (suffix) → clamped to Max
{
name: "109",
from: "gemini-cli",
to: "gemini",
model: "gemini-budget-model(64000)",
inputJSON: `{"model":"gemini-budget-model(64000)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 110: Gemini to Antigravity, budget 8192 → passthrough (normal value)
{
name: "110",
from: "gemini",
to: "antigravity",
model: "gemini-budget-model(8192)",
inputJSON: `{"model":"gemini-budget-model(8192)","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 111: Gemini-CLI to Antigravity, budget 8192 → passthrough (normal value)
{
name: "111",
from: "gemini-cli",
to: "antigravity",
model: "gemini-budget-model(8192)",
inputJSON: `{"model":"gemini-budget-model(8192)","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}]}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
}
runThinkingTests(t, cases)
}
// TestThinkingE2EMatrix_Body tests the thinking configuration transformation using request body parameters.
// Data flow: Input JSON with thinking params → TranslateRequest → ApplyThinking → Validate Output
func TestThinkingE2EMatrix_Body(t *testing.T) {
reg := registry.GetGlobalRegistry()
uid := fmt.Sprintf("thinking-e2e-body-%d", time.Now().UnixNano())
reg.RegisterClient(uid, "test", getTestModels())
defer reg.UnregisterClient(uid)
cases := []thinkingTestCase{
// level-model (Levels=minimal/low/medium/high, ZeroAllowed=false, DynamicAllowed=false)
// Case 1: No param → injected default → medium
{
name: "1",
from: "openai",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 2: reasoning_effort=medium → medium
{
name: "2",
from: "openai",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"medium"}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 3: reasoning_effort=xhigh → out of range error
{
name: "3",
from: "openai",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
expectField: "",
expectErr: true,
},
// Case 4: reasoning_effort=none → clamped to minimal
{
name: "4",
from: "openai",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
expectField: "reasoning.effort",
expectValue: "minimal",
expectErr: false,
},
// Case 5: reasoning_effort=auto → medium (DynamicAllowed=false)
{
name: "5",
from: "openai",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"auto"}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 6: No param from gemini → injected default → medium
{
name: "6",
from: "gemini",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 7: thinkingBudget=8192 → medium
{
name: "7",
from: "gemini",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 8: thinkingBudget=64000 → clamped to high
{
name: "8",
from: "gemini",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}`,
expectField: "reasoning.effort",
expectValue: "high",
expectErr: false,
},
// Case 9: thinkingBudget=0 → clamped to minimal
{
name: "9",
from: "gemini",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":0}}}`,
expectField: "reasoning.effort",
expectValue: "minimal",
expectErr: false,
},
// Case 10: thinkingBudget=-1 → medium (DynamicAllowed=false)
{
name: "10",
from: "gemini",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 11: Claude no param → passthrough (no thinking)
{
name: "11",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 12: thinking.budget_tokens=8192 → medium
{
name: "12",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":8192}}`,
expectField: "reasoning_effort",
expectValue: "medium",
expectErr: false,
},
// Case 13: thinking.budget_tokens=64000 → clamped to high
{
name: "13",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":64000}}`,
expectField: "reasoning_effort",
expectValue: "high",
expectErr: false,
},
// Case 14: thinking.budget_tokens=0 → clamped to minimal
{
name: "14",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":0}}`,
expectField: "reasoning_effort",
expectValue: "minimal",
expectErr: false,
},
// Case 15: thinking.budget_tokens=-1 → medium (DynamicAllowed=false)
{
name: "15",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":-1}}`,
expectField: "reasoning_effort",
expectValue: "medium",
expectErr: false,
},
// level-subset-model (Levels=low/high, ZeroAllowed=false, DynamicAllowed=false)
// Case 16: thinkingBudget=8192 → medium → rounded down to low
{
name: "16",
from: "gemini",
to: "openai",
model: "level-subset-model",
inputJSON: `{"model":"level-subset-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`,
expectField: "reasoning_effort",
expectValue: "low",
expectErr: false,
},
// Case 17: thinking.budget_tokens=1 → minimal → clamped to low
{
name: "17",
from: "claude",
to: "gemini",
model: "level-subset-model",
inputJSON: `{"model":"level-subset-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":1}}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "low",
includeThoughts: "true",
expectErr: false,
},
// gemini-budget-model (Min=128, Max=20000, ZeroAllowed=false, DynamicAllowed=true)
// Case 18: No param → passthrough
{
name: "18",
from: "openai",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 19: reasoning_effort=medium → 8192
{
name: "19",
from: "openai",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"medium"}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 20: reasoning_effort=xhigh → clamped to 20000
{
name: "20",
from: "openai",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 21: reasoning_effort=none → clamped to 128 → includeThoughts=false
{
name: "21",
from: "openai",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "128",
includeThoughts: "false",
expectErr: false,
},
// Case 22: reasoning_effort=auto → -1 (DynamicAllowed=true)
{
name: "22",
from: "openai",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"auto"}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// Case 23: Claude no param → passthrough
{
name: "23",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 24: thinking.budget_tokens=8192 → 8192
{
name: "24",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":8192}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 25: thinking.budget_tokens=64000 → clamped to 20000
{
name: "25",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":64000}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 26: thinking.budget_tokens=0 → clamped to 128 → includeThoughts=false
{
name: "26",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":0}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "128",
includeThoughts: "false",
expectErr: false,
},
// Case 27: thinking.budget_tokens=-1 → -1 (DynamicAllowed=true)
{
name: "27",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":-1}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// gemini-mixed-model (Min=128, Max=32768, Levels=low/high, ZeroAllowed=false, DynamicAllowed=true)
// Case 28: No param → passthrough
{
name: "28",
from: "openai",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 29: reasoning_effort=high → high
{
name: "29",
from: "openai",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"high"}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "high",
includeThoughts: "true",
expectErr: false,
},
// Case 30: reasoning_effort=xhigh → clamped to high
{
name: "30",
from: "openai",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "high",
includeThoughts: "true",
expectErr: false,
},
// Case 31: reasoning_effort=none → clamped to low → includeThoughts=false
{
name: "31",
from: "openai",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "low",
includeThoughts: "false",
expectErr: false,
},
// Case 32: reasoning_effort=auto → -1 (DynamicAllowed=true)
{
name: "32",
from: "openai",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"auto"}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// Case 33: Claude no param → passthrough
{
name: "33",
from: "claude",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 34: thinking.budget_tokens=8192 → 8192 (keeps budget)
{
name: "34",
from: "claude",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":8192}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 35: thinking.budget_tokens=64000 → clamped to 32768 (keeps budget)
{
name: "35",
from: "claude",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":64000}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "32768",
includeThoughts: "true",
expectErr: false,
},
// Case 36: thinking.budget_tokens=0 → clamped to low → includeThoughts=false
{
name: "36",
from: "claude",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":0}}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "low",
includeThoughts: "false",
expectErr: false,
},
// Case 37: thinking.budget_tokens=-1 → -1 (DynamicAllowed=true)
{
name: "37",
from: "claude",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":-1}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// claude-budget-model (Min=1024, Max=128000, ZeroAllowed=true, DynamicAllowed=false)
// Case 38: No param → passthrough
{
name: "38",
from: "openai",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 39: reasoning_effort=medium → 8192
{
name: "39",
from: "openai",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"medium"}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Case 40: reasoning_effort=xhigh → clamped to 32768
{
name: "40",
from: "openai",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
expectField: "thinking.budget_tokens",
expectValue: "32768",
expectErr: false,
},
// Case 41: reasoning_effort=none → disabled
{
name: "41",
from: "openai",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
expectField: "thinking.type",
expectValue: "disabled",
expectErr: false,
},
// Case 42: reasoning_effort=auto → 64512 (mid-range)
{
name: "42",
from: "openai",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"auto"}`,
expectField: "thinking.budget_tokens",
expectValue: "64512",
expectErr: false,
},
// Case 43: Gemini no param → passthrough
{
name: "43",
from: "gemini",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 44: thinkingBudget=8192 → 8192
{
name: "44",
from: "gemini",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Case 45: thinkingBudget=200000 → clamped to 128000
{
name: "45",
from: "gemini",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":200000}}}`,
expectField: "thinking.budget_tokens",
expectValue: "128000",
expectErr: false,
},
// Case 46: thinkingBudget=0 → disabled
{
name: "46",
from: "gemini",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":0}}}`,
expectField: "thinking.type",
expectValue: "disabled",
expectErr: false,
},
// Case 47: thinkingBudget=-1 → 64512 (mid-range)
{
name: "47",
from: "gemini",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`,
expectField: "thinking.budget_tokens",
expectValue: "64512",
expectErr: false,
},
// antigravity-budget-model (Min=128, Max=20000, ZeroAllowed=true, DynamicAllowed=true)
// Case 48: Gemini no param → passthrough
{
name: "48",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 49: thinkingLevel=medium → 8192
{
name: "49",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"medium"}}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 50: thinkingLevel=xhigh → clamped to 20000
{
name: "50",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"xhigh"}}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 51: thinkingLevel=none → 0 (ZeroAllowed=true)
{
name: "51",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"none"}}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "0",
includeThoughts: "false",
expectErr: false,
},
// Case 52: thinkingBudget=-1 → -1 (DynamicAllowed=true)
{
name: "52",
from: "gemini",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// Case 53: Claude no param → passthrough
{
name: "53",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 54: thinking.budget_tokens=8192 → 8192
{
name: "54",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":8192}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 55: thinking.budget_tokens=64000 → clamped to 20000
{
name: "55",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":64000}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
// Case 56: thinking.budget_tokens=0 → 0 (ZeroAllowed=true)
{
name: "56",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":0}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "0",
includeThoughts: "false",
expectErr: false,
},
// Case 57: thinking.budget_tokens=-1 → -1 (DynamicAllowed=true)
{
name: "57",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":-1}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "-1",
includeThoughts: "true",
expectErr: false,
},
// no-thinking-model (Thinking=nil)
// Case 58: Gemini no param → passthrough
{
name: "58",
from: "gemini",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 59: thinkingBudget=8192 → stripped
{
name: "59",
from: "gemini",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`,
expectField: "",
expectErr: false,
},
// Case 60: thinkingBudget=0 → stripped
{
name: "60",
from: "gemini",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":0}}}`,
expectField: "",
expectErr: false,
},
// Case 61: thinkingBudget=-1 → stripped
{
name: "61",
from: "gemini",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`,
expectField: "",
expectErr: false,
},
// Case 62: Claude no param → passthrough
{
name: "62",
from: "claude",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 63: thinking.budget_tokens=8192 → stripped
{
name: "63",
from: "claude",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":8192}}`,
expectField: "",
expectErr: false,
},
// Case 64: thinking.budget_tokens=0 → stripped
{
name: "64",
from: "claude",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":0}}`,
expectField: "",
expectErr: false,
},
// Case 65: thinking.budget_tokens=-1 → stripped
{
name: "65",
from: "claude",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":-1}}`,
expectField: "",
expectErr: false,
},
// user-defined-model (UserDefined=true, Thinking=nil)
// Case 66: Gemini no param → passthrough
{
name: "66",
from: "gemini",
to: "openai",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 67: thinkingBudget=8192 → medium
{
name: "67",
from: "gemini",
to: "openai",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`,
expectField: "reasoning_effort",
expectValue: "medium",
expectErr: false,
},
// Case 68: thinkingBudget=64000 → xhigh (passthrough)
{
name: "68",
from: "gemini",
to: "openai",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}`,
expectField: "reasoning_effort",
expectValue: "xhigh",
expectErr: false,
},
// Case 69: thinkingBudget=0 → none
{
name: "69",
from: "gemini",
to: "openai",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":0}}}`,
expectField: "reasoning_effort",
expectValue: "none",
expectErr: false,
},
// Case 70: thinkingBudget=-1 → auto
{
name: "70",
from: "gemini",
to: "openai",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`,
expectField: "reasoning_effort",
expectValue: "auto",
expectErr: false,
},
// Case 71: Claude no param → injected default → medium
{
name: "71",
from: "claude",
to: "codex",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","messages":[{"role":"user","content":"hi"}]}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 72: thinking.budget_tokens=8192 → medium
{
name: "72",
from: "claude",
to: "codex",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":8192}}`,
expectField: "reasoning.effort",
expectValue: "medium",
expectErr: false,
},
// Case 73: thinking.budget_tokens=64000 → xhigh (passthrough)
{
name: "73",
from: "claude",
to: "codex",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":64000}}`,
expectField: "reasoning.effort",
expectValue: "xhigh",
expectErr: false,
},
// Case 74: thinking.budget_tokens=0 → none
{
name: "74",
from: "claude",
to: "codex",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":0}}`,
expectField: "reasoning.effort",
expectValue: "none",
expectErr: false,
},
// Case 75: thinking.budget_tokens=-1 → auto
{
name: "75",
from: "claude",
to: "codex",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":-1}}`,
expectField: "reasoning.effort",
expectValue: "auto",
expectErr: false,
},
// Case 76: OpenAI reasoning_effort=medium to Gemini → 8192
{
name: "76",
from: "openai",
to: "gemini",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"medium"}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 77: OpenAI reasoning_effort=medium to Claude → 8192
{
name: "77",
from: "openai",
to: "claude",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"medium"}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Case 78: OpenAI-Response reasoning.effort=medium to Gemini → 8192
{
name: "78",
from: "openai-response",
to: "gemini",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"medium"}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 79: OpenAI-Response reasoning.effort=medium to Claude → 8192
{
name: "79",
from: "openai-response",
to: "claude",
model: "user-defined-model",
inputJSON: `{"model":"user-defined-model","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"medium"}}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Same-protocol passthrough tests (80-89)
// Case 80: OpenAI to OpenAI, reasoning_effort=high → passthrough
{
name: "80",
from: "openai",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"high"}`,
expectField: "reasoning_effort",
expectValue: "high",
expectErr: false,
},
// Case 81: OpenAI to OpenAI, reasoning_effort=xhigh → out of range error
{
name: "81",
from: "openai",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
expectField: "",
expectErr: true,
},
// Case 82: OpenAI-Response to Codex, reasoning.effort=high → passthrough
{
name: "82",
from: "openai-response",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"high"}}`,
expectField: "reasoning.effort",
expectValue: "high",
expectErr: false,
},
// Case 83: OpenAI-Response to Codex, reasoning.effort=xhigh → out of range error
{
name: "83",
from: "openai-response",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","input":[{"role":"user","content":"hi"}],"reasoning":{"effort":"xhigh"}}`,
expectField: "",
expectErr: true,
},
// Case 84: Gemini to Gemini, thinkingBudget=8192 → passthrough
{
name: "84",
from: "gemini",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 85: Gemini to Gemini, thinkingBudget=64000 → exceeds Max error
{
name: "85",
from: "gemini",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}`,
expectField: "",
expectErr: true,
},
// Case 86: Claude to Claude, thinking.budget_tokens=8192 → passthrough
{
name: "86",
from: "claude",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":8192}}`,
expectField: "thinking.budget_tokens",
expectValue: "8192",
expectErr: false,
},
// Case 87: Claude to Claude, thinking.budget_tokens=200000 → exceeds Max error
{
name: "87",
from: "claude",
to: "claude",
model: "claude-budget-model",
inputJSON: `{"model":"claude-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":200000}}`,
expectField: "",
expectErr: true,
},
// Case 88: Gemini-CLI to Antigravity, thinkingBudget=8192 → passthrough
{
name: "88",
from: "gemini-cli",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 89: Gemini-CLI to Antigravity, thinkingBudget=64000 → exceeds Max error
{
name: "89",
from: "gemini-cli",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}}`,
expectField: "",
expectErr: true,
},
// iflow tests: glm-test and minimax-test (Cases 90-105)
// glm-test (from: openai, claude)
// Case 90: OpenAI to iflow, no param → passthrough
{
name: "90",
from: "openai",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 91: OpenAI to iflow, reasoning_effort=medium → enable_thinking=true
{
name: "91",
from: "openai",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"medium"}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "true",
expectErr: false,
},
// Case 92: OpenAI to iflow, reasoning_effort=auto → enable_thinking=true
{
name: "92",
from: "openai",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"auto"}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "true",
expectErr: false,
},
// Case 93: OpenAI to iflow, reasoning_effort=none → enable_thinking=false
{
name: "93",
from: "openai",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "false",
expectErr: false,
},
// Case 94: Claude to iflow, no param → passthrough
{
name: "94",
from: "claude",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 95: Claude to iflow, thinking.budget_tokens=8192 → enable_thinking=true
{
name: "95",
from: "claude",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":8192}}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "true",
expectErr: false,
},
// Case 96: Claude to iflow, thinking.budget_tokens=-1 → enable_thinking=true
{
name: "96",
from: "claude",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":-1}}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "true",
expectErr: false,
},
// Case 97: Claude to iflow, thinking.budget_tokens=0 → enable_thinking=false
{
name: "97",
from: "claude",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"enabled","budget_tokens":0}}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "false",
expectErr: false,
},
// minimax-test (from: openai, gemini)
// Case 98: OpenAI to iflow, no param → passthrough
{
name: "98",
from: "openai",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}]}`,
expectField: "",
expectErr: false,
},
// Case 99: OpenAI to iflow, reasoning_effort=medium → reasoning_split=true
{
name: "99",
from: "openai",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"medium"}`,
expectField: "reasoning_split",
expectValue: "true",
expectErr: false,
},
// Case 100: OpenAI to iflow, reasoning_effort=auto → reasoning_split=true
{
name: "100",
from: "openai",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"auto"}`,
expectField: "reasoning_split",
expectValue: "true",
expectErr: false,
},
// Case 101: OpenAI to iflow, reasoning_effort=none → reasoning_split=false
{
name: "101",
from: "openai",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"none"}`,
expectField: "reasoning_split",
expectValue: "false",
expectErr: false,
},
// Case 102: Gemini to iflow, no param → passthrough
{
name: "102",
from: "gemini",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","contents":[{"role":"user","parts":[{"text":"hi"}]}]}`,
expectField: "",
expectErr: false,
},
// Case 103: Gemini to iflow, thinkingBudget=8192 → reasoning_split=true
{
name: "103",
from: "gemini",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`,
expectField: "reasoning_split",
expectValue: "true",
expectErr: false,
},
// Case 104: Gemini to iflow, thinkingBudget=-1 → reasoning_split=true
{
name: "104",
from: "gemini",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`,
expectField: "reasoning_split",
expectValue: "true",
expectErr: false,
},
// Case 105: Gemini to iflow, thinkingBudget=0 → reasoning_split=false
{
name: "105",
from: "gemini",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":0}}}`,
expectField: "reasoning_split",
expectValue: "false",
expectErr: false,
},
// Gemini Family Cross-Channel Consistency (Cases 106-114)
// Tests that gemini/gemini-cli/antigravity as same API family should have consistent validation behavior
// Case 106: Gemini to Antigravity, thinkingBudget=64000 → exceeds Max error (same family strict validation)
{
name: "106",
from: "gemini",
to: "antigravity",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}`,
expectField: "",
expectErr: true,
},
// Case 107: Gemini to Gemini-CLI, thinkingBudget=64000 → exceeds Max error (same family strict validation)
{
name: "107",
from: "gemini",
to: "gemini-cli",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}`,
expectField: "",
expectErr: true,
},
// Case 108: Gemini-CLI to Antigravity, thinkingBudget=64000 → exceeds Max error (same family strict validation)
{
name: "108",
from: "gemini-cli",
to: "antigravity",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}}`,
expectField: "",
expectErr: true,
},
// Case 109: Gemini-CLI to Gemini, thinkingBudget=64000 → exceeds Max error (same family strict validation)
{
name: "109",
from: "gemini-cli",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":64000}}}}`,
expectField: "",
expectErr: true,
},
// Case 110: Gemini to Antigravity, thinkingBudget=8192 → passthrough (normal value)
{
name: "110",
from: "gemini",
to: "antigravity",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
// Case 111: Gemini-CLI to Antigravity, thinkingBudget=8192 → passthrough (normal value)
{
name: "111",
from: "gemini-cli",
to: "antigravity",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","request":{"contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
}
runThinkingTests(t, cases)
}
// TestThinkingE2EClaudeAdaptive_Body covers Group 3 cases in docs/thinking-e2e-test-cases.md.
// It focuses on Claude 4.6 adaptive thinking and effort/level cross-protocol semantics (body-only).
func TestThinkingE2EClaudeAdaptive_Body(t *testing.T) {
reg := registry.GetGlobalRegistry()
uid := fmt.Sprintf("thinking-e2e-claude-adaptive-%d", time.Now().UnixNano())
reg.RegisterClient(uid, "test", getTestModels())
defer reg.UnregisterClient(uid)
cases := []thinkingTestCase{
// A subgroup: OpenAI -> Claude (reasoning_effort -> output_config.effort)
{
name: "A1",
from: "openai",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"minimal"}`,
expectField: "output_config.effort",
expectValue: "low",
expectErr: false,
},
{
name: "A2",
from: "openai",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"low"}`,
expectField: "output_config.effort",
expectValue: "low",
expectErr: false,
},
{
name: "A3",
from: "openai",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"medium"}`,
expectField: "output_config.effort",
expectValue: "medium",
expectErr: false,
},
{
name: "A4",
from: "openai",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"high"}`,
expectField: "output_config.effort",
expectValue: "high",
expectErr: false,
},
{
name: "A5",
from: "openai",
to: "claude",
model: "claude-opus-4-6-model",
inputJSON: `{"model":"claude-opus-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
expectField: "output_config.effort",
expectValue: "max",
expectErr: false,
},
{
name: "A6",
from: "openai",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"xhigh"}`,
expectField: "output_config.effort",
expectValue: "high",
expectErr: false,
},
{
name: "A7",
from: "openai",
to: "claude",
model: "claude-opus-4-6-model",
inputJSON: `{"model":"claude-opus-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"max"}`,
expectField: "output_config.effort",
expectValue: "max",
expectErr: false,
},
{
name: "A8",
from: "openai",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"reasoning_effort":"max"}`,
expectField: "output_config.effort",
expectValue: "high",
expectErr: false,
},
// B subgroup: Gemini -> Claude (thinkingLevel/thinkingBudget -> output_config.effort)
{
name: "B1",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"minimal"}}}`,
expectField: "output_config.effort",
expectValue: "low",
expectErr: false,
},
{
name: "B2",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"low"}}}`,
expectField: "output_config.effort",
expectValue: "low",
expectErr: false,
},
{
name: "B3",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"medium"}}}`,
expectField: "output_config.effort",
expectValue: "medium",
expectErr: false,
},
{
name: "B4",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"high"}}}`,
expectField: "output_config.effort",
expectValue: "high",
expectErr: false,
},
{
name: "B5",
from: "gemini",
to: "claude",
model: "claude-opus-4-6-model",
inputJSON: `{"model":"claude-opus-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"xhigh"}}}`,
expectField: "output_config.effort",
expectValue: "max",
expectErr: false,
},
{
name: "B6",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingLevel":"xhigh"}}}`,
expectField: "output_config.effort",
expectValue: "high",
expectErr: false,
},
{
name: "B7",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":512}}}`,
expectField: "output_config.effort",
expectValue: "low",
expectErr: false,
},
{
name: "B8",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":1024}}}`,
expectField: "output_config.effort",
expectValue: "low",
expectErr: false,
},
{
name: "B9",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":8192}}}`,
expectField: "output_config.effort",
expectValue: "medium",
expectErr: false,
},
{
name: "B10",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":24576}}}`,
expectField: "output_config.effort",
expectValue: "high",
expectErr: false,
},
{
name: "B11",
from: "gemini",
to: "claude",
model: "claude-opus-4-6-model",
inputJSON: `{"model":"claude-opus-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":32768}}}`,
expectField: "output_config.effort",
expectValue: "max",
expectErr: false,
},
{
name: "B12",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":32768}}}`,
expectField: "output_config.effort",
expectValue: "high",
expectErr: false,
},
{
name: "B13",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":0}}}`,
expectField: "thinking.type",
expectValue: "disabled",
expectErr: false,
},
{
name: "B14",
from: "gemini",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","contents":[{"role":"user","parts":[{"text":"hi"}]}],"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`,
expectField: "output_config.effort",
expectValue: "high",
expectErr: false,
},
// C subgroup: Claude adaptive + effort cross-protocol conversion
{
name: "C1",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"minimal"}}`,
expectField: "reasoning_effort",
expectValue: "minimal",
expectErr: false,
},
{
name: "C2",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"low"}}`,
expectField: "reasoning_effort",
expectValue: "low",
expectErr: false,
},
{
name: "C3",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"medium"}}`,
expectField: "reasoning_effort",
expectValue: "medium",
expectErr: false,
},
{
name: "C4",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
expectField: "reasoning_effort",
expectValue: "high",
expectErr: false,
},
{
name: "C5",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"xhigh"}}`,
expectField: "reasoning_effort",
expectValue: "high",
expectErr: false,
},
{
name: "C6",
from: "claude",
to: "openai",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`,
expectField: "reasoning_effort",
expectValue: "high",
expectErr: false,
},
{
name: "C7",
from: "claude",
to: "openai",
model: "no-thinking-model",
inputJSON: `{"model":"no-thinking-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
expectField: "",
expectErr: false,
},
{
name: "C8",
from: "claude",
to: "gemini",
model: "level-subset-model",
inputJSON: `{"model":"level-subset-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "high",
includeThoughts: "true",
expectErr: false,
},
{
name: "C9",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"low"}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "1024",
includeThoughts: "true",
expectErr: false,
},
{
name: "C10",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"medium"}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "8192",
includeThoughts: "true",
expectErr: false,
},
{
name: "C11",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
{
name: "C12",
from: "claude",
to: "gemini",
model: "gemini-budget-model",
inputJSON: `{"model":"gemini-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
expectField: "generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
{
name: "C13",
from: "claude",
to: "gemini",
model: "gemini-mixed-model",
inputJSON: `{"model":"gemini-mixed-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
expectField: "generationConfig.thinkingConfig.thinkingLevel",
expectValue: "high",
includeThoughts: "true",
expectErr: false,
},
{
name: "C14",
from: "claude",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"minimal"}}`,
expectField: "reasoning.effort",
expectValue: "minimal",
expectErr: false,
},
{
name: "C15",
from: "claude",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"low"}}`,
expectField: "reasoning.effort",
expectValue: "low",
expectErr: false,
},
{
name: "C16",
from: "claude",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
expectField: "reasoning.effort",
expectValue: "high",
expectErr: false,
},
{
name: "C17",
from: "claude",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"xhigh"}}`,
expectField: "reasoning.effort",
expectValue: "high",
expectErr: false,
},
{
name: "C18",
from: "claude",
to: "codex",
model: "level-model",
inputJSON: `{"model":"level-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`,
expectField: "reasoning.effort",
expectValue: "high",
expectErr: false,
},
{
name: "C19",
from: "claude",
to: "iflow",
model: "glm-test",
inputJSON: `{"model":"glm-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"minimal"}}`,
expectField: "chat_template_kwargs.enable_thinking",
expectValue: "true",
expectErr: false,
},
{
name: "C20",
from: "claude",
to: "iflow",
model: "minimax-test",
inputJSON: `{"model":"minimax-test","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
expectField: "reasoning_split",
expectValue: "true",
expectErr: false,
},
{
name: "C21",
from: "claude",
to: "antigravity",
model: "antigravity-budget-model",
inputJSON: `{"model":"antigravity-budget-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"}}`,
expectField: "request.generationConfig.thinkingConfig.thinkingBudget",
expectValue: "20000",
includeThoughts: "true",
expectErr: false,
},
{
name: "C22",
from: "claude",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"medium"}}`,
expectField: "thinking.type",
expectValue: "adaptive",
expectField2: "output_config.effort",
expectValue2: "medium",
expectErr: false,
},
{
name: "C23",
from: "claude",
to: "claude",
model: "claude-opus-4-6-model",
inputJSON: `{"model":"claude-opus-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`,
expectField: "thinking.type",
expectValue: "adaptive",
expectField2: "output_config.effort",
expectValue2: "max",
expectErr: false,
},
{
name: "C24",
from: "claude",
to: "claude",
model: "claude-opus-4-6-model",
inputJSON: `{"model":"claude-opus-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"xhigh"}}`,
expectErr: true,
},
{
name: "C25",
from: "claude",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"high"}}`,
expectField: "thinking.type",
expectValue: "adaptive",
expectField2: "output_config.effort",
expectValue2: "high",
expectErr: false,
},
{
name: "C26",
from: "claude",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"max"}}`,
expectErr: true,
},
{
name: "C27",
from: "claude",
to: "claude",
model: "claude-sonnet-4-6-model",
inputJSON: `{"model":"claude-sonnet-4-6-model","messages":[{"role":"user","content":"hi"}],"thinking":{"type":"adaptive"},"output_config":{"effort":"xhigh"}}`,
expectErr: true,
},
}
runThinkingTests(t, cases)
}
// getTestModels returns the shared model definitions for E2E tests.
func getTestModels() []*registry.ModelInfo {
return []*registry.ModelInfo{
{
ID: "level-model",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "openai",
DisplayName: "Level Model",
Thinking: ®istry.ThinkingSupport{Levels: []string{"minimal", "low", "medium", "high"}, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "level-subset-model",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "gemini",
DisplayName: "Level Subset Model",
Thinking: ®istry.ThinkingSupport{Levels: []string{"low", "high"}, ZeroAllowed: false, DynamicAllowed: false},
},
{
ID: "gemini-budget-model",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "gemini",
DisplayName: "Gemini Budget Model",
Thinking: ®istry.ThinkingSupport{Min: 128, Max: 20000, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "gemini-mixed-model",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "gemini",
DisplayName: "Gemini Mixed Model",
Thinking: ®istry.ThinkingSupport{Min: 128, Max: 32768, Levels: []string{"low", "high"}, ZeroAllowed: false, DynamicAllowed: true},
},
{
ID: "claude-budget-model",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "claude",
DisplayName: "Claude Budget Model",
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6-model",
Object: "model",
Created: 1771372800, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high"}},
},
{
ID: "claude-opus-4-6-model",
Object: "model",
Created: 1770318000, // 2026-02-05
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Opus",
Description: "Premium model combining maximum intelligence with practical performance",
ContextLength: 1000000,
MaxCompletionTokens: 128000,
Thinking: ®istry.ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false, Levels: []string{"low", "medium", "high", "max"}},
},
{
ID: "antigravity-budget-model",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "gemini-cli",
DisplayName: "Antigravity Budget Model",
Thinking: ®istry.ThinkingSupport{Min: 128, Max: 20000, ZeroAllowed: true, DynamicAllowed: true},
},
{
ID: "no-thinking-model",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "openai",
DisplayName: "No Thinking Model",
Thinking: nil,
},
{
ID: "user-defined-model",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "openai",
DisplayName: "User Defined Model",
UserDefined: true,
Thinking: nil,
},
{
ID: "glm-test",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "iflow",
DisplayName: "GLM Test Model",
Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"}},
},
{
ID: "minimax-test",
Object: "model",
Created: 1700000000,
OwnedBy: "test",
Type: "iflow",
DisplayName: "MiniMax Test Model",
Thinking: ®istry.ThinkingSupport{Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"}},
},
}
}
// runThinkingTests runs thinking test cases using the real data flow path.
func runThinkingTests(t *testing.T, cases []thinkingTestCase) {
for _, tc := range cases {
tc := tc
testName := fmt.Sprintf("Case%s_%s->%s_%s", tc.name, tc.from, tc.to, tc.model)
t.Run(testName, func(t *testing.T) {
suffixResult := thinking.ParseSuffix(tc.model)
baseModel := suffixResult.ModelName
translateTo := tc.to
applyTo := tc.to
if tc.to == "iflow" {
translateTo = "openai"
applyTo = "iflow"
}
body := sdktranslator.TranslateRequest(
sdktranslator.FromString(tc.from),
sdktranslator.FromString(translateTo),
baseModel,
[]byte(tc.inputJSON),
true,
)
if applyTo == "claude" {
body, _ = sjson.SetBytes(body, "max_tokens", 200000)
}
body, err := thinking.ApplyThinking(body, tc.model, tc.from, applyTo, applyTo)
if tc.expectErr {
if err == nil {
t.Fatalf("expected error but got none, body=%s", string(body))
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v, body=%s", err, string(body))
}
if tc.expectField == "" {
var hasThinking bool
switch tc.to {
case "gemini":
hasThinking = gjson.GetBytes(body, "generationConfig.thinkingConfig").Exists()
case "gemini-cli":
hasThinking = gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists()
case "antigravity":
hasThinking = gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists()
case "claude":
hasThinking = gjson.GetBytes(body, "thinking").Exists()
case "openai":
hasThinking = gjson.GetBytes(body, "reasoning_effort").Exists()
case "codex":
hasThinking = gjson.GetBytes(body, "reasoning.effort").Exists() || gjson.GetBytes(body, "reasoning").Exists()
case "iflow":
hasThinking = gjson.GetBytes(body, "chat_template_kwargs.enable_thinking").Exists() || gjson.GetBytes(body, "reasoning_split").Exists()
}
if hasThinking {
t.Fatalf("expected no thinking field but found one, body=%s", string(body))
}
return
}
assertField := func(fieldPath, expected string) {
val := gjson.GetBytes(body, fieldPath)
if !val.Exists() {
t.Fatalf("expected field %s not found, body=%s", fieldPath, string(body))
}
actualValue := val.String()
if val.Type == gjson.Number {
actualValue = fmt.Sprintf("%d", val.Int())
}
if actualValue != expected {
t.Fatalf("field %s: expected %q, got %q, body=%s", fieldPath, expected, actualValue, string(body))
}
}
assertField(tc.expectField, tc.expectValue)
if tc.expectField2 != "" {
assertField(tc.expectField2, tc.expectValue2)
}
if tc.includeThoughts != "" && (tc.to == "gemini" || tc.to == "gemini-cli" || tc.to == "antigravity") {
path := "generationConfig.thinkingConfig.includeThoughts"
if tc.to == "gemini-cli" || tc.to == "antigravity" {
path = "request.generationConfig.thinkingConfig.includeThoughts"
}
itVal := gjson.GetBytes(body, path)
if !itVal.Exists() {
t.Fatalf("expected includeThoughts field not found, body=%s", string(body))
}
actual := fmt.Sprintf("%v", itVal.Bool())
if actual != tc.includeThoughts {
t.Fatalf("includeThoughts: expected %s, got %s, body=%s", tc.includeThoughts, actual, string(body))
}
}
// Verify clear_thinking for iFlow GLM models when enable_thinking=true
if tc.to == "iflow" && tc.expectField == "chat_template_kwargs.enable_thinking" && tc.expectValue == "true" {
baseModel := thinking.ParseSuffix(tc.model).ModelName
isGLM := strings.HasPrefix(strings.ToLower(baseModel), "glm")
ctVal := gjson.GetBytes(body, "chat_template_kwargs.clear_thinking")
if isGLM {
if !ctVal.Exists() {
t.Fatalf("expected clear_thinking field not found for GLM model, body=%s", string(body))
}
if ctVal.Bool() != false {
t.Fatalf("clear_thinking: expected false, got %v, body=%s", ctVal.Bool(), string(body))
}
} else if ctVal.Exists() {
t.Fatalf("expected no clear_thinking field for non-GLM enable_thinking model, body=%s", string(body))
}
}
})
}
}